| Classes in this File | Line Coverage | Branch Coverage | Complexity | ||||
| EvaluationUtils |
|
| 2.0;2 |
| 1 | /* | |
| 2 | * This program is free software: you can redistribute it and/or modify | |
| 3 | * it under the terms of the GNU General Public License as published by | |
| 4 | * the Free Software Foundation, either version 3 of the License, or | |
| 5 | * (at your option) any later version. | |
| 6 | * | |
| 7 | * This program is distributed in the hope that it will be useful, | |
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 10 | * GNU General Public License for more details. | |
| 11 | * | |
| 12 | * You should have received a copy of the GNU General Public License | |
| 13 | * along with this program. If not, see <http://www.gnu.org/licenses/>. | |
| 14 | */ | |
| 15 | ||
| 16 | /* | |
| 17 | * EvaluationUtils.java | |
| 18 | * Copyright (C) 2002-2012 University of Waikato, Hamilton, New Zealand | |
| 19 | * | |
| 20 | */ | |
| 21 | ||
| 22 | package weka.classifiers.evaluation; | |
| 23 | ||
| 24 | import java.util.Random; | |
| 25 | ||
| 26 | import weka.classifiers.Classifier; | |
| 27 | import weka.core.FastVector; | |
| 28 | import weka.core.Instance; | |
| 29 | import weka.core.Instances; | |
| 30 | import weka.core.RevisionHandler; | |
| 31 | import weka.core.RevisionUtils; | |
| 32 | ||
| 33 | /** | |
| 34 | * Contains utility functions for generating lists of predictions in | |
| 35 | * various manners. | |
| 36 | * | |
| 37 | * @author Len Trigg (len@reeltwo.com) | |
| 38 | * @version $Revision: 8034 $ | |
| 39 | */ | |
| 40 | 0 | public class EvaluationUtils |
| 41 | implements RevisionHandler { | |
| 42 | ||
| 43 | /** Seed used to randomize data in cross-validation */ | |
| 44 | 0 | private int m_Seed = 1; |
| 45 | ||
| 46 | /** Sets the seed for randomization during cross-validation */ | |
| 47 | 0 | public void setSeed(int seed) { m_Seed = seed; } |
| 48 | ||
| 49 | /** Gets the seed for randomization during cross-validation */ | |
| 50 | 0 | public int getSeed() { return m_Seed; } |
| 51 | ||
| 52 | /** | |
| 53 | * Generate a bunch of predictions ready for processing, by performing a | |
| 54 | * cross-validation on the supplied dataset. | |
| 55 | * | |
| 56 | * @param classifier the Classifier to evaluate | |
| 57 | * @param data the dataset | |
| 58 | * @param numFolds the number of folds in the cross-validation. | |
| 59 | * @exception Exception if an error occurs | |
| 60 | */ | |
| 61 | public FastVector getCVPredictions(Classifier classifier, | |
| 62 | Instances data, | |
| 63 | int numFolds) | |
| 64 | throws Exception { | |
| 65 | ||
| 66 | 0 | FastVector predictions = new FastVector(); |
| 67 | 0 | Instances runInstances = new Instances(data); |
| 68 | 0 | Random random = new Random(m_Seed); |
| 69 | 0 | runInstances.randomize(random); |
| 70 | 0 | if (runInstances.classAttribute().isNominal() && (numFolds > 1)) { |
| 71 | 0 | runInstances.stratify(numFolds); |
| 72 | } | |
| 73 | 0 | int inst = 0; |
| 74 | 0 | for (int fold = 0; fold < numFolds; fold++) { |
| 75 | 0 | Instances train = runInstances.trainCV(numFolds, fold, random); |
| 76 | 0 | Instances test = runInstances.testCV(numFolds, fold); |
| 77 | 0 | FastVector foldPred = getTrainTestPredictions(classifier, train, test); |
| 78 | 0 | predictions.appendElements(foldPred); |
| 79 | } | |
| 80 | 0 | return predictions; |
| 81 | } | |
| 82 | ||
| 83 | /** | |
| 84 | * Generate a bunch of predictions ready for processing, by performing a | |
| 85 | * evaluation on a test set after training on the given training set. | |
| 86 | * | |
| 87 | * @param classifier the Classifier to evaluate | |
| 88 | * @param train the training dataset | |
| 89 | * @param test the test dataset | |
| 90 | * @exception Exception if an error occurs | |
| 91 | */ | |
| 92 | public FastVector getTrainTestPredictions(Classifier classifier, | |
| 93 | Instances train, Instances test) | |
| 94 | throws Exception { | |
| 95 | ||
| 96 | 0 | classifier.buildClassifier(train); |
| 97 | 0 | return getTestPredictions(classifier, test); |
| 98 | } | |
| 99 | ||
| 100 | /** | |
| 101 | * Generate a bunch of predictions ready for processing, by performing a | |
| 102 | * evaluation on a test set assuming the classifier is already trained. | |
| 103 | * | |
| 104 | * @param classifier the pre-trained Classifier to evaluate | |
| 105 | * @param test the test dataset | |
| 106 | * @exception Exception if an error occurs | |
| 107 | */ | |
| 108 | public FastVector getTestPredictions(Classifier classifier, | |
| 109 | Instances test) | |
| 110 | throws Exception { | |
| 111 | ||
| 112 | 0 | FastVector predictions = new FastVector(); |
| 113 | 0 | for (int i = 0; i < test.numInstances(); i++) { |
| 114 | 0 | if (!test.instance(i).classIsMissing()) { |
| 115 | 0 | predictions.addElement(getPrediction(classifier, test.instance(i))); |
| 116 | } | |
| 117 | } | |
| 118 | 0 | return predictions; |
| 119 | } | |
| 120 | ||
| 121 | ||
| 122 | /** | |
| 123 | * Generate a single prediction for a test instance given the pre-trained | |
| 124 | * classifier. | |
| 125 | * | |
| 126 | * @param classifier the pre-trained Classifier to evaluate | |
| 127 | * @param test the test instance | |
| 128 | * @exception Exception if an error occurs | |
| 129 | */ | |
| 130 | public Prediction getPrediction(Classifier classifier, | |
| 131 | Instance test) | |
| 132 | throws Exception { | |
| 133 | ||
| 134 | 0 | double actual = test.classValue(); |
| 135 | 0 | double [] dist = classifier.distributionForInstance(test); |
| 136 | 0 | if (test.classAttribute().isNominal()) { |
| 137 | 0 | return new NominalPrediction(actual, dist, test.weight()); |
| 138 | } else { | |
| 139 | 0 | return new NumericPrediction(actual, dist[0], test.weight()); |
| 140 | } | |
| 141 | } | |
| 142 | ||
| 143 | /** | |
| 144 | * Returns the revision string. | |
| 145 | * | |
| 146 | * @return the revision | |
| 147 | */ | |
| 148 | public String getRevision() { | |
| 149 | 0 | return RevisionUtils.extract("$Revision: 8034 $"); |
| 150 | } | |
| 151 | } | |
| 152 |