Coverage Report - weka.classifiers.evaluation.EvaluationUtils
 
Classes in this File Line Coverage Branch Coverage Complexity
EvaluationUtils
0%
0/30
0%
0/12
2
 
 1  
 /*
 2  
  *   This program is free software: you can redistribute it and/or modify
 3  
  *   it under the terms of the GNU General Public License as published by
 4  
  *   the Free Software Foundation, either version 3 of the License, or
 5  
  *   (at your option) any later version.
 6  
  *
 7  
  *   This program is distributed in the hope that it will be useful,
 8  
  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 9  
  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 10  
  *   GNU General Public License for more details.
 11  
  *
 12  
  *   You should have received a copy of the GNU General Public License
 13  
  *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 14  
  */
 15  
 
 16  
 /*
 17  
  *    EvaluationUtils.java
 18  
  *    Copyright (C) 2002-2012 University of Waikato, Hamilton, New Zealand
 19  
  *
 20  
  */
 21  
 
 22  
 package weka.classifiers.evaluation;
 23  
 
 24  
 import java.util.Random;
 25  
 
 26  
 import weka.classifiers.Classifier;
 27  
 import weka.core.FastVector;
 28  
 import weka.core.Instance;
 29  
 import weka.core.Instances;
 30  
 import weka.core.RevisionHandler;
 31  
 import weka.core.RevisionUtils;
 32  
 
 33  
 /**
 34  
  * Contains utility functions for generating lists of predictions in 
 35  
  * various manners.
 36  
  *
 37  
  * @author Len Trigg (len@reeltwo.com)
 38  
  * @version $Revision: 8034 $
 39  
  */
 40  0
 public class EvaluationUtils
 41  
   implements RevisionHandler {
 42  
 
 43  
   /** Seed used to randomize data in cross-validation */
 44  0
   private int m_Seed = 1;
 45  
 
 46  
   /** Sets the seed for randomization during cross-validation */
 47  0
   public void setSeed(int seed) { m_Seed = seed; }
 48  
 
 49  
   /** Gets the seed for randomization during cross-validation */
 50  0
   public int getSeed() { return m_Seed; }
 51  
   
 52  
   /**
 53  
    * Generate a bunch of predictions ready for processing, by performing a
 54  
    * cross-validation on the supplied dataset.
 55  
    *
 56  
    * @param classifier the Classifier to evaluate
 57  
    * @param data the dataset
 58  
    * @param numFolds the number of folds in the cross-validation.
 59  
    * @exception Exception if an error occurs
 60  
    */
 61  
   public FastVector getCVPredictions(Classifier classifier, 
 62  
                                      Instances data, 
 63  
                                      int numFolds) 
 64  
     throws Exception {
 65  
 
 66  0
     FastVector predictions = new FastVector();
 67  0
     Instances runInstances = new Instances(data);
 68  0
     Random random = new Random(m_Seed);
 69  0
     runInstances.randomize(random);
 70  0
     if (runInstances.classAttribute().isNominal() && (numFolds > 1)) {
 71  0
       runInstances.stratify(numFolds);
 72  
     }
 73  0
     int inst = 0;
 74  0
     for (int fold = 0; fold < numFolds; fold++) {
 75  0
       Instances train = runInstances.trainCV(numFolds, fold, random);
 76  0
       Instances test = runInstances.testCV(numFolds, fold);
 77  0
       FastVector foldPred = getTrainTestPredictions(classifier, train, test);
 78  0
       predictions.appendElements(foldPred);
 79  
     } 
 80  0
     return predictions;
 81  
   }
 82  
 
 83  
   /**
 84  
    * Generate a bunch of predictions ready for processing, by performing a
 85  
    * evaluation on a test set after training on the given training set.
 86  
    *
 87  
    * @param classifier the Classifier to evaluate
 88  
    * @param train the training dataset
 89  
    * @param test the test dataset
 90  
    * @exception Exception if an error occurs
 91  
    */
 92  
   public FastVector getTrainTestPredictions(Classifier classifier, 
 93  
                                             Instances train, Instances test) 
 94  
     throws Exception {
 95  
     
 96  0
     classifier.buildClassifier(train);
 97  0
     return getTestPredictions(classifier, test);
 98  
   }
 99  
 
 100  
   /**
 101  
    * Generate a bunch of predictions ready for processing, by performing a
 102  
    * evaluation on a test set assuming the classifier is already trained.
 103  
    *
 104  
    * @param classifier the pre-trained Classifier to evaluate
 105  
    * @param test the test dataset
 106  
    * @exception Exception if an error occurs
 107  
    */
 108  
   public FastVector getTestPredictions(Classifier classifier, 
 109  
                                        Instances test) 
 110  
     throws Exception {
 111  
     
 112  0
     FastVector predictions = new FastVector();
 113  0
     for (int i = 0; i < test.numInstances(); i++) {
 114  0
       if (!test.instance(i).classIsMissing()) {
 115  0
         predictions.addElement(getPrediction(classifier, test.instance(i)));
 116  
       }
 117  
     }
 118  0
     return predictions;
 119  
   }
 120  
 
 121  
   
 122  
   /**
 123  
    * Generate a single prediction for a test instance given the pre-trained
 124  
    * classifier.
 125  
    *
 126  
    * @param classifier the pre-trained Classifier to evaluate
 127  
    * @param test the test instance
 128  
    * @exception Exception if an error occurs
 129  
    */
 130  
   public Prediction getPrediction(Classifier classifier,
 131  
                                   Instance test)
 132  
     throws Exception {
 133  
    
 134  0
     double actual = test.classValue();
 135  0
     double [] dist = classifier.distributionForInstance(test);
 136  0
     if (test.classAttribute().isNominal()) {
 137  0
       return new NominalPrediction(actual, dist, test.weight());
 138  
     } else {
 139  0
       return new NumericPrediction(actual, dist[0], test.weight());
 140  
     }
 141  
   }
 142  
   
 143  
   /**
 144  
    * Returns the revision string.
 145  
    * 
 146  
    * @return                the revision
 147  
    */
 148  
   public String getRevision() {
 149  0
     return RevisionUtils.extract("$Revision: 8034 $");
 150  
   }
 151  
 }
 152