Coverage Report - weka.classifiers.functions.neural.NeuralNode
 
Classes in this File Line Coverage Branch Coverage Complexity
NeuralNode
0%
0/94
0%
0/48
2.529
 
 1  
 /*
 2  
  *   This program is free software: you can redistribute it and/or modify
 3  
  *   it under the terms of the GNU General Public License as published by
 4  
  *   the Free Software Foundation, either version 3 of the License, or
 5  
  *   (at your option) any later version.
 6  
  *
 7  
  *   This program is distributed in the hope that it will be useful,
 8  
  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 9  
  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 10  
  *   GNU General Public License for more details.
 11  
  *
 12  
  *   You should have received a copy of the GNU General Public License
 13  
  *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 14  
  */
 15  
 
 16  
 /*
 17  
  *    NeuralNode.java
 18  
  *    Copyright (C) 2000-2012 University of Waikato, Hamilton, New Zealand
 19  
  */
 20  
 
 21  
 package weka.classifiers.functions.neural;
 22  
 
 23  
 import weka.core.RevisionUtils;
 24  
 
 25  
 import java.util.Random;
 26  
 
 27  
 /**
 28  
  * This class is used to represent a node in the neuralnet.
 29  
  * 
 30  
  * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
 31  
  * @version $Revision: 8034 $
 32  
  */
 33  
 public class NeuralNode
 34  
   extends NeuralConnection {
 35  
 
 36  
   /** for serialization */
 37  
   private static final long serialVersionUID = -1085750607680839163L;
 38  
     
 39  
   /** The weights for each of the input connections, and the threshold. */
 40  
   private double[] m_weights;
 41  
   
 42  
   /** The best (lowest error) weights. Only used when validation set is used */
 43  
   private double[] m_bestWeights;
 44  
   
 45  
   /** The change in the weights. */
 46  
   private double[] m_changeInWeights;
 47  
   
 48  
   private Random m_random;
 49  
 
 50  
   /** Performs the operations for this node. Currently this
 51  
    * defines that the node is either a sigmoid or a linear unit. */
 52  
   private NeuralMethod m_methods;
 53  
 
 54  
   /** 
 55  
    * @param id The string name for this node (used to id this node).
 56  
    * @param r A random number generator used to generate initial weights.
 57  
    * @param m The methods this node should use to update.
 58  
    */
 59  
   public NeuralNode(String id, Random r, NeuralMethod m) {
 60  0
     super(id);
 61  0
     m_weights = new double[1];
 62  0
     m_bestWeights = new double[1];
 63  0
     m_changeInWeights = new double[1];
 64  
     
 65  0
     m_random = r;
 66  
     
 67  0
     m_weights[0] = m_random.nextDouble() * .1 - .05;
 68  0
     m_changeInWeights[0] = 0;
 69  
 
 70  0
     m_methods = m;
 71  0
   }
 72  
   
 73  
   /**
 74  
    * Set how this node should operate (note that the neural method has no
 75  
    * internal state, so the same object can be used by any number of nodes.
 76  
    * @param m The new method.
 77  
    */
 78  
   public void setMethod(NeuralMethod m) {
 79  0
     m_methods = m;
 80  0
   } 
 81  
 
 82  
   public NeuralMethod getMethod() {
 83  0
     return m_methods;
 84  
   }
 85  
 
 86  
   /**
 87  
    * Call this to get the output value of this unit. 
 88  
    * @param calculate True if the value should be calculated if it hasn't been
 89  
    * already.
 90  
    * @return The output value, or NaN, if the value has not been calculated.
 91  
    */
 92  
   public double outputValue(boolean calculate) {
 93  
     
 94  0
     if (Double.isNaN(m_unitValue) && calculate) {
 95  
       //then calculate the output value;
 96  0
       m_unitValue = m_methods.outputValue(this);
 97  
     }
 98  
     
 99  0
     return m_unitValue;
 100  
   }
 101  
 
 102  
   
 103  
   /**
 104  
    * Call this to get the error value of this unit.
 105  
    * @param calculate True if the value should be calculated if it hasn't been
 106  
    * already.
 107  
    * @return The error value, or NaN, if the value has not been calculated.
 108  
    */
 109  
   public double errorValue(boolean calculate) {
 110  
 
 111  0
     if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) {
 112  
       //then calculate the error.
 113  0
       m_unitError = m_methods.errorValue(this);
 114  
     }
 115  0
     return m_unitError;
 116  
   }
 117  
 
 118  
   /**
 119  
    * Call this to reset the value and error for this unit, ready for the next
 120  
    * run. This will also call the reset function of all units that are 
 121  
    * connected as inputs to this one.
 122  
    * This is also the time that the update for the listeners will be performed.
 123  
    */
 124  
   public void reset() {
 125  
     
 126  0
     if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
 127  0
       m_unitValue = Double.NaN;
 128  0
       m_unitError = Double.NaN;
 129  0
       m_weightsUpdated = false;
 130  0
       for (int noa = 0; noa < m_numInputs; noa++) {
 131  0
         m_inputList[noa].reset();
 132  
       }
 133  
     }
 134  0
   }
 135  
   
 136  
   /**
 137  
    * Call this to have the connection save the current
 138  
    * weights.
 139  
    */
 140  
   public void saveWeights() {
 141  
     // copy the current weights
 142  0
     System.arraycopy(m_weights, 0, m_bestWeights, 0, m_weights.length);
 143  
     
 144  
     // tell inputs to save weights
 145  0
     for (int i = 0; i < m_numInputs; i++) {
 146  0
       m_inputList[i].saveWeights();
 147  
     }
 148  0
   }
 149  
   
 150  
   /**
 151  
    * Call this to have the connection restore from the saved
 152  
    * weights.
 153  
    */
 154  
   public void restoreWeights() {
 155  
     // copy the saved best weights back into the weights
 156  0
     System.arraycopy(m_bestWeights, 0, m_weights, 0, m_weights.length);
 157  
     
 158  
     // tell inputs to restore weights
 159  0
     for (int i = 0; i < m_numInputs; i++) {
 160  0
       m_inputList[i].restoreWeights();
 161  
     }
 162  0
   }
 163  
 
 164  
   /**
 165  
    * Call this to get the weight value on a particular connection.
 166  
    * @param n The connection number to get the weight for, -1 if The threshold
 167  
    * weight should be returned.
 168  
    * @return The value for the specified connection or if -1 then it should 
 169  
    * return the threshold value. If no value exists for the specified 
 170  
    * connection, NaN will be returned.
 171  
    */
 172  
   public double weightValue(int n) {
 173  0
     if (n >= m_numInputs || n < -1) {
 174  0
       return Double.NaN;
 175  
     }
 176  0
     return m_weights[n + 1];
 177  
   }
 178  
 
 179  
   /**
 180  
    * call this function to get the weights array.
 181  
    * This will also allow the weights to be updated.
 182  
    * @return The weights array.
 183  
    */
 184  
   public double[] getWeights() {
 185  0
     return m_weights;
 186  
   }
 187  
 
 188  
   /**
 189  
    * call this function to get the chnage in weights array.
 190  
    * This will also allow the change in weights to be updated.
 191  
    * @return The change in weights array.
 192  
    */
 193  
   public double[] getChangeInWeights() {
 194  0
     return m_changeInWeights;
 195  
   }
 196  
 
 197  
   /**
 198  
    * Call this function to update the weight values at this unit.
 199  
    * After the weights have been updated at this unit, All the
 200  
    * input connections will then be called from this to have their
 201  
    * weights updated.
 202  
    * @param l The learning rate to use.
 203  
    * @param m The momentum to use.
 204  
    */
 205  
   public void updateWeights(double l, double m) {
 206  
     
 207  0
     if (!m_weightsUpdated && !Double.isNaN(m_unitError)) {
 208  0
       m_methods.updateWeights(this, l, m);
 209  
      
 210  
       //note that the super call to update the inputs is done here and
 211  
       //not in the m_method updateWeights, because it is not deemed to be
 212  
       //required to update the weights at this node (while the error and output
 213  
       //value ao need to be recursively calculated)
 214  0
       super.updateWeights(l, m); //to call all of the inputs.
 215  
     }
 216  
     
 217  0
   }
 218  
 
 219  
   /**
 220  
    * This will connect the specified unit to be an input to this unit.
 221  
    * @param i The unit.
 222  
    * @param n It's connection number for this connection.
 223  
    * @return True if the connection was made, false otherwise.
 224  
    */
 225  
   protected boolean connectInput(NeuralConnection i, int n) {
 226  
     
 227  
     //the function that this overrides can do most of the work.
 228  0
     if (!super.connectInput(i, n)) {
 229  0
       return false;
 230  
     }
 231  
     
 232  
     //note that the weights are shifted 1 forward in the array so
 233  
     //it leaves the numinputs aligned on the space the weight needs to go.
 234  0
     m_weights[m_numInputs] = m_random.nextDouble() * .1 - .05;
 235  0
     m_changeInWeights[m_numInputs] = 0;
 236  
     
 237  0
     return true;
 238  
   }
 239  
 
 240  
   /**
 241  
    * This will allocate more space for input connection information
 242  
    * if the arrays for this have been filled up.
 243  
    */
 244  
   protected void allocateInputs() {
 245  
     
 246  0
     NeuralConnection[] temp1 = new NeuralConnection[m_inputList.length + 15];
 247  0
     int[] temp2 = new int[m_inputNums.length + 15];
 248  0
     double[] temp4 = new double[m_weights.length + 15];
 249  0
     double[] temp5 = new double[m_changeInWeights.length + 15];
 250  0
     double[] temp6 = new double[m_bestWeights.length + 15];
 251  
 
 252  0
     temp4[0] = m_weights[0];
 253  0
     temp5[0] = m_changeInWeights[0];
 254  0
     temp6[0] = m_bestWeights[0];
 255  0
     for (int noa = 0; noa < m_numInputs; noa++) {
 256  0
       temp1[noa] = m_inputList[noa];
 257  0
       temp2[noa] = m_inputNums[noa];
 258  0
       temp4[noa+1] = m_weights[noa+1];
 259  0
       temp5[noa+1] = m_changeInWeights[noa+1];
 260  0
       temp6[noa+1] = m_bestWeights[noa+1];
 261  
     }
 262  
     
 263  0
     m_inputList = temp1;
 264  0
     m_inputNums = temp2;
 265  0
     m_weights = temp4;
 266  0
     m_changeInWeights = temp5;
 267  0
     m_bestWeights = temp6;
 268  0
   }
 269  
 
 270  
   
 271  
   
 272  
 
 273  
   /**
 274  
    * This will disconnect the input with the specific connection number
 275  
    * From this node (only on this end however).
 276  
    * @param i The unit to disconnect.
 277  
    * @param n The connection number at the other end, -1 if all the connections
 278  
    * to this unit should be severed (not the same as removeAllInputs).
 279  
    * @return True if the connection was removed, false if the connection was 
 280  
    * not found.
 281  
    */
 282  
   protected boolean disconnectInput(NeuralConnection i, int n) {
 283  
     
 284  0
     int loc = -1;
 285  0
     boolean removed = false;
 286  
     do {
 287  0
       loc = -1;
 288  0
       for (int noa = 0; noa < m_numInputs; noa++) {
 289  0
         if (i == m_inputList[noa] && (n == -1 || n == m_inputNums[noa])) {
 290  0
           loc = noa;
 291  0
           break;
 292  
         }
 293  
       }
 294  
       
 295  0
       if (loc >= 0) {
 296  0
         for (int noa = loc+1; noa < m_numInputs; noa++) {
 297  0
           m_inputList[noa-1] = m_inputList[noa];
 298  0
           m_inputNums[noa-1] = m_inputNums[noa];
 299  
           
 300  0
           m_weights[noa] = m_weights[noa+1];
 301  0
           m_changeInWeights[noa] = m_changeInWeights[noa+1];
 302  
           
 303  0
           m_inputList[noa-1].changeOutputNum(m_inputNums[noa-1], noa-1);
 304  
         }
 305  0
         m_numInputs--;
 306  0
         removed = true;
 307  
       }      
 308  0
     } while (n == -1 && loc != -1);
 309  0
     return removed;
 310  
   }
 311  
   
 312  
   /**
 313  
    * This function will remove all the inputs to this unit.
 314  
    * In doing so it will also terminate the connections at the other end.
 315  
    */
 316  
   public void removeAllInputs() {
 317  0
     super.removeAllInputs();
 318  
     
 319  0
     double temp1 = m_weights[0];
 320  0
     double temp2 = m_changeInWeights[0];
 321  
 
 322  0
     m_weights = new double[1];
 323  0
     m_changeInWeights = new double[1];
 324  
 
 325  0
     m_weights[0] = temp1;
 326  0
     m_changeInWeights[0] = temp2;
 327  
     
 328  0
   }  
 329  
   
 330  
   /**
 331  
    * Returns the revision string.
 332  
    * 
 333  
    * @return                the revision
 334  
    */
 335  
   public String getRevision() {
 336  0
     return RevisionUtils.extract("$Revision: 8034 $");
 337  
   }
 338  
 }