Coverage Report - weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes
 
Classes in this File Line Coverage Branch Coverage Complexity
DiscreteEstimatorBayes
0%
0/68
0%
0/36
3.667
 
 1  
 /*
 2  
  *   This program is free software: you can redistribute it and/or modify
 3  
  *   it under the terms of the GNU General Public License as published by
 4  
  *   the Free Software Foundation, either version 3 of the License, or
 5  
  *   (at your option) any later version.
 6  
  *
 7  
  *   This program is distributed in the hope that it will be useful,
 8  
  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 9  
  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 10  
  *   GNU General Public License for more details.
 11  
  *
 12  
  *   You should have received a copy of the GNU General Public License
 13  
  *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 14  
  */
 15  
 
 16  
 /*
 17  
  * DiscreteEstimatorBayes.java
 18  
  * Adapted from DiscreteEstimator.java
 19  
  * Copyright (C) 2012 University of Waikato, Hamilton, New Zealand
 20  
  * 
 21  
  */
 22  
 package weka.classifiers.bayes.net.estimate;
 23  
 
 24  
 import weka.classifiers.bayes.net.search.local.Scoreable;
 25  
 import weka.core.RevisionUtils;
 26  
 import weka.core.Statistics;
 27  
 import weka.core.Utils;
 28  
 import weka.estimators.DiscreteEstimator;
 29  
 import weka.estimators.Estimator;
 30  
 
 31  
 /**
 32  
  * Symbolic probability estimator based on symbol counts and a prior.
 33  
  * 
 34  
  * @author Remco Bouckaert (rrb@xm.co.nz)
 35  
  * @version $Revision: 8034 $
 36  
  */
 37  
 public class DiscreteEstimatorBayes extends Estimator
 38  
   implements Scoreable {
 39  
 
 40  
   /** for serialization */
 41  
   static final long serialVersionUID = 4215400230843212684L;
 42  
   
 43  
   /**
 44  
    * Hold the counts
 45  
    */
 46  
   protected double[] m_Counts;
 47  
 
 48  
   /**
 49  
    * Hold the sum of counts
 50  
    */
 51  
   protected double   m_SumOfCounts;
 52  
 
 53  
   /**
 54  
    * Holds number of symbols in distribution
 55  
    */
 56  0
   protected int      m_nSymbols = 0;
 57  
 
 58  
   /**
 59  
    * Holds the prior probability
 60  
    */
 61  0
   protected double   m_fPrior = 0.0;
 62  
 
 63  
   /**
 64  
    * Constructor
 65  
    * 
 66  
    * @param nSymbols the number of possible symbols (remember to include 0)
 67  
    * @param fPrior
 68  
    */
 69  0
   public DiscreteEstimatorBayes(int nSymbols, double fPrior) {
 70  0
     m_fPrior = fPrior;
 71  0
     m_nSymbols = nSymbols;
 72  0
     m_Counts = new double[m_nSymbols];
 73  
 
 74  0
     for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
 75  0
       m_Counts[iSymbol] = m_fPrior;
 76  
     } 
 77  
 
 78  0
     m_SumOfCounts = m_fPrior * (double) m_nSymbols;
 79  0
   }    // DiscreteEstimatorBayes
 80  
 
 81  
   /**
 82  
    * Add a new data value to the current estimator.
 83  
    * 
 84  
    * @param data the new data value
 85  
    * @param weight the weight assigned to the data value
 86  
    */
 87  
   public void addValue(double data, double weight) {
 88  0
     m_Counts[(int) data] += weight;
 89  0
     m_SumOfCounts += weight;
 90  0
   } 
 91  
 
 92  
   /**
 93  
    * Get a probability estimate for a value
 94  
    * 
 95  
    * @param data the value to estimate the probability of
 96  
    * @return the estimated probability of the supplied value
 97  
    */
 98  
   public double getProbability(double data) {
 99  0
     if (m_SumOfCounts == 0) {
 100  
 
 101  
       // this can only happen if numSymbols = 0 in constructor
 102  0
       return 0;
 103  
     } 
 104  
 
 105  0
     return (double) m_Counts[(int) data] / m_SumOfCounts;
 106  
   } 
 107  
 
 108  
   /**
 109  
    * Get a counts for a value
 110  
    * 
 111  
    * @param data the value to get the counts for
 112  
    * @return the count of the supplied value
 113  
    */
 114  
   public double getCount(double data) {
 115  0
     if (m_SumOfCounts == 0) {
 116  
       // this can only happen if numSymbols = 0 in constructor
 117  0
       return 0;
 118  
     } 
 119  
 
 120  0
     return m_Counts[(int) data];
 121  
   } 
 122  
   
 123  
   /**
 124  
    * Gets the number of symbols this estimator operates with
 125  
    * 
 126  
    * @return the number of estimator symbols
 127  
    */
 128  
   public int getNumSymbols() {
 129  0
     return (m_Counts == null) ? 0 : m_Counts.length;
 130  
   } 
 131  
 
 132  
   /**
 133  
    * Gets the log score contribution of this distribution
 134  
    * @param nType score type
 135  
    * @return the score
 136  
    */
 137  
   public double logScore(int nType, int nCardinality) {
 138  0
             double fScore = 0.0;
 139  
 
 140  0
             switch (nType) {
 141  
 
 142  
             case (Scoreable.BAYES): {
 143  0
               for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
 144  0
                 fScore += Statistics.lnGamma(m_Counts[iSymbol]);
 145  
               } 
 146  
 
 147  0
               fScore -= Statistics.lnGamma(m_SumOfCounts);
 148  0
               if (m_fPrior != 0.0) {
 149  0
                       fScore -= m_nSymbols * Statistics.lnGamma(m_fPrior);
 150  0
                       fScore += Statistics.lnGamma(m_nSymbols * m_fPrior);
 151  
               }
 152  
             } 
 153  
 
 154  
               break;
 155  
                   case (Scoreable.BDeu): {
 156  0
                   for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
 157  0
                         fScore += Statistics.lnGamma(m_Counts[iSymbol]);
 158  
                   } 
 159  
 
 160  0
                   fScore -= Statistics.lnGamma(m_SumOfCounts);
 161  
                   //fScore -= m_nSymbols * Statistics.lnGamma(1.0);
 162  
                   //fScore += Statistics.lnGamma(m_nSymbols * 1.0);
 163  0
               fScore -= m_nSymbols * Statistics.lnGamma(1.0/(m_nSymbols * nCardinality));
 164  0
               fScore += Statistics.lnGamma(1.0/nCardinality);
 165  
                 } 
 166  0
                   break;
 167  
 
 168  
             case (Scoreable.MDL):
 169  
 
 170  
             case (Scoreable.AIC):
 171  
 
 172  
             case (Scoreable.ENTROPY): {
 173  0
               for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
 174  0
                 double fP = getProbability(iSymbol);
 175  
 
 176  0
                 fScore += m_Counts[iSymbol] * Math.log(fP);
 177  
               } 
 178  
             } 
 179  
 
 180  0
               break;
 181  
 
 182  
             default: {}
 183  
             }
 184  
 
 185  0
             return fScore;
 186  
           } 
 187  
 
 188  
   /**
 189  
    * Display a representation of this estimator
 190  
    * 
 191  
    * @return a string representation of the estimator
 192  
    */
 193  
   public String toString() {
 194  0
     String result = "Discrete Estimator. Counts = ";
 195  
 
 196  0
     if (m_SumOfCounts > 1) {
 197  0
       for (int i = 0; i < m_Counts.length; i++) {
 198  0
         result += " " + Utils.doubleToString(m_Counts[i], 2);
 199  
       } 
 200  
 
 201  0
       result += "  (Total = " + Utils.doubleToString(m_SumOfCounts, 2) 
 202  
                 + ")\n";
 203  
     } else {
 204  0
       for (int i = 0; i < m_Counts.length; i++) {
 205  0
         result += " " + m_Counts[i];
 206  
       } 
 207  
 
 208  0
       result += "  (Total = " + m_SumOfCounts + ")\n";
 209  
     } 
 210  
 
 211  0
     return result;
 212  
   } 
 213  
   
 214  
   /**
 215  
    * Returns the revision string.
 216  
    * 
 217  
    * @return                the revision
 218  
    */
 219  
   public String getRevision() {
 220  0
     return RevisionUtils.extract("$Revision: 8034 $");
 221  
   }
 222  
   
 223  
   /**
 224  
    * Main method for testing this class.
 225  
    * 
 226  
    * @param argv should contain a sequence of integers which
 227  
    * will be treated as symbolic.
 228  
    */
 229  
   public static void main(String[] argv) {
 230  
     try {
 231  0
       if (argv.length == 0) {
 232  0
         System.out.println("Please specify a set of instances.");
 233  
 
 234  0
         return;
 235  
       } 
 236  
 
 237  0
       int current = Integer.parseInt(argv[0]);
 238  0
       int max = current;
 239  
 
 240  0
       for (int i = 1; i < argv.length; i++) {
 241  0
         current = Integer.parseInt(argv[i]);
 242  
 
 243  0
         if (current > max) {
 244  0
           max = current;
 245  
         } 
 246  
       } 
 247  
 
 248  0
       DiscreteEstimator newEst = new DiscreteEstimator(max + 1, true);
 249  
 
 250  0
       for (int i = 0; i < argv.length; i++) {
 251  0
         current = Integer.parseInt(argv[i]);
 252  
 
 253  0
         System.out.println(newEst);
 254  0
         System.out.println("Prediction for " + current + " = " 
 255  
                            + newEst.getProbability(current));
 256  0
         newEst.addValue(current, 1);
 257  
       } 
 258  0
     } catch (Exception e) {
 259  0
       System.out.println(e.getMessage());
 260  0
     } 
 261  0
   }    // main
 262  
  
 263  
 }      // class DiscreteEstimatorBayes