Coverage Report - weka.classifiers.bayes.net.search.local.TAN
 
Classes in this File Line Coverage Branch Coverage Complexity
TAN
0%
0/71
0%
0/54
5.143
 
 1  
 /*
 2  
  *   This program is free software: you can redistribute it and/or modify
 3  
  *   it under the terms of the GNU General Public License as published by
 4  
  *   the Free Software Foundation, either version 3 of the License, or
 5  
  *   (at your option) any later version.
 6  
  *
 7  
  *   This program is distributed in the hope that it will be useful,
 8  
  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 9  
  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 10  
  *   GNU General Public License for more details.
 11  
  *
 12  
  *   You should have received a copy of the GNU General Public License
 13  
  *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 14  
  */
 15  
 
 16  
 /*
 17  
  * TAN.java
 18  
  * Copyright (C) 2004-2012 University of Waikato, Hamilton, New Zealand
 19  
  * 
 20  
  */
 21  
 
 22  
 package weka.classifiers.bayes.net.search.local;
 23  
 
 24  
 import java.util.Enumeration;
 25  
 
 26  
 import weka.classifiers.bayes.BayesNet;
 27  
 import weka.core.Instances;
 28  
 import weka.core.RevisionUtils;
 29  
 import weka.core.TechnicalInformation;
 30  
 import weka.core.TechnicalInformation.Field;
 31  
 import weka.core.TechnicalInformation.Type;
 32  
 import weka.core.TechnicalInformationHandler;
 33  
 
 34  
 /** 
 35  
  <!-- globalinfo-start -->
 36  
  * This Bayes Network learning algorithm determines the maximum weight spanning tree  and returns a Naive Bayes network augmented with a tree.<br/>
 37  
  * <br/>
 38  
  * For more information see:<br/>
 39  
  * <br/>
 40  
  * N. Friedman, D. Geiger, M. Goldszmidt (1997). Bayesian network classifiers. Machine Learning. 29(2-3):131-163.
 41  
  * <p/>
 42  
  <!-- globalinfo-end -->
 43  
  * 
 44  
  <!-- technical-bibtex-start -->
 45  
  * BibTeX:
 46  
  * <pre>
 47  
  * &#64;article{Friedman1997,
 48  
  *    author = {N. Friedman and D. Geiger and M. Goldszmidt},
 49  
  *    journal = {Machine Learning},
 50  
  *    number = {2-3},
 51  
  *    pages = {131-163},
 52  
  *    title = {Bayesian network classifiers},
 53  
  *    volume = {29},
 54  
  *    year = {1997}
 55  
  * }
 56  
  * </pre>
 57  
  * <p/>
 58  
  <!-- technical-bibtex-end -->
 59  
  *
 60  
  <!-- options-start -->
 61  
  * Valid options are: <p/>
 62  
  * 
 63  
  * <pre> -mbc
 64  
  *  Applies a Markov Blanket correction to the network structure, 
 65  
  *  after a network structure is learned. This ensures that all 
 66  
  *  nodes in the network are part of the Markov blanket of the 
 67  
  *  classifier node.</pre>
 68  
  * 
 69  
  * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
 70  
  *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
 71  
  * 
 72  
  <!-- options-end -->
 73  
  *
 74  
  * @author Remco Bouckaert
 75  
  * @version $Revision: 8034 $
 76  
  */
 77  0
 public class TAN 
 78  
         extends LocalScoreSearchAlgorithm
 79  
         implements TechnicalInformationHandler {
 80  
   
 81  
           /** for serialization */
 82  
           static final long serialVersionUID = 965182127977228690L;
 83  
 
 84  
           /**
 85  
            * Returns an instance of a TechnicalInformation object, containing 
 86  
            * detailed information about the technical background of this class,
 87  
            * e.g., paper reference or book this class is based on.
 88  
            * 
 89  
            * @return the technical information about this class
 90  
            */
 91  
           public TechnicalInformation getTechnicalInformation() {
 92  
             TechnicalInformation         result;
 93  
             
 94  0
             result = new TechnicalInformation(Type.ARTICLE);
 95  0
             result.setValue(Field.AUTHOR, "N. Friedman and D. Geiger and M. Goldszmidt");
 96  0
             result.setValue(Field.YEAR, "1997");
 97  0
             result.setValue(Field.TITLE, "Bayesian network classifiers");
 98  0
             result.setValue(Field.JOURNAL, "Machine Learning");
 99  0
             result.setValue(Field.VOLUME, "29");
 100  0
             result.setValue(Field.NUMBER, "2-3");
 101  0
             result.setValue(Field.PAGES, "131-163");
 102  
             
 103  0
             return result;
 104  
           }
 105  
 
 106  
         /**
 107  
          * buildStructure determines the network structure/graph of the network
 108  
          * using the maximimum weight spanning tree algorithm of Chow and Liu
 109  
          * 
 110  
          * @param bayesNet the network
 111  
          * @param instances the data to use
 112  
          * @throws Exception if something goes wrong
 113  
          */
 114  
         public void buildStructure(BayesNet bayesNet, Instances instances) throws Exception {
 115  
           
 116  0
                 m_bInitAsNaiveBayes = true;
 117  0
                 m_nMaxNrOfParents = 2;
 118  0
                 super.buildStructure(bayesNet, instances);
 119  0
                 int      nNrOfAtts = instances.numAttributes();
 120  
                 
 121  0
                 if (nNrOfAtts <= 2) {
 122  0
                     return;
 123  
                 }
 124  
 
 125  
                 // determine base scores
 126  0
                 double[] fBaseScores = new double[instances.numAttributes()];
 127  
 
 128  0
                 for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
 129  0
                   fBaseScores[iAttribute] = calcNodeScore(iAttribute);
 130  
                 } 
 131  
 
 132  
                 //                // cache scores & whether adding an arc makes sense
 133  0
                 double[][]  fScore = new double[nNrOfAtts][nNrOfAtts];
 134  
 
 135  0
                 for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
 136  0
                         for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
 137  0
                                 if (iAttributeHead != iAttributeTail) {
 138  0
                                         fScore[iAttributeHead][iAttributeTail] = calcScoreWithExtraParent(iAttributeHead, iAttributeTail);
 139  
                                 }
 140  
                         } 
 141  
                 }
 142  
                 
 143  
                 // TAN greedy search (not restricted by ordering like K2)
 144  
                 // 1. find strongest link
 145  
                 // 2. find remaining links by adding strongest link to already
 146  
                 //    connected nodes
 147  
                 // 3. assign direction to links
 148  0
                 int nClassNode = instances.classIndex();
 149  0
                 int [] link1 = new int [nNrOfAtts - 1];
 150  0
                 int [] link2 = new int [nNrOfAtts - 1];
 151  0
                 boolean [] linked = new boolean [nNrOfAtts];
 152  
 
 153  
                 // 1. find strongest link
 154  0
                 int    nBestLinkNode1 = -1;
 155  0
                 int    nBestLinkNode2 = -1;
 156  0
                 double fBestDeltaScore = 0.0;
 157  
                 int iLinkNode1;
 158  0
                 for (iLinkNode1 = 0; iLinkNode1 < nNrOfAtts; iLinkNode1++) {
 159  0
                         if (iLinkNode1 != nClassNode) {
 160  0
                         for (int iLinkNode2 = 0; iLinkNode2 < nNrOfAtts; iLinkNode2++) {
 161  0
                                 if ((iLinkNode1 != iLinkNode2) &&
 162  
                                     (iLinkNode2 != nClassNode) && (
 163  
                                     (nBestLinkNode1 == -1) || (fScore[iLinkNode1][iLinkNode2] - fBaseScores[iLinkNode1] > fBestDeltaScore)
 164  
                                 )) {
 165  0
                                         fBestDeltaScore = fScore[iLinkNode1][iLinkNode2] - fBaseScores[iLinkNode1];
 166  0
                                         nBestLinkNode1 = iLinkNode2;
 167  0
                                         nBestLinkNode2 = iLinkNode1;
 168  
                             } 
 169  
                         } 
 170  
                         }
 171  
                 }
 172  0
                 link1[0] = nBestLinkNode1;
 173  0
                 link2[0] = nBestLinkNode2;
 174  0
                 linked[nBestLinkNode1] = true;
 175  0
                 linked[nBestLinkNode2] = true;
 176  
         
 177  
                 // 2. find remaining links by adding strongest link to already
 178  
                 //    connected nodes
 179  0
                 for (int iLink = 1; iLink < nNrOfAtts - 2; iLink++) {
 180  0
                         nBestLinkNode1 = -1;
 181  0
                         for (iLinkNode1 = 0; iLinkNode1 < nNrOfAtts; iLinkNode1++) {
 182  0
                                 if (iLinkNode1 != nClassNode) {
 183  0
                                 for (int iLinkNode2 = 0; iLinkNode2 < nNrOfAtts; iLinkNode2++) {
 184  0
                                         if ((iLinkNode1 != iLinkNode2) &&
 185  
                                             (iLinkNode2 != nClassNode) && 
 186  
                                         (linked[iLinkNode1] || linked[iLinkNode2]) &&
 187  
                                         (!linked[iLinkNode1] || !linked[iLinkNode2]) &&
 188  
                                         (
 189  
                                         (nBestLinkNode1 == -1) || (fScore[iLinkNode1][iLinkNode2] - fBaseScores[iLinkNode1] > fBestDeltaScore)
 190  
                                         )) {
 191  0
                                                 fBestDeltaScore = fScore[iLinkNode1][iLinkNode2] - fBaseScores[iLinkNode1];
 192  0
                                                 nBestLinkNode1 = iLinkNode2;
 193  0
                                                 nBestLinkNode2 = iLinkNode1;
 194  
                                         } 
 195  
                                 } 
 196  
                                 }
 197  
                         }
 198  
 
 199  0
                         link1[iLink] = nBestLinkNode1;
 200  0
                         link2[iLink] = nBestLinkNode2;
 201  0
                         linked[nBestLinkNode1] = true;
 202  0
                         linked[nBestLinkNode2] = true;
 203  
                 }
 204  
                 
 205  
                 // 3. assign direction to links
 206  0
                 boolean [] hasParent = new boolean [nNrOfAtts];
 207  0
                 for (int iLink = 0; iLink < nNrOfAtts - 2; iLink++) {
 208  0
                         if (!hasParent[link1[iLink]]) {
 209  0
                                 bayesNet.getParentSet(link1[iLink]).addParent(link2[iLink], instances);
 210  0
                                 hasParent[link1[iLink]] = true;
 211  
                         } else {
 212  0
                                 if (hasParent[link2[iLink]]) {
 213  0
                                         throw new Exception("Bug condition found: too many arrows");
 214  
                                 }
 215  0
                                 bayesNet.getParentSet(link2[iLink]).addParent(link1[iLink], instances);
 216  0
                                 hasParent[link2[iLink]] = true;
 217  
                         }
 218  
                 }
 219  
 
 220  0
         } // buildStructure
 221  
 
 222  
 
 223  
         /**
 224  
          * Returns an enumeration describing the available options.
 225  
          *
 226  
          * @return an enumeration of all the available options.
 227  
          */
 228  
         public Enumeration listOptions() {
 229  0
                 return super.listOptions();
 230  
         } // listOption
 231  
 
 232  
         /**
 233  
          * Parses a given list of options. <p/>
 234  
          *
 235  
          <!-- options-start -->
 236  
          * Valid options are: <p/>
 237  
          * 
 238  
          * <pre> -mbc
 239  
          *  Applies a Markov Blanket correction to the network structure, 
 240  
          *  after a network structure is learned. This ensures that all 
 241  
          *  nodes in the network are part of the Markov blanket of the 
 242  
          *  classifier node.</pre>
 243  
          * 
 244  
          * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
 245  
          *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
 246  
          * 
 247  
          <!-- options-end -->
 248  
          * 
 249  
          * @param options the list of options as an array of strings
 250  
          * @throws Exception if an option is not supported
 251  
          */
 252  
         public void setOptions(String[] options) throws Exception {
 253  0
                 super.setOptions(options);
 254  0
         } // setOptions
 255  
         
 256  
         /**
 257  
          * Gets the current settings of the Classifier.
 258  
          *
 259  
          * @return an array of strings suitable for passing to setOptions
 260  
          */
 261  
         public String [] getOptions() {
 262  0
                 return super.getOptions();
 263  
         } // getOptions
 264  
 
 265  
         /**
 266  
          * This will return a string describing the classifier.
 267  
          * @return The string.
 268  
          */
 269  
         public String globalInfo() {
 270  0
                 return 
 271  
                     "This Bayes Network learning algorithm determines the maximum weight spanning tree "
 272  
                   + " and returns a Naive Bayes network augmented with a tree.\n\n"
 273  
                   + "For more information see:\n\n"
 274  
                   + getTechnicalInformation().toString();
 275  
         } // globalInfo
 276  
 
 277  
         /**
 278  
          * Returns the revision string.
 279  
          * 
 280  
          * @return                the revision
 281  
          */
 282  
         public String getRevision() {
 283  0
           return RevisionUtils.extract("$Revision: 8034 $");
 284  
         }
 285  
 
 286  
 } // TAN
 287