Coverage Report - weka.classifiers.bayes.net.search.global.HillClimber
 
Classes in this File Line Coverage Branch Coverage Complexity
HillClimber
0%
0/115
0%
0/64
2.48
HillClimber$Operation
0%
0/12
0%
0/8
2.48
 
 1  
 /*
 2  
  *   This program is free software: you can redistribute it and/or modify
 3  
  *   it under the terms of the GNU General Public License as published by
 4  
  *   the Free Software Foundation, either version 3 of the License, or
 5  
  *   (at your option) any later version.
 6  
  *
 7  
  *   This program is distributed in the hope that it will be useful,
 8  
  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 9  
  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 10  
  *   GNU General Public License for more details.
 11  
  *
 12  
  *   You should have received a copy of the GNU General Public License
 13  
  *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 14  
  */
 15  
 
 16  
 /*
 17  
  * HillClimber.java
 18  
  * Copyright (C) 2004-2012 University of Waikato, Hamilton, New Zealand
 19  
  * 
 20  
  */
 21  
  
 22  
 package weka.classifiers.bayes.net.search.global;
 23  
 
 24  
 import java.io.Serializable;
 25  
 import java.util.Enumeration;
 26  
 import java.util.Vector;
 27  
 
 28  
 import weka.classifiers.bayes.BayesNet;
 29  
 import weka.classifiers.bayes.net.ParentSet;
 30  
 import weka.core.Instances;
 31  
 import weka.core.Option;
 32  
 import weka.core.RevisionHandler;
 33  
 import weka.core.RevisionUtils;
 34  
 import weka.core.Utils;
 35  
 
 36  
 /** 
 37  
  <!-- globalinfo-start -->
 38  
  * This Bayes Network learning algorithm uses a hill climbing algorithm adding, deleting and reversing arcs. The search is not restricted by an order on the variables (unlike K2). The difference with B and B2 is that this hill climber also considers arrows part of the naive Bayes structure for deletion.
 39  
  * <p/>
 40  
  <!-- globalinfo-end -->
 41  
  *
 42  
  <!-- options-start -->
 43  
  * Valid options are: <p/>
 44  
  * 
 45  
  * <pre> -P &lt;nr of parents&gt;
 46  
  *  Maximum number of parents</pre>
 47  
  * 
 48  
  * <pre> -R
 49  
  *  Use arc reversal operation.
 50  
  *  (default false)</pre>
 51  
  * 
 52  
  * <pre> -N
 53  
  *  Initial structure is empty (instead of Naive Bayes)</pre>
 54  
  * 
 55  
  * <pre> -mbc
 56  
  *  Applies a Markov Blanket correction to the network structure, 
 57  
  *  after a network structure is learned. This ensures that all 
 58  
  *  nodes in the network are part of the Markov blanket of the 
 59  
  *  classifier node.</pre>
 60  
  * 
 61  
  * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
 62  
  *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
 63  
  * 
 64  
  * <pre> -Q
 65  
  *  Use probabilistic or 0/1 scoring.
 66  
  *  (default probabilistic scoring)</pre>
 67  
  * 
 68  
  <!-- options-end -->
 69  
  * 
 70  
  * @author Remco Bouckaert (rrb@xm.co.nz)
 71  
  * @version $Revision: 8034 $
 72  
  */
 73  0
 public class HillClimber 
 74  
     extends GlobalScoreSearchAlgorithm {
 75  
 
 76  
     /** for serialization */
 77  
     static final long serialVersionUID = -3885042888195820149L;
 78  
   
 79  
   /** 
 80  
    * the Operation class contains info on operations performed
 81  
    * on the current Bayesian network.
 82  
    */
 83  
     class Operation 
 84  
             implements Serializable, RevisionHandler {
 85  
       
 86  
               /** for serialization */
 87  
         static final long serialVersionUID = -2934970456587374967L;
 88  
       
 89  
             // constants indicating the type of an operation
 90  
             final static int OPERATION_ADD = 0;
 91  
             final static int OPERATION_DEL = 1;
 92  
             final static int OPERATION_REVERSE = 2;
 93  
 
 94  
             /** c'tor **/
 95  0
         public Operation() {
 96  0
         }
 97  
         
 98  
                 /** c'tor + initializers
 99  
                  * 
 100  
                  * @param nTail
 101  
                  * @param nHead
 102  
                  * @param nOperation
 103  
                  */ 
 104  0
             public Operation(int nTail, int nHead, int nOperation) {
 105  0
                         m_nHead = nHead;
 106  0
                         m_nTail = nTail;
 107  0
                         m_nOperation = nOperation;
 108  0
                 }
 109  
                 /** compare this operation with another
 110  
                  * @param other operation to compare with
 111  
                  * @return true if operation is the same
 112  
                  */
 113  
                 public boolean equals(Operation other) {
 114  0
                         if (other == null) {
 115  0
                                 return false;
 116  
                         }
 117  0
                         return ((        m_nOperation == other.m_nOperation) &&
 118  
                         (m_nHead == other.m_nHead) &&
 119  
                         (m_nTail == other.m_nTail));
 120  
                 } // equals
 121  
                 /** number of the tail node **/
 122  
         public int m_nTail;
 123  
                 /** number of the head node **/
 124  
         public int m_nHead;
 125  
                 /** type of operation (ADD, DEL, REVERSE) **/
 126  
         public int m_nOperation;
 127  
         /** change of score due to this operation **/
 128  0
         public double m_fScore = -1E100;
 129  
         
 130  
         /**
 131  
          * Returns the revision string.
 132  
          * 
 133  
          * @return                the revision
 134  
          */
 135  
         public String getRevision() {
 136  0
           return RevisionUtils.extract("$Revision: 8034 $");
 137  
         }
 138  
     } // class Operation
 139  
         
 140  
     /** use the arc reversal operator **/
 141  0
     boolean m_bUseArcReversal = false;
 142  
 
 143  
     /**
 144  
      * search determines the network structure/graph of the network
 145  
      * with the Taby algorithm.
 146  
      * 
 147  
      * @param bayesNet the network to search
 148  
      * @param instances the instances to work with
 149  
      * @throws Exception if something goes wrong
 150  
      */
 151  
     protected void search(BayesNet bayesNet, Instances instances) throws Exception {
 152  0
             m_BayesNet = bayesNet;
 153  0
                 double fScore = calcScore(bayesNet);
 154  
         // go do the search        
 155  0
                 Operation oOperation = getOptimalOperation(bayesNet, instances);
 156  0
                 while ((oOperation != null) && (oOperation.m_fScore > fScore)) {
 157  0
                         performOperation(bayesNet, instances, oOperation);
 158  0
                         fScore = oOperation.m_fScore;
 159  0
                         oOperation = getOptimalOperation(bayesNet, instances);
 160  
         }        
 161  0
     } // search
 162  
 
 163  
 
 164  
 
 165  
         /** check whether the operation is not in the forbidden.
 166  
          * For base hill climber, there are no restrictions on operations,
 167  
          * so we always return true.
 168  
          * @param oOperation operation to be checked
 169  
          * @return true if operation is not in the tabu list
 170  
          */
 171  
         boolean isNotTabu(Operation oOperation) {
 172  0
                 return true;
 173  
         } // isNotTabu
 174  
 
 175  
         /** 
 176  
          * getOptimalOperation finds the optimal operation that can be performed
 177  
          * on the Bayes network that is not in the tabu list.
 178  
          * 
 179  
          * @param bayesNet Bayes network to apply operation on
 180  
          * @param instances data set to learn from
 181  
          * @return optimal operation found
 182  
          * @throws Exception if something goes wrong
 183  
          */
 184  
     Operation getOptimalOperation(BayesNet bayesNet, Instances instances) throws Exception {
 185  0
         Operation oBestOperation = new Operation();
 186  
 
 187  
                 // Add???
 188  0
                 oBestOperation = findBestArcToAdd(bayesNet, instances, oBestOperation);
 189  
                 // Delete???
 190  0
                 oBestOperation = findBestArcToDelete(bayesNet, instances, oBestOperation);
 191  
                 // Reverse???
 192  0
                 if (getUseArcReversal()) {
 193  0
                         oBestOperation = findBestArcToReverse(bayesNet, instances, oBestOperation);
 194  
                 }
 195  
 
 196  
                 // did we find something?
 197  0
                 if (oBestOperation.m_fScore == -1E100) {
 198  0
                         return null;
 199  
                 }
 200  
 
 201  0
         return oBestOperation;
 202  
     } // getOptimalOperation
 203  
 
 204  
         /** performOperation applies an operation 
 205  
          * on the Bayes network and update the cache.
 206  
          * 
 207  
          * @param bayesNet Bayes network to apply operation on
 208  
          * @param instances data set to learn from
 209  
          * @param oOperation operation to perform
 210  
          * @throws Exception if something goes wrong
 211  
          */
 212  
         void performOperation(BayesNet bayesNet, Instances instances, Operation oOperation) throws Exception {
 213  
                 // perform operation
 214  0
                 switch (oOperation.m_nOperation) {
 215  
                         case Operation.OPERATION_ADD:
 216  0
                                 applyArcAddition(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
 217  0
                                 if (bayesNet.getDebug()) {
 218  0
                                         System.out.print("Add " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
 219  
                                 }
 220  
                                 break;
 221  
                         case Operation.OPERATION_DEL:
 222  0
                                 applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
 223  0
                                 if (bayesNet.getDebug()) {
 224  0
                                         System.out.print("Del " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
 225  
                                 }
 226  
                                 break;
 227  
                         case Operation.OPERATION_REVERSE:
 228  0
                                 applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
 229  0
                                 applyArcAddition(bayesNet, oOperation.m_nTail, oOperation.m_nHead, instances);
 230  0
                                 if (bayesNet.getDebug()) {
 231  0
                                         System.out.print("Rev " + oOperation.m_nHead+ " -> " + oOperation.m_nTail);
 232  
                                 }
 233  
                                 break;
 234  
                 }
 235  0
         } // performOperation
 236  
 
 237  
         /**
 238  
          * 
 239  
          * @param bayesNet
 240  
          * @param iHead
 241  
          * @param iTail
 242  
          * @param instances
 243  
          */
 244  
         void applyArcAddition(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
 245  0
                 ParentSet bestParentSet = bayesNet.getParentSet(iHead);
 246  0
                 bestParentSet.addParent(iTail, instances);
 247  0
         } // applyArcAddition
 248  
 
 249  
         /**
 250  
          * 
 251  
          * @param bayesNet
 252  
          * @param iHead
 253  
          * @param iTail
 254  
          * @param instances
 255  
          */
 256  
         void applyArcDeletion(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
 257  0
                 ParentSet bestParentSet = bayesNet.getParentSet(iHead);
 258  0
                 bestParentSet.deleteParent(iTail, instances);
 259  0
         } // applyArcAddition
 260  
 
 261  
 
 262  
         /** 
 263  
          * find best (or least bad) arc addition operation
 264  
          * 
 265  
          * @param bayesNet Bayes network to add arc to
 266  
          * @param instances data set
 267  
          * @param oBestOperation
 268  
          * @return Operation containing best arc to add, or null if no arc addition is allowed 
 269  
          * (this can happen if any arc addition introduces a cycle, or all parent sets are filled
 270  
          * up to the maximum nr of parents).
 271  
          * @throws Exception if something goes wrong
 272  
          */
 273  
         Operation findBestArcToAdd(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
 274  0
                 int nNrOfAtts = instances.numAttributes();
 275  
                 // find best arc to add
 276  0
                 for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
 277  0
                         if (bayesNet.getParentSet(iAttributeHead).getNrOfParents() < m_nMaxNrOfParents) {
 278  0
                                 for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
 279  0
                                         if (addArcMakesSense(bayesNet, instances, iAttributeHead, iAttributeTail)) {
 280  0
                                                 Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD);
 281  0
                                                 double fScore = calcScoreWithExtraParent(oOperation.m_nHead, oOperation.m_nTail);
 282  0
                                                 if (fScore > oBestOperation.m_fScore) {
 283  0
                                                         if (isNotTabu(oOperation)) {
 284  0
                                                                 oBestOperation = oOperation;
 285  0
                                                                 oBestOperation.m_fScore = fScore;
 286  
                                                         }
 287  
                                                 }
 288  
                                         }
 289  
                                 }
 290  
                         }
 291  
                 }
 292  0
                 return oBestOperation;
 293  
         } // findBestArcToAdd
 294  
 
 295  
         /** 
 296  
          * find best (or least bad) arc deletion operation
 297  
          * 
 298  
          * @param bayesNet Bayes network to delete arc from
 299  
          * @param instances data set
 300  
          * @param oBestOperation
 301  
          * @return Operation containing best arc to delete, or null if no deletion can be made 
 302  
          * (happens when there is no arc in the network yet).
 303  
          * @throws Exception of something goes wrong
 304  
          */
 305  
         Operation findBestArcToDelete(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
 306  0
                 int nNrOfAtts = instances.numAttributes();
 307  
                 // find best arc to delete
 308  0
                 for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
 309  0
                         ParentSet parentSet = bayesNet.getParentSet(iNode);
 310  0
                         for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
 311  0
                                 Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_DEL);
 312  0
                                 double fScore = calcScoreWithMissingParent(oOperation.m_nHead, oOperation.m_nTail);
 313  0
                                 if (fScore > oBestOperation.m_fScore) {
 314  0
                                         if (isNotTabu(oOperation)) {
 315  0
                                                 oBestOperation = oOperation;
 316  0
                                                 oBestOperation.m_fScore = fScore;
 317  
                                         }
 318  
                                 }
 319  
                         }
 320  
                 }
 321  0
                 return oBestOperation;
 322  
         } // findBestArcToDelete
 323  
 
 324  
         /** 
 325  
          * find best (or least bad) arc reversal operation
 326  
          * 
 327  
          * @param bayesNet Bayes network to reverse arc in
 328  
          * @param instances data set
 329  
          * @param oBestOperation
 330  
          * @return Operation containing best arc to reverse, or null if no reversal is allowed
 331  
          * (happens if there is no arc in the network yet, or when any such reversal introduces
 332  
          * a cycle).
 333  
          * @throws Exception if something goes wrong
 334  
          */
 335  
         Operation findBestArcToReverse(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
 336  0
                 int nNrOfAtts = instances.numAttributes();
 337  
                 // find best arc to reverse
 338  0
                 for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
 339  0
                         ParentSet parentSet = bayesNet.getParentSet(iNode);
 340  0
                         for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
 341  0
                                 int iTail = parentSet.getParent(iParent);
 342  
                                 // is reversal allowed?
 343  0
                                 if (reverseArcMakesSense(bayesNet, instances, iNode, iTail) && 
 344  
                                     bayesNet.getParentSet(iTail).getNrOfParents() < m_nMaxNrOfParents) {
 345  
                                         // go check if reversal results in the best step forward
 346  0
                                         Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_REVERSE);
 347  0
                                         double fScore = calcScoreWithReversedParent(oOperation.m_nHead, oOperation.m_nTail);
 348  0
                                         if (fScore > oBestOperation.m_fScore) {
 349  0
                                                 if (isNotTabu(oOperation)) {
 350  0
                                                         oBestOperation = oOperation;
 351  0
                                                         oBestOperation.m_fScore = fScore;
 352  
                                                 }
 353  
                                         }
 354  
                                 }
 355  
                         }
 356  
                 }
 357  0
                 return oBestOperation;
 358  
         } // findBestArcToReverse
 359  
         
 360  
 
 361  
         /**
 362  
          * Sets the max number of parents
 363  
          *
 364  
          * @param nMaxNrOfParents the max number of parents
 365  
          */
 366  
         public void setMaxNrOfParents(int nMaxNrOfParents) {
 367  0
           m_nMaxNrOfParents = nMaxNrOfParents;
 368  0
         } 
 369  
 
 370  
         /**
 371  
          * Gets the max number of parents.
 372  
          *
 373  
          * @return the max number of parents
 374  
          */
 375  
         public int getMaxNrOfParents() {
 376  0
           return m_nMaxNrOfParents;
 377  
         } 
 378  
 
 379  
         /**
 380  
          * Returns an enumeration describing the available options.
 381  
          *
 382  
          * @return an enumeration of all the available options.
 383  
          */
 384  
         public Enumeration listOptions() {
 385  0
                 Vector newVector = new Vector(2);
 386  
 
 387  0
                 newVector.addElement(new Option("\tMaximum number of parents", "P", 1, "-P <nr of parents>"));
 388  0
                 newVector.addElement(new Option("\tUse arc reversal operation.\n\t(default false)", "R", 0, "-R"));
 389  0
                 newVector.addElement(new Option("\tInitial structure is empty (instead of Naive Bayes)", "N", 0, "-N"));
 390  
 
 391  0
                 Enumeration enu = super.listOptions();
 392  0
                 while (enu.hasMoreElements()) {
 393  0
                         newVector.addElement(enu.nextElement());
 394  
                 }
 395  0
                 return newVector.elements();
 396  
         } // listOptions
 397  
 
 398  
         /**
 399  
          * Parses a given list of options. <p/>
 400  
          *
 401  
          <!-- options-start -->
 402  
          * Valid options are: <p/>
 403  
          * 
 404  
          * <pre> -P &lt;nr of parents&gt;
 405  
          *  Maximum number of parents</pre>
 406  
          * 
 407  
          * <pre> -R
 408  
          *  Use arc reversal operation.
 409  
          *  (default false)</pre>
 410  
          * 
 411  
          * <pre> -N
 412  
          *  Initial structure is empty (instead of Naive Bayes)</pre>
 413  
          * 
 414  
          * <pre> -mbc
 415  
          *  Applies a Markov Blanket correction to the network structure, 
 416  
          *  after a network structure is learned. This ensures that all 
 417  
          *  nodes in the network are part of the Markov blanket of the 
 418  
          *  classifier node.</pre>
 419  
          * 
 420  
          * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
 421  
          *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
 422  
          * 
 423  
          * <pre> -Q
 424  
          *  Use probabilistic or 0/1 scoring.
 425  
          *  (default probabilistic scoring)</pre>
 426  
          * 
 427  
          <!-- options-end -->
 428  
          *
 429  
          * @param options the list of options as an array of strings
 430  
          * @throws Exception if an option is not supported
 431  
          */
 432  
         public void setOptions(String[] options) throws Exception {
 433  0
                 setUseArcReversal(Utils.getFlag('R', options));
 434  
 
 435  0
                 setInitAsNaiveBayes (!(Utils.getFlag('N', options)));
 436  
                 
 437  0
                 String sMaxNrOfParents = Utils.getOption('P', options);
 438  0
                 if (sMaxNrOfParents.length() != 0) {
 439  0
                   setMaxNrOfParents(Integer.parseInt(sMaxNrOfParents));
 440  
                 } else {
 441  0
                   setMaxNrOfParents(100000);
 442  
                 }
 443  
                 
 444  0
                 super.setOptions(options);
 445  0
         } // setOptions
 446  
 
 447  
         /**
 448  
          * Gets the current settings of the search algorithm.
 449  
          *
 450  
          * @return an array of strings suitable for passing to setOptions
 451  
          */
 452  
         public String[] getOptions() {
 453  0
                 String[] superOptions = super.getOptions();
 454  0
                 String[] options = new String[7 + superOptions.length];
 455  0
                 int current = 0;
 456  0
                 if (getUseArcReversal()) {
 457  0
                   options[current++] = "-R";
 458  
                 }
 459  
                 
 460  0
                 if (!getInitAsNaiveBayes()) {
 461  0
                   options[current++] = "-N";
 462  
                 } 
 463  
 
 464  0
                 options[current++] = "-P";
 465  0
                 options[current++] = "" + m_nMaxNrOfParents;
 466  
 
 467  
                 // insert options from parent class
 468  0
                 for (int iOption = 0; iOption < superOptions.length; iOption++) {
 469  0
                         options[current++] = superOptions[iOption];
 470  
                 }
 471  
 
 472  
                 // Fill up rest with empty strings, not nulls!
 473  0
                 while (current < options.length) {
 474  0
                         options[current++] = "";
 475  
                 }
 476  0
                 return options;
 477  
         } // getOptions
 478  
 
 479  
         /**
 480  
          * Sets whether to init as naive bayes
 481  
          *
 482  
          * @param bInitAsNaiveBayes whether to init as naive bayes
 483  
          */
 484  
         public void setInitAsNaiveBayes(boolean bInitAsNaiveBayes) {
 485  0
           m_bInitAsNaiveBayes = bInitAsNaiveBayes;
 486  0
         } 
 487  
 
 488  
         /**
 489  
          * Gets whether to init as naive bayes
 490  
          *
 491  
          * @return whether to init as naive bayes
 492  
          */
 493  
         public boolean getInitAsNaiveBayes() {
 494  0
           return m_bInitAsNaiveBayes;
 495  
         } 
 496  
 
 497  
         /** get use the arc reversal operation
 498  
          * @return whether the arc reversal operation should be used
 499  
          */
 500  
         public boolean getUseArcReversal() {
 501  0
                 return m_bUseArcReversal;
 502  
         } // getUseArcReversal
 503  
 
 504  
         /** set use the arc reversal operation
 505  
          * @param bUseArcReversal whether the arc reversal operation should be used
 506  
          */
 507  
         public void setUseArcReversal(boolean bUseArcReversal) {
 508  0
                 m_bUseArcReversal = bUseArcReversal;
 509  0
         } // setUseArcReversal
 510  
 
 511  
         /**
 512  
          * This will return a string describing the search algorithm.
 513  
          * @return The string.
 514  
          */
 515  
         public String globalInfo() {
 516  0
           return "This Bayes Network learning algorithm uses a hill climbing algorithm " +
 517  
           "adding, deleting and reversing arcs. The search is not restricted by an order " +
 518  
           "on the variables (unlike K2). The difference with B and B2 is that this hill " +          
 519  
           "climber also considers arrows part of the naive Bayes structure for deletion.";
 520  
         } // globalInfo
 521  
 
 522  
         /**
 523  
          * @return a string to describe the Use Arc Reversal option.
 524  
          */
 525  
         public String useArcReversalTipText() {
 526  0
           return "When set to true, the arc reversal operation is used in the search.";
 527  
         } // useArcReversalTipText
 528  
 
 529  
         /**
 530  
          * Returns the revision string.
 531  
          * 
 532  
          * @return                the revision
 533  
          */
 534  
         public String getRevision() {
 535  0
           return RevisionUtils.extract("$Revision: 8034 $");
 536  
         }
 537  
 } // HillClimber