Coverage Report - weka.classifiers.functions.Logistic
 
Classes in this File Line Coverage Branch Coverage Complexity
Logistic
0%
0/266
0%
0/122
3.963
Logistic$1
N/A
N/A
3.963
Logistic$OptEng
0%
0/54
0%
0/34
3.963
 
 1  
 /*
 2  
  *   This program is free software: you can redistribute it and/or modify
 3  
  *   it under the terms of the GNU General Public License as published by
 4  
  *   the Free Software Foundation, either version 3 of the License, or
 5  
  *   (at your option) any later version.
 6  
  *
 7  
  *   This program is distributed in the hope that it will be useful,
 8  
  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 9  
  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 10  
  *   GNU General Public License for more details.
 11  
  *
 12  
  *   You should have received a copy of the GNU General Public License
 13  
  *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 14  
  */
 15  
 
 16  
 /*
 17  
  *    Logistic.java
 18  
  *    Copyright (C) 2003-2012 University of Waikato, Hamilton, New Zealand
 19  
  *
 20  
  */
 21  
 
 22  
 package weka.classifiers.functions;
 23  
 
 24  
 import java.util.Enumeration;
 25  
 import java.util.Vector;
 26  
 
 27  
 import weka.classifiers.AbstractClassifier;
 28  
 import weka.core.Capabilities;
 29  
 import weka.core.Capabilities.Capability;
 30  
 import weka.core.Instance;
 31  
 import weka.core.Instances;
 32  
 import weka.core.Optimization;
 33  
 import weka.core.Option;
 34  
 import weka.core.OptionHandler;
 35  
 import weka.core.RevisionUtils;
 36  
 import weka.core.TechnicalInformation;
 37  
 import weka.core.TechnicalInformation.Field;
 38  
 import weka.core.TechnicalInformation.Type;
 39  
 import weka.core.TechnicalInformationHandler;
 40  
 import weka.core.Utils;
 41  
 import weka.core.WeightedInstancesHandler;
 42  
 import weka.filters.Filter;
 43  
 import weka.filters.unsupervised.attribute.NominalToBinary;
 44  
 import weka.filters.unsupervised.attribute.RemoveUseless;
 45  
 import weka.filters.unsupervised.attribute.ReplaceMissingValues;
 46  
 
 47  
 /**
 48  
  <!-- globalinfo-start -->
 49  
  * Class for building and using a multinomial logistic regression model with a ridge estimator.<br/>
 50  
  * <br/>
 51  
  * There are some modifications, however, compared to the paper of leCessie and van Houwelingen(1992): <br/>
 52  
  * <br/>
 53  
  * If there are k classes for n instances with m attributes, the parameter matrix B to be calculated will be an m*(k-1) matrix.<br/>
 54  
  * <br/>
 55  
  * The probability for class j with the exception of the last class is<br/>
 56  
  * <br/>
 57  
  * Pj(Xi) = exp(XiBj)/((sum[j=1..(k-1)]exp(Xi*Bj))+1) <br/>
 58  
  * <br/>
 59  
  * The last class has probability<br/>
 60  
  * <br/>
 61  
  * 1-(sum[j=1..(k-1)]Pj(Xi)) <br/>
 62  
  *         = 1/((sum[j=1..(k-1)]exp(Xi*Bj))+1)<br/>
 63  
  * <br/>
 64  
  * The (negative) multinomial log-likelihood is thus: <br/>
 65  
  * <br/>
 66  
  * L = -sum[i=1..n]{<br/>
 67  
  *         sum[j=1..(k-1)](Yij * ln(Pj(Xi)))<br/>
 68  
  *         +(1 - (sum[j=1..(k-1)]Yij)) <br/>
 69  
  *         * ln(1 - sum[j=1..(k-1)]Pj(Xi))<br/>
 70  
  *         } + ridge * (B^2)<br/>
 71  
  * <br/>
 72  
  * In order to find the matrix B for which L is minimised, a Quasi-Newton Method is used to search for the optimized values of the m*(k-1) variables.  Note that before we use the optimization procedure, we 'squeeze' the matrix B into a m*(k-1) vector.  For details of the optimization procedure, please check weka.core.Optimization class.<br/>
 73  
  * <br/>
 74  
  * Although original Logistic Regression does not deal with instance weights, we modify the algorithm a little bit to handle the instance weights.<br/>
 75  
  * <br/>
 76  
  * For more information see:<br/>
 77  
  * <br/>
 78  
  * le Cessie, S., van Houwelingen, J.C. (1992). Ridge Estimators in Logistic Regression. Applied Statistics. 41(1):191-201.<br/>
 79  
  * <br/>
 80  
  * Note: Missing values are replaced using a ReplaceMissingValuesFilter, and nominal attributes are transformed into numeric attributes using a NominalToBinaryFilter.
 81  
  * <p/>
 82  
  <!-- globalinfo-end -->
 83  
  *
 84  
  <!-- technical-bibtex-start -->
 85  
  * BibTeX:
 86  
  * <pre>
 87  
  * &#64;article{leCessie1992,
 88  
  *    author = {le Cessie, S. and van Houwelingen, J.C.},
 89  
  *    journal = {Applied Statistics},
 90  
  *    number = {1},
 91  
  *    pages = {191-201},
 92  
  *    title = {Ridge Estimators in Logistic Regression},
 93  
  *    volume = {41},
 94  
  *    year = {1992}
 95  
  * }
 96  
  * </pre>
 97  
  * <p/>
 98  
  <!-- technical-bibtex-end -->
 99  
  *
 100  
  <!-- options-start -->
 101  
  * Valid options are: <p/>
 102  
  * 
 103  
  * <pre> -D
 104  
  *  Turn on debugging output.</pre>
 105  
  * 
 106  
  * <pre> -R &lt;ridge&gt;
 107  
  *  Set the ridge in the log-likelihood.</pre>
 108  
  * 
 109  
  * <pre> -M &lt;number&gt;
 110  
  *  Set the maximum number of iterations (default -1, until convergence).</pre>
 111  
  * 
 112  
  <!-- options-end -->
 113  
  *
 114  
  * @author Xin Xu (xx5@cs.waikato.ac.nz)
 115  
  * @version $Revision: 8034 $
 116  
  */
 117  0
 public class Logistic extends AbstractClassifier 
 118  
   implements OptionHandler, WeightedInstancesHandler, TechnicalInformationHandler {
 119  
   
 120  
   /** for serialization */
 121  
   static final long serialVersionUID = 3932117032546553727L;
 122  
   
 123  
   /** The coefficients (optimized parameters) of the model */
 124  
   protected double [][] m_Par;
 125  
     
 126  
   /** The data saved as a matrix */
 127  
   protected double [][] m_Data;
 128  
     
 129  
   /** The number of attributes in the model */
 130  
   protected int m_NumPredictors;
 131  
     
 132  
   /** The index of the class attribute */
 133  
   protected int m_ClassIndex;
 134  
     
 135  
   /** The number of the class labels */
 136  
   protected int m_NumClasses;
 137  
     
 138  
   /** The ridge parameter. */
 139  0
   protected double m_Ridge = 1e-8;
 140  
     
 141  
   /** An attribute filter */
 142  
   private RemoveUseless m_AttFilter;
 143  
     
 144  
   /** The filter used to make attributes numeric. */
 145  
   private NominalToBinary m_NominalToBinary;
 146  
     
 147  
   /** The filter used to get rid of missing values. */
 148  
   private ReplaceMissingValues m_ReplaceMissingValues;
 149  
     
 150  
   /** Debugging output */
 151  
   protected boolean m_Debug;
 152  
 
 153  
   /** Log-likelihood of the searched model */
 154  
   protected double m_LL;
 155  
     
 156  
   /** The maximum number of iterations. */
 157  0
   private int m_MaxIts = -1;
 158  
 
 159  
   private Instances m_structure;
 160  
     
 161  
   /**
 162  
    * Returns a string describing this classifier
 163  
    * @return a description of the classifier suitable for
 164  
    * displaying in the explorer/experimenter gui
 165  
    */
 166  
   public String globalInfo() {
 167  0
     return "Class for building and using a multinomial logistic "
 168  
       +"regression model with a ridge estimator.\n\n"
 169  
       +"There are some modifications, however, compared to the paper of "
 170  
       +"leCessie and van Houwelingen(1992): \n\n" 
 171  
       +"If there are k classes for n instances with m attributes, the "
 172  
       +"parameter matrix B to be calculated will be an m*(k-1) matrix.\n\n"
 173  
       +"The probability for class j with the exception of the last class is\n\n"
 174  
       +"Pj(Xi) = exp(XiBj)/((sum[j=1..(k-1)]exp(Xi*Bj))+1) \n\n"
 175  
       +"The last class has probability\n\n"
 176  
       +"1-(sum[j=1..(k-1)]Pj(Xi)) \n\t= 1/((sum[j=1..(k-1)]exp(Xi*Bj))+1)\n\n"
 177  
       +"The (negative) multinomial log-likelihood is thus: \n\n"
 178  
       +"L = -sum[i=1..n]{\n\tsum[j=1..(k-1)](Yij * ln(Pj(Xi)))"
 179  
       +"\n\t+(1 - (sum[j=1..(k-1)]Yij)) \n\t* ln(1 - sum[j=1..(k-1)]Pj(Xi))"
 180  
       +"\n\t} + ridge * (B^2)\n\n"
 181  
       +"In order to find the matrix B for which L is minimised, a "
 182  
       +"Quasi-Newton Method is used to search for the optimized values of "
 183  
       +"the m*(k-1) variables.  Note that before we use the optimization "
 184  
       +"procedure, we 'squeeze' the matrix B into a m*(k-1) vector.  For "
 185  
       +"details of the optimization procedure, please check "
 186  
       +"weka.core.Optimization class.\n\n"
 187  
       +"Although original Logistic Regression does not deal with instance "
 188  
       +"weights, we modify the algorithm a little bit to handle the "
 189  
       +"instance weights.\n\n"
 190  
       +"For more information see:\n\n"
 191  
       + getTechnicalInformation().toString() + "\n\n"
 192  
       +"Note: Missing values are replaced using a ReplaceMissingValuesFilter, and "
 193  
       +"nominal attributes are transformed into numeric attributes using a "
 194  
       +"NominalToBinaryFilter.";
 195  
   }
 196  
 
 197  
   /**
 198  
    * Returns an instance of a TechnicalInformation object, containing 
 199  
    * detailed information about the technical background of this class,
 200  
    * e.g., paper reference or book this class is based on.
 201  
    * 
 202  
    * @return the technical information about this class
 203  
    */
 204  
   public TechnicalInformation getTechnicalInformation() {
 205  
     TechnicalInformation         result;
 206  
     
 207  0
     result = new TechnicalInformation(Type.ARTICLE);
 208  0
     result.setValue(Field.AUTHOR, "le Cessie, S. and van Houwelingen, J.C.");
 209  0
     result.setValue(Field.YEAR, "1992");
 210  0
     result.setValue(Field.TITLE, "Ridge Estimators in Logistic Regression");
 211  0
     result.setValue(Field.JOURNAL, "Applied Statistics");
 212  0
     result.setValue(Field.VOLUME, "41");
 213  0
     result.setValue(Field.NUMBER, "1");
 214  0
     result.setValue(Field.PAGES, "191-201");
 215  
     
 216  0
     return result;
 217  
   }
 218  
 
 219  
   /**
 220  
    * Returns an enumeration describing the available options
 221  
    *
 222  
    * @return an enumeration of all the available options
 223  
    */
 224  
   public Enumeration listOptions() {
 225  0
     Vector newVector = new Vector(3);
 226  0
     newVector.addElement(new Option("\tTurn on debugging output.",
 227  
                                     "D", 0, "-D"));
 228  0
     newVector.addElement(new Option("\tSet the ridge in the log-likelihood.",
 229  
                                     "R", 1, "-R <ridge>"));
 230  0
     newVector.addElement(new Option("\tSet the maximum number of iterations"+
 231  
                                     " (default -1, until convergence).",
 232  
                                     "M", 1, "-M <number>"));
 233  0
     return newVector.elements();
 234  
   }
 235  
     
 236  
   /**
 237  
    * Parses a given list of options. <p/>
 238  
    *
 239  
    <!-- options-start -->
 240  
    * Valid options are: <p/>
 241  
    * 
 242  
    * <pre> -D
 243  
    *  Turn on debugging output.</pre>
 244  
    * 
 245  
    * <pre> -R &lt;ridge&gt;
 246  
    *  Set the ridge in the log-likelihood.</pre>
 247  
    * 
 248  
    * <pre> -M &lt;number&gt;
 249  
    *  Set the maximum number of iterations (default -1, until convergence).</pre>
 250  
    * 
 251  
    <!-- options-end -->
 252  
    *
 253  
    * @param options the list of options as an array of strings
 254  
    * @throws Exception if an option is not supported
 255  
    */
 256  
   public void setOptions(String[] options) throws Exception {
 257  0
     setDebug(Utils.getFlag('D', options));
 258  
 
 259  0
     String ridgeString = Utils.getOption('R', options);
 260  0
     if (ridgeString.length() != 0) 
 261  0
       m_Ridge = Double.parseDouble(ridgeString);
 262  
     else 
 263  0
       m_Ridge = 1.0e-8;
 264  
         
 265  0
     String maxItsString = Utils.getOption('M', options);
 266  0
     if (maxItsString.length() != 0) 
 267  0
       m_MaxIts = Integer.parseInt(maxItsString);
 268  
     else 
 269  0
       m_MaxIts = -1;
 270  0
   }
 271  
     
 272  
   /**
 273  
    * Gets the current settings of the classifier.
 274  
    *
 275  
    * @return an array of strings suitable for passing to setOptions
 276  
    */
 277  
   public String [] getOptions() {
 278  
         
 279  0
     String [] options = new String [5];
 280  0
     int current = 0;
 281  
         
 282  0
     if (getDebug()) 
 283  0
       options[current++] = "-D";
 284  0
     options[current++] = "-R";
 285  0
     options[current++] = ""+m_Ridge;        
 286  0
     options[current++] = "-M";
 287  0
     options[current++] = ""+m_MaxIts;
 288  0
     while (current < options.length) 
 289  0
       options[current++] = "";
 290  0
     return options;
 291  
   }
 292  
    
 293  
   /**
 294  
    * Returns the tip text for this property
 295  
    * @return tip text for this property suitable for
 296  
    * displaying in the explorer/experimenter gui
 297  
    */
 298  
   public String debugTipText() {
 299  0
     return "Output debug information to the console.";
 300  
   }
 301  
 
 302  
   /**
 303  
    * Sets whether debugging output will be printed.
 304  
    *
 305  
    * @param debug true if debugging output should be printed
 306  
    */
 307  
   public void setDebug(boolean debug) {
 308  0
     m_Debug = debug;
 309  0
   }
 310  
     
 311  
   /**
 312  
    * Gets whether debugging output will be printed.
 313  
    *
 314  
    * @return true if debugging output will be printed
 315  
    */
 316  
   public boolean getDebug() {
 317  0
     return m_Debug;
 318  
   }      
 319  
 
 320  
   /**
 321  
    * Returns the tip text for this property
 322  
    * @return tip text for this property suitable for
 323  
    * displaying in the explorer/experimenter gui
 324  
    */
 325  
   public String ridgeTipText() {
 326  0
     return "Set the Ridge value in the log-likelihood.";
 327  
   }
 328  
 
 329  
   /**
 330  
    * Sets the ridge in the log-likelihood.
 331  
    *
 332  
    * @param ridge the ridge
 333  
    */
 334  
   public void setRidge(double ridge) {
 335  0
     m_Ridge = ridge;
 336  0
   }
 337  
     
 338  
   /**
 339  
    * Gets the ridge in the log-likelihood.
 340  
    *
 341  
    * @return the ridge
 342  
    */
 343  
   public double getRidge() {
 344  0
     return m_Ridge;
 345  
   }
 346  
    
 347  
   /**
 348  
    * Returns the tip text for this property
 349  
    * @return tip text for this property suitable for
 350  
    * displaying in the explorer/experimenter gui
 351  
    */
 352  
   public String maxItsTipText() {
 353  0
     return "Maximum number of iterations to perform.";
 354  
   }
 355  
 
 356  
   /**
 357  
    * Get the value of MaxIts.
 358  
    *
 359  
    * @return Value of MaxIts.
 360  
    */
 361  
   public int getMaxIts() {
 362  
         
 363  0
     return m_MaxIts;
 364  
   }
 365  
     
 366  
   /**
 367  
    * Set the value of MaxIts.
 368  
    *
 369  
    * @param newMaxIts Value to assign to MaxIts.
 370  
    */
 371  
   public void setMaxIts(int newMaxIts) {
 372  
         
 373  0
     m_MaxIts = newMaxIts;
 374  0
   }    
 375  
     
 376  0
   private class OptEng extends Optimization{
 377  
     /** Weights of instances in the data */
 378  
     private double[] weights;
 379  
 
 380  
     /** Class labels of instances */
 381  
     private int[] cls;
 382  
         
 383  
     /** 
 384  
      * Set the weights of instances
 385  
      * @param w the weights to be set
 386  
      */ 
 387  
     public void setWeights(double[] w) {
 388  0
       weights = w;
 389  0
     }
 390  
         
 391  
     /** 
 392  
      * Set the class labels of instances
 393  
      * @param c the class labels to be set
 394  
      */ 
 395  
     public void setClassLabels(int[] c) {
 396  0
       cls = c;
 397  0
     }
 398  
         
 399  
     /** 
 400  
      * Evaluate objective function
 401  
      * @param x the current values of variables
 402  
      * @return the value of the objective function 
 403  
      */
 404  
     protected double objectiveFunction(double[] x){
 405  0
       double nll = 0; // -LogLikelihood
 406  0
       int dim = m_NumPredictors+1; // Number of variables per class
 407  
             
 408  0
       for(int i=0; i<cls.length; i++){ // ith instance
 409  
 
 410  0
         double[] exp = new double[m_NumClasses-1];
 411  
         int index;
 412  0
         for(int offset=0; offset<m_NumClasses-1; offset++){ 
 413  0
           index = offset * dim;
 414  0
           for(int j=0; j<dim; j++)
 415  0
             exp[offset] += m_Data[i][j]*x[index + j];
 416  
         }
 417  0
         double max = exp[Utils.maxIndex(exp)];
 418  0
         double denom = Math.exp(-max);
 419  
         double num;
 420  0
         if (cls[i] == m_NumClasses - 1) { // Class of this instance
 421  0
           num = -max;
 422  
         } else {
 423  0
           num = exp[cls[i]] - max;
 424  
         }
 425  0
         for(int offset=0; offset<m_NumClasses-1; offset++){
 426  0
           denom += Math.exp(exp[offset] - max);
 427  
         }
 428  
                 
 429  0
         nll -= weights[i]*(num - Math.log(denom)); // Weighted NLL
 430  
       }
 431  
             
 432  
       // Ridge: note that intercepts NOT included
 433  0
       for(int offset=0; offset<m_NumClasses-1; offset++){
 434  0
         for(int r=1; r<dim; r++)
 435  0
           nll += m_Ridge*x[offset*dim+r]*x[offset*dim+r];
 436  
       }
 437  
             
 438  0
       return nll;
 439  
     }
 440  
 
 441  
     /** 
 442  
      * Evaluate Jacobian vector
 443  
      * @param x the current values of variables
 444  
      * @return the gradient vector 
 445  
      */
 446  
     protected double[] evaluateGradient(double[] x){
 447  0
       double[] grad = new double[x.length];
 448  0
       int dim = m_NumPredictors+1; // Number of variables per class
 449  
             
 450  0
       for(int i=0; i<cls.length; i++){ // ith instance
 451  0
         double[] num=new double[m_NumClasses-1]; // numerator of [-log(1+sum(exp))]'
 452  
         int index;
 453  0
         for(int offset=0; offset<m_NumClasses-1; offset++){ // Which part of x
 454  0
           double exp=0.0;
 455  0
           index = offset * dim;
 456  0
           for(int j=0; j<dim; j++)
 457  0
             exp += m_Data[i][j]*x[index + j];
 458  0
           num[offset] = exp;
 459  
         }
 460  
 
 461  0
         double max = num[Utils.maxIndex(num)];
 462  0
         double denom = Math.exp(-max); // Denominator of [-log(1+sum(exp))]'
 463  0
         for(int offset=0; offset<m_NumClasses-1; offset++){
 464  0
           num[offset] = Math.exp(num[offset] - max);
 465  0
           denom += num[offset];
 466  
         }
 467  0
         Utils.normalize(num, denom);
 468  
                 
 469  
         // Update denominator of the gradient of -log(Posterior)
 470  
         double firstTerm;
 471  0
         for(int offset=0; offset<m_NumClasses-1; offset++){ // Which part of x
 472  0
           index = offset * dim;
 473  0
           firstTerm = weights[i] * num[offset];
 474  0
           for(int q=0; q<dim; q++){
 475  0
             grad[index + q] += firstTerm * m_Data[i][q];
 476  
           }
 477  
         }
 478  
                 
 479  0
         if(cls[i] != m_NumClasses-1){ // Not the last class
 480  0
           for(int p=0; p<dim; p++){
 481  0
             grad[cls[i]*dim+p] -= weights[i]*m_Data[i][p]; 
 482  
           }
 483  
         }
 484  
       }
 485  
             
 486  
       // Ridge: note that intercepts NOT included
 487  0
       for(int offset=0; offset<m_NumClasses-1; offset++){
 488  0
         for(int r=1; r<dim; r++)
 489  0
           grad[offset*dim+r] += 2*m_Ridge*x[offset*dim+r];
 490  
       }
 491  
             
 492  0
       return grad;
 493  
     }
 494  
     
 495  
     /**
 496  
      * Returns the revision string.
 497  
      * 
 498  
      * @return                the revision
 499  
      */
 500  
     public String getRevision() {
 501  0
       return RevisionUtils.extract("$Revision: 8034 $");
 502  
     }
 503  
   }
 504  
 
 505  
   /**
 506  
    * Returns default capabilities of the classifier.
 507  
    *
 508  
    * @return      the capabilities of this classifier
 509  
    */
 510  
   public Capabilities getCapabilities() {
 511  0
     Capabilities result = super.getCapabilities();
 512  0
     result.disableAll();
 513  
 
 514  
     // attributes
 515  0
     result.enable(Capability.NOMINAL_ATTRIBUTES);
 516  0
     result.enable(Capability.NUMERIC_ATTRIBUTES);
 517  0
     result.enable(Capability.DATE_ATTRIBUTES);
 518  0
     result.enable(Capability.MISSING_VALUES);
 519  
 
 520  
     // class
 521  0
     result.enable(Capability.NOMINAL_CLASS);
 522  0
     result.enable(Capability.MISSING_CLASS_VALUES);
 523  
     
 524  0
     return result;
 525  
   }
 526  
     
 527  
   /**
 528  
    * Builds the classifier
 529  
    *
 530  
    * @param train the training data to be used for generating the
 531  
    * boosted classifier.
 532  
    * @throws Exception if the classifier could not be built successfully
 533  
    */
 534  
   public void buildClassifier(Instances train) throws Exception {
 535  
     // can classifier handle the data?
 536  0
     getCapabilities().testWithFail(train);
 537  
 
 538  
     // remove instances with missing class
 539  0
     train = new Instances(train);
 540  0
     train.deleteWithMissingClass();
 541  
     
 542  
     // Replace missing values        
 543  0
     m_ReplaceMissingValues = new ReplaceMissingValues();
 544  0
     m_ReplaceMissingValues.setInputFormat(train);
 545  0
     train = Filter.useFilter(train, m_ReplaceMissingValues);
 546  
 
 547  
     // Remove useless attributes
 548  0
     m_AttFilter = new RemoveUseless();
 549  0
     m_AttFilter.setInputFormat(train);
 550  0
     train = Filter.useFilter(train, m_AttFilter);
 551  
         
 552  
     // Transform attributes
 553  0
     m_NominalToBinary = new NominalToBinary();
 554  0
     m_NominalToBinary.setInputFormat(train);
 555  0
     train = Filter.useFilter(train, m_NominalToBinary);
 556  
     
 557  
     // Save the structure for printing the model
 558  0
     m_structure = new Instances(train, 0);
 559  
         
 560  
     // Extract data
 561  0
     m_ClassIndex = train.classIndex();
 562  0
     m_NumClasses = train.numClasses();
 563  
 
 564  0
     int nK = m_NumClasses - 1;                     // Only K-1 class labels needed 
 565  0
     int nR = m_NumPredictors = train.numAttributes() - 1;
 566  0
     int nC = train.numInstances();
 567  
         
 568  0
     m_Data = new double[nC][nR + 1];               // Data values
 569  0
     int [] Y  = new int[nC];                       // Class labels
 570  0
     double [] xMean= new double[nR + 1];           // Attribute means
 571  0
     double [] xSD  = new double[nR + 1];           // Attribute stddev's
 572  0
     double [] sY = new double[nK + 1];             // Number of classes
 573  0
     double [] weights = new double[nC];            // Weights of instances
 574  0
     double totWeights = 0;                         // Total weights of the instances
 575  0
     m_Par = new double[nR + 1][nK];                // Optimized parameter values
 576  
         
 577  0
     if (m_Debug) {
 578  0
       System.out.println("Extracting data...");
 579  
     }
 580  
         
 581  0
     for (int i = 0; i < nC; i++) {
 582  
       // initialize X[][]
 583  0
       Instance current = train.instance(i);
 584  0
       Y[i] = (int)current.classValue();  // Class value starts from 0
 585  0
       weights[i] = current.weight();     // Dealing with weights
 586  0
       totWeights += weights[i];
 587  
             
 588  0
       m_Data[i][0] = 1;
 589  0
       int j = 1;
 590  0
       for (int k = 0; k <= nR; k++) {
 591  0
         if (k != m_ClassIndex) {
 592  0
           double x = current.value(k);
 593  0
           m_Data[i][j] = x;
 594  0
           xMean[j] += weights[i]*x;
 595  0
           xSD[j] += weights[i]*x*x;
 596  0
           j++;
 597  
         }
 598  
       }
 599  
             
 600  
       // Class count
 601  0
       sY[Y[i]]++;        
 602  
     }
 603  
         
 604  0
     if((totWeights <= 1) && (nC > 1))
 605  0
       throw new Exception("Sum of weights of instances less than 1, please reweight!");
 606  
 
 607  0
     xMean[0] = 0; xSD[0] = 1;
 608  0
     for (int j = 1; j <= nR; j++) {
 609  0
       xMean[j] = xMean[j] / totWeights;
 610  0
       if(totWeights > 1)
 611  0
         xSD[j] = Math.sqrt(Math.abs(xSD[j] - totWeights*xMean[j]*xMean[j])/(totWeights-1));
 612  
       else
 613  0
         xSD[j] = 0;
 614  
     }
 615  
 
 616  0
     if (m_Debug) {            
 617  
       // Output stats about input data
 618  0
       System.out.println("Descriptives...");
 619  0
       for (int m = 0; m <= nK; m++)
 620  0
         System.out.println(sY[m] + " cases have class " + m);
 621  0
       System.out.println("\n Variable     Avg       SD    ");
 622  0
       for (int j = 1; j <= nR; j++) 
 623  0
         System.out.println(Utils.doubleToString(j,8,4) 
 624  
                            + Utils.doubleToString(xMean[j], 10, 4) 
 625  
                            + Utils.doubleToString(xSD[j], 10, 4)
 626  
                            );
 627  
     }
 628  
         
 629  
     // Normalise input data 
 630  0
     for (int i = 0; i < nC; i++) {
 631  0
       for (int j = 0; j <= nR; j++) {
 632  0
         if (xSD[j] != 0) {
 633  0
           m_Data[i][j] = (m_Data[i][j] - xMean[j]) / xSD[j];
 634  
         }
 635  
       }
 636  
     }
 637  
         
 638  0
     if (m_Debug) {
 639  0
       System.out.println("\nIteration History..." );
 640  
     }
 641  
         
 642  0
     double x[] = new double[(nR+1)*nK];
 643  0
     double[][] b = new double[2][x.length]; // Boundary constraints, N/A here
 644  
 
 645  
     // Initialize
 646  0
     for(int p=0; p<nK; p++){
 647  0
       int offset=p*(nR+1);         
 648  0
       x[offset] =  Math.log(sY[p]+1.0) - Math.log(sY[nK]+1.0); // Null model
 649  0
       b[0][offset] = Double.NaN;
 650  0
       b[1][offset] = Double.NaN;   
 651  0
       for (int q=1; q <= nR; q++){
 652  0
         x[offset+q] = 0.0;                
 653  0
         b[0][offset+q] = Double.NaN;
 654  0
         b[1][offset+q] = Double.NaN;
 655  
       }        
 656  
     }
 657  
         
 658  0
     OptEng opt = new OptEng();        
 659  0
     opt.setDebug(m_Debug);
 660  0
     opt.setWeights(weights);
 661  0
     opt.setClassLabels(Y);
 662  
 
 663  0
     if(m_MaxIts == -1){  // Search until convergence
 664  0
       x = opt.findArgmin(x, b);
 665  0
       while(x==null){
 666  0
         x = opt.getVarbValues();
 667  0
         if (m_Debug)
 668  0
           System.out.println("200 iterations finished, not enough!");
 669  0
         x = opt.findArgmin(x, b);
 670  
       }
 671  0
       if (m_Debug)
 672  0
         System.out.println(" -------------<Converged>--------------");
 673  
     }
 674  
     else{
 675  0
       opt.setMaxIteration(m_MaxIts);
 676  0
       x = opt.findArgmin(x, b);
 677  0
       if(x==null) // Not enough, but use the current value
 678  0
         x = opt.getVarbValues();
 679  
     }
 680  
         
 681  0
     m_LL = -opt.getMinFunction(); // Log-likelihood
 682  
 
 683  
     // Don't need data matrix anymore
 684  0
     m_Data = null;
 685  
             
 686  
     // Convert coefficients back to non-normalized attribute units
 687  0
     for(int i=0; i < nK; i++){
 688  0
       m_Par[0][i] = x[i*(nR+1)];
 689  0
       for(int j = 1; j <= nR; j++) {
 690  0
         m_Par[j][i] = x[i*(nR+1)+j];
 691  0
         if (xSD[j] != 0) {
 692  0
           m_Par[j][i] /= xSD[j];
 693  0
           m_Par[0][i] -= m_Par[j][i] * xMean[j];
 694  
         }
 695  
       }
 696  
     }
 697  0
   }                
 698  
     
 699  
   /**
 700  
    * Computes the distribution for a given instance
 701  
    *
 702  
    * @param instance the instance for which distribution is computed
 703  
    * @return the distribution
 704  
    * @throws Exception if the distribution can't be computed successfully
 705  
    */
 706  
   public double [] distributionForInstance(Instance instance) 
 707  
     throws Exception {
 708  
         
 709  0
     m_ReplaceMissingValues.input(instance);
 710  0
     instance = m_ReplaceMissingValues.output();
 711  0
     m_AttFilter.input(instance);
 712  0
     instance = m_AttFilter.output();
 713  0
     m_NominalToBinary.input(instance);
 714  0
     instance = m_NominalToBinary.output();
 715  
         
 716  
     // Extract the predictor columns into an array
 717  0
     double [] instDat = new double [m_NumPredictors + 1];
 718  0
     int j = 1;
 719  0
     instDat[0] = 1;
 720  0
     for (int k = 0; k <= m_NumPredictors; k++) {
 721  0
       if (k != m_ClassIndex) {
 722  0
         instDat[j++] = instance.value(k);
 723  
       }
 724  
     }
 725  
         
 726  0
     double [] distribution = evaluateProbability(instDat);
 727  0
     return distribution;
 728  
   }
 729  
 
 730  
   /**
 731  
    * Compute the posterior distribution using optimized parameter values
 732  
    * and the testing instance.
 733  
    * @param data the testing instance
 734  
    * @return the posterior probability distribution
 735  
    */ 
 736  
   private double[] evaluateProbability(double[] data){
 737  0
     double[] prob = new double[m_NumClasses],
 738  0
       v = new double[m_NumClasses];
 739  
 
 740  
     // Log-posterior before normalizing
 741  0
     for(int j = 0; j < m_NumClasses-1; j++){
 742  0
       for(int k = 0; k <= m_NumPredictors; k++){
 743  0
         v[j] += m_Par[k][j] * data[k];
 744  
       }
 745  
     }
 746  0
     v[m_NumClasses-1] = 0;
 747  
         
 748  
     // Do so to avoid scaling problems
 749  0
     for(int m=0; m < m_NumClasses; m++){
 750  0
       double sum = 0;
 751  0
       for(int n=0; n < m_NumClasses-1; n++)
 752  0
         sum += Math.exp(v[n] - v[m]);
 753  0
       prob[m] = 1 / (sum + Math.exp(-v[m]));
 754  
     }
 755  
         
 756  0
     return prob;
 757  
   } 
 758  
 
 759  
   /**
 760  
    * Returns the coefficients for this logistic model.
 761  
    * The first dimension indexes the attributes, and
 762  
    * the second the classes.
 763  
    * 
 764  
    * @return the coefficients for this logistic model
 765  
    */
 766  
   public double [][] coefficients() {
 767  0
     return m_Par;
 768  
   }
 769  
     
 770  
   /**
 771  
    * Gets a string describing the classifier.
 772  
    *
 773  
    * @return a string describing the classifer built.
 774  
    */
 775  
   public String toString() {
 776  0
     StringBuffer temp = new StringBuffer();
 777  
 
 778  0
     String result = "";
 779  0
     temp.append("Logistic Regression with ridge parameter of " + m_Ridge);
 780  0
     if (m_Par == null) {
 781  0
       return result + ": No model built yet.";
 782  
     }
 783  
 
 784  
     // find longest attribute name
 785  0
     int attLength = 0;
 786  0
     for (int i = 0; i < m_structure.numAttributes(); i++) {
 787  0
       if (i != m_structure.classIndex() && 
 788  
           m_structure.attribute(i).name().length() > attLength) {
 789  0
         attLength = m_structure.attribute(i).name().length();
 790  
       }
 791  
     }
 792  
 
 793  0
     if ("Intercept".length() > attLength) {
 794  0
       attLength = "Intercept".length();
 795  
     }
 796  
 
 797  0
     if ("Variable".length() > attLength) {
 798  0
       attLength = "Variable".length();
 799  
     }
 800  0
     attLength += 2;
 801  
 
 802  0
     int colWidth = 0;
 803  
     // check length of class names
 804  0
     for (int i = 0; i < m_structure.classAttribute().numValues() - 1; i++) {
 805  0
       if (m_structure.classAttribute().value(i).length() > colWidth) {
 806  0
         colWidth = m_structure.classAttribute().value(i).length();
 807  
       }
 808  
     }
 809  
 
 810  
     // check against coefficients and odds ratios
 811  0
     for (int j = 1; j <= m_NumPredictors; j++) {
 812  0
       for (int k = 0; k < m_NumClasses - 1; k++) {
 813  0
         if (Utils.doubleToString(m_Par[j][k], 12, 4).trim().length() > colWidth) {
 814  0
           colWidth = Utils.doubleToString(m_Par[j][k], 12, 4).trim().length();
 815  
         }
 816  0
         double ORc = Math.exp(m_Par[j][k]);
 817  0
         String t = " " + ((ORc > 1e10) ?  "" + ORc : Utils.doubleToString(ORc, 12, 4));
 818  0
         if (t.trim().length() > colWidth) {
 819  0
           colWidth = t.trim().length();
 820  
         }
 821  
       }
 822  
     }
 823  
 
 824  0
     if ("Class".length() > colWidth) {
 825  0
       colWidth = "Class".length();
 826  
     }
 827  0
     colWidth += 2;
 828  
     
 829  
     
 830  0
     temp.append("\nCoefficients...\n");
 831  0
     temp.append(Utils.padLeft(" ", attLength) + Utils.padLeft("Class", colWidth) + "\n");
 832  0
     temp.append(Utils.padRight("Variable", attLength));
 833  
 
 834  0
     for (int i = 0; i < m_NumClasses - 1; i++) {
 835  0
       String className = m_structure.classAttribute().value(i);
 836  0
       temp.append(Utils.padLeft(className, colWidth));
 837  
     }
 838  0
     temp.append("\n");
 839  0
     int separatorL = attLength + ((m_NumClasses - 1) * colWidth);
 840  0
     for (int i = 0; i < separatorL; i++) {
 841  0
       temp.append("=");
 842  
     }
 843  0
     temp.append("\n");
 844  
                 
 845  0
     int j = 1;
 846  0
     for (int i = 0; i < m_structure.numAttributes(); i++) {
 847  0
       if (i != m_structure.classIndex()) {
 848  0
         temp.append(Utils.padRight(m_structure.attribute(i).name(), attLength));
 849  0
         for (int k = 0; k < m_NumClasses-1; k++) {
 850  0
           temp.append(Utils.padLeft(Utils.doubleToString(m_Par[j][k], 12, 4).trim(), colWidth));
 851  
         }
 852  0
         temp.append("\n");
 853  0
         j++;
 854  
       }
 855  
     }
 856  
         
 857  0
     temp.append(Utils.padRight("Intercept", attLength));
 858  0
     for (int k = 0; k < m_NumClasses-1; k++) {
 859  0
       temp.append(Utils.padLeft(Utils.doubleToString(m_Par[0][k], 10, 4).trim(), colWidth)); 
 860  
     }
 861  0
     temp.append("\n");
 862  
         
 863  0
     temp.append("\n\nOdds Ratios...\n");
 864  0
     temp.append(Utils.padLeft(" ", attLength) + Utils.padLeft("Class", colWidth) + "\n");
 865  0
     temp.append(Utils.padRight("Variable", attLength));
 866  
 
 867  0
     for (int i = 0; i < m_NumClasses - 1; i++) {
 868  0
       String className = m_structure.classAttribute().value(i);
 869  0
       temp.append(Utils.padLeft(className, colWidth));
 870  
     }
 871  0
     temp.append("\n");
 872  0
     for (int i = 0; i < separatorL; i++) {
 873  0
       temp.append("=");
 874  
     }
 875  0
     temp.append("\n");
 876  
 
 877  0
     j = 1;
 878  0
     for (int i = 0; i < m_structure.numAttributes(); i++) {
 879  0
       if (i != m_structure.classIndex()) {
 880  0
         temp.append(Utils.padRight(m_structure.attribute(i).name(), attLength));
 881  0
         for (int k = 0; k < m_NumClasses-1; k++) {
 882  0
           double ORc = Math.exp(m_Par[j][k]);
 883  0
           String ORs = " " + ((ORc > 1e10) ?  "" + ORc : Utils.doubleToString(ORc, 12, 4));
 884  0
           temp.append(Utils.padLeft(ORs.trim(), colWidth));
 885  
         }
 886  0
         temp.append("\n");
 887  0
         j++;
 888  
       }
 889  
     }
 890  
 
 891  0
     return temp.toString();
 892  
   }
 893  
   
 894  
   /**
 895  
    * Returns the revision string.
 896  
    * 
 897  
    * @return                the revision
 898  
    */
 899  
   public String getRevision() {
 900  0
     return RevisionUtils.extract("$Revision: 8034 $");
 901  
   }
 902  
     
 903  
   /**
 904  
    * Main method for testing this class.
 905  
    *
 906  
    * @param argv should contain the command line arguments to the
 907  
    * scheme (see Evaluation)
 908  
    */
 909  
   public static void main(String [] argv) {
 910  0
     runClassifier(new Logistic(), argv);
 911  0
   }
 912  
 }