| Classes in this File | Line Coverage | Branch Coverage | Complexity | ||||
| NeuralNode |
|
| 2.5294117647058822;2.529 |
| 1 | /* | |
| 2 | * This program is free software: you can redistribute it and/or modify | |
| 3 | * it under the terms of the GNU General Public License as published by | |
| 4 | * the Free Software Foundation, either version 3 of the License, or | |
| 5 | * (at your option) any later version. | |
| 6 | * | |
| 7 | * This program is distributed in the hope that it will be useful, | |
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 10 | * GNU General Public License for more details. | |
| 11 | * | |
| 12 | * You should have received a copy of the GNU General Public License | |
| 13 | * along with this program. If not, see <http://www.gnu.org/licenses/>. | |
| 14 | */ | |
| 15 | ||
| 16 | /* | |
| 17 | * NeuralNode.java | |
| 18 | * Copyright (C) 2000-2012 University of Waikato, Hamilton, New Zealand | |
| 19 | */ | |
| 20 | ||
| 21 | package weka.classifiers.functions.neural; | |
| 22 | ||
| 23 | import weka.core.RevisionUtils; | |
| 24 | ||
| 25 | import java.util.Random; | |
| 26 | ||
| 27 | /** | |
| 28 | * This class is used to represent a node in the neuralnet. | |
| 29 | * | |
| 30 | * @author Malcolm Ware (mfw4@cs.waikato.ac.nz) | |
| 31 | * @version $Revision: 8034 $ | |
| 32 | */ | |
| 33 | public class NeuralNode | |
| 34 | extends NeuralConnection { | |
| 35 | ||
| 36 | /** for serialization */ | |
| 37 | private static final long serialVersionUID = -1085750607680839163L; | |
| 38 | ||
| 39 | /** The weights for each of the input connections, and the threshold. */ | |
| 40 | private double[] m_weights; | |
| 41 | ||
| 42 | /** The best (lowest error) weights. Only used when validation set is used */ | |
| 43 | private double[] m_bestWeights; | |
| 44 | ||
| 45 | /** The change in the weights. */ | |
| 46 | private double[] m_changeInWeights; | |
| 47 | ||
| 48 | private Random m_random; | |
| 49 | ||
| 50 | /** Performs the operations for this node. Currently this | |
| 51 | * defines that the node is either a sigmoid or a linear unit. */ | |
| 52 | private NeuralMethod m_methods; | |
| 53 | ||
| 54 | /** | |
| 55 | * @param id The string name for this node (used to id this node). | |
| 56 | * @param r A random number generator used to generate initial weights. | |
| 57 | * @param m The methods this node should use to update. | |
| 58 | */ | |
| 59 | public NeuralNode(String id, Random r, NeuralMethod m) { | |
| 60 | 0 | super(id); |
| 61 | 0 | m_weights = new double[1]; |
| 62 | 0 | m_bestWeights = new double[1]; |
| 63 | 0 | m_changeInWeights = new double[1]; |
| 64 | ||
| 65 | 0 | m_random = r; |
| 66 | ||
| 67 | 0 | m_weights[0] = m_random.nextDouble() * .1 - .05; |
| 68 | 0 | m_changeInWeights[0] = 0; |
| 69 | ||
| 70 | 0 | m_methods = m; |
| 71 | 0 | } |
| 72 | ||
| 73 | /** | |
| 74 | * Set how this node should operate (note that the neural method has no | |
| 75 | * internal state, so the same object can be used by any number of nodes. | |
| 76 | * @param m The new method. | |
| 77 | */ | |
| 78 | public void setMethod(NeuralMethod m) { | |
| 79 | 0 | m_methods = m; |
| 80 | 0 | } |
| 81 | ||
| 82 | public NeuralMethod getMethod() { | |
| 83 | 0 | return m_methods; |
| 84 | } | |
| 85 | ||
| 86 | /** | |
| 87 | * Call this to get the output value of this unit. | |
| 88 | * @param calculate True if the value should be calculated if it hasn't been | |
| 89 | * already. | |
| 90 | * @return The output value, or NaN, if the value has not been calculated. | |
| 91 | */ | |
| 92 | public double outputValue(boolean calculate) { | |
| 93 | ||
| 94 | 0 | if (Double.isNaN(m_unitValue) && calculate) { |
| 95 | //then calculate the output value; | |
| 96 | 0 | m_unitValue = m_methods.outputValue(this); |
| 97 | } | |
| 98 | ||
| 99 | 0 | return m_unitValue; |
| 100 | } | |
| 101 | ||
| 102 | ||
| 103 | /** | |
| 104 | * Call this to get the error value of this unit. | |
| 105 | * @param calculate True if the value should be calculated if it hasn't been | |
| 106 | * already. | |
| 107 | * @return The error value, or NaN, if the value has not been calculated. | |
| 108 | */ | |
| 109 | public double errorValue(boolean calculate) { | |
| 110 | ||
| 111 | 0 | if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) { |
| 112 | //then calculate the error. | |
| 113 | 0 | m_unitError = m_methods.errorValue(this); |
| 114 | } | |
| 115 | 0 | return m_unitError; |
| 116 | } | |
| 117 | ||
| 118 | /** | |
| 119 | * Call this to reset the value and error for this unit, ready for the next | |
| 120 | * run. This will also call the reset function of all units that are | |
| 121 | * connected as inputs to this one. | |
| 122 | * This is also the time that the update for the listeners will be performed. | |
| 123 | */ | |
| 124 | public void reset() { | |
| 125 | ||
| 126 | 0 | if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) { |
| 127 | 0 | m_unitValue = Double.NaN; |
| 128 | 0 | m_unitError = Double.NaN; |
| 129 | 0 | m_weightsUpdated = false; |
| 130 | 0 | for (int noa = 0; noa < m_numInputs; noa++) { |
| 131 | 0 | m_inputList[noa].reset(); |
| 132 | } | |
| 133 | } | |
| 134 | 0 | } |
| 135 | ||
| 136 | /** | |
| 137 | * Call this to have the connection save the current | |
| 138 | * weights. | |
| 139 | */ | |
| 140 | public void saveWeights() { | |
| 141 | // copy the current weights | |
| 142 | 0 | System.arraycopy(m_weights, 0, m_bestWeights, 0, m_weights.length); |
| 143 | ||
| 144 | // tell inputs to save weights | |
| 145 | 0 | for (int i = 0; i < m_numInputs; i++) { |
| 146 | 0 | m_inputList[i].saveWeights(); |
| 147 | } | |
| 148 | 0 | } |
| 149 | ||
| 150 | /** | |
| 151 | * Call this to have the connection restore from the saved | |
| 152 | * weights. | |
| 153 | */ | |
| 154 | public void restoreWeights() { | |
| 155 | // copy the saved best weights back into the weights | |
| 156 | 0 | System.arraycopy(m_bestWeights, 0, m_weights, 0, m_weights.length); |
| 157 | ||
| 158 | // tell inputs to restore weights | |
| 159 | 0 | for (int i = 0; i < m_numInputs; i++) { |
| 160 | 0 | m_inputList[i].restoreWeights(); |
| 161 | } | |
| 162 | 0 | } |
| 163 | ||
| 164 | /** | |
| 165 | * Call this to get the weight value on a particular connection. | |
| 166 | * @param n The connection number to get the weight for, -1 if The threshold | |
| 167 | * weight should be returned. | |
| 168 | * @return The value for the specified connection or if -1 then it should | |
| 169 | * return the threshold value. If no value exists for the specified | |
| 170 | * connection, NaN will be returned. | |
| 171 | */ | |
| 172 | public double weightValue(int n) { | |
| 173 | 0 | if (n >= m_numInputs || n < -1) { |
| 174 | 0 | return Double.NaN; |
| 175 | } | |
| 176 | 0 | return m_weights[n + 1]; |
| 177 | } | |
| 178 | ||
| 179 | /** | |
| 180 | * call this function to get the weights array. | |
| 181 | * This will also allow the weights to be updated. | |
| 182 | * @return The weights array. | |
| 183 | */ | |
| 184 | public double[] getWeights() { | |
| 185 | 0 | return m_weights; |
| 186 | } | |
| 187 | ||
| 188 | /** | |
| 189 | * call this function to get the chnage in weights array. | |
| 190 | * This will also allow the change in weights to be updated. | |
| 191 | * @return The change in weights array. | |
| 192 | */ | |
| 193 | public double[] getChangeInWeights() { | |
| 194 | 0 | return m_changeInWeights; |
| 195 | } | |
| 196 | ||
| 197 | /** | |
| 198 | * Call this function to update the weight values at this unit. | |
| 199 | * After the weights have been updated at this unit, All the | |
| 200 | * input connections will then be called from this to have their | |
| 201 | * weights updated. | |
| 202 | * @param l The learning rate to use. | |
| 203 | * @param m The momentum to use. | |
| 204 | */ | |
| 205 | public void updateWeights(double l, double m) { | |
| 206 | ||
| 207 | 0 | if (!m_weightsUpdated && !Double.isNaN(m_unitError)) { |
| 208 | 0 | m_methods.updateWeights(this, l, m); |
| 209 | ||
| 210 | //note that the super call to update the inputs is done here and | |
| 211 | //not in the m_method updateWeights, because it is not deemed to be | |
| 212 | //required to update the weights at this node (while the error and output | |
| 213 | //value ao need to be recursively calculated) | |
| 214 | 0 | super.updateWeights(l, m); //to call all of the inputs. |
| 215 | } | |
| 216 | ||
| 217 | 0 | } |
| 218 | ||
| 219 | /** | |
| 220 | * This will connect the specified unit to be an input to this unit. | |
| 221 | * @param i The unit. | |
| 222 | * @param n It's connection number for this connection. | |
| 223 | * @return True if the connection was made, false otherwise. | |
| 224 | */ | |
| 225 | protected boolean connectInput(NeuralConnection i, int n) { | |
| 226 | ||
| 227 | //the function that this overrides can do most of the work. | |
| 228 | 0 | if (!super.connectInput(i, n)) { |
| 229 | 0 | return false; |
| 230 | } | |
| 231 | ||
| 232 | //note that the weights are shifted 1 forward in the array so | |
| 233 | //it leaves the numinputs aligned on the space the weight needs to go. | |
| 234 | 0 | m_weights[m_numInputs] = m_random.nextDouble() * .1 - .05; |
| 235 | 0 | m_changeInWeights[m_numInputs] = 0; |
| 236 | ||
| 237 | 0 | return true; |
| 238 | } | |
| 239 | ||
| 240 | /** | |
| 241 | * This will allocate more space for input connection information | |
| 242 | * if the arrays for this have been filled up. | |
| 243 | */ | |
| 244 | protected void allocateInputs() { | |
| 245 | ||
| 246 | 0 | NeuralConnection[] temp1 = new NeuralConnection[m_inputList.length + 15]; |
| 247 | 0 | int[] temp2 = new int[m_inputNums.length + 15]; |
| 248 | 0 | double[] temp4 = new double[m_weights.length + 15]; |
| 249 | 0 | double[] temp5 = new double[m_changeInWeights.length + 15]; |
| 250 | 0 | double[] temp6 = new double[m_bestWeights.length + 15]; |
| 251 | ||
| 252 | 0 | temp4[0] = m_weights[0]; |
| 253 | 0 | temp5[0] = m_changeInWeights[0]; |
| 254 | 0 | temp6[0] = m_bestWeights[0]; |
| 255 | 0 | for (int noa = 0; noa < m_numInputs; noa++) { |
| 256 | 0 | temp1[noa] = m_inputList[noa]; |
| 257 | 0 | temp2[noa] = m_inputNums[noa]; |
| 258 | 0 | temp4[noa+1] = m_weights[noa+1]; |
| 259 | 0 | temp5[noa+1] = m_changeInWeights[noa+1]; |
| 260 | 0 | temp6[noa+1] = m_bestWeights[noa+1]; |
| 261 | } | |
| 262 | ||
| 263 | 0 | m_inputList = temp1; |
| 264 | 0 | m_inputNums = temp2; |
| 265 | 0 | m_weights = temp4; |
| 266 | 0 | m_changeInWeights = temp5; |
| 267 | 0 | m_bestWeights = temp6; |
| 268 | 0 | } |
| 269 | ||
| 270 | ||
| 271 | ||
| 272 | ||
| 273 | /** | |
| 274 | * This will disconnect the input with the specific connection number | |
| 275 | * From this node (only on this end however). | |
| 276 | * @param i The unit to disconnect. | |
| 277 | * @param n The connection number at the other end, -1 if all the connections | |
| 278 | * to this unit should be severed (not the same as removeAllInputs). | |
| 279 | * @return True if the connection was removed, false if the connection was | |
| 280 | * not found. | |
| 281 | */ | |
| 282 | protected boolean disconnectInput(NeuralConnection i, int n) { | |
| 283 | ||
| 284 | 0 | int loc = -1; |
| 285 | 0 | boolean removed = false; |
| 286 | do { | |
| 287 | 0 | loc = -1; |
| 288 | 0 | for (int noa = 0; noa < m_numInputs; noa++) { |
| 289 | 0 | if (i == m_inputList[noa] && (n == -1 || n == m_inputNums[noa])) { |
| 290 | 0 | loc = noa; |
| 291 | 0 | break; |
| 292 | } | |
| 293 | } | |
| 294 | ||
| 295 | 0 | if (loc >= 0) { |
| 296 | 0 | for (int noa = loc+1; noa < m_numInputs; noa++) { |
| 297 | 0 | m_inputList[noa-1] = m_inputList[noa]; |
| 298 | 0 | m_inputNums[noa-1] = m_inputNums[noa]; |
| 299 | ||
| 300 | 0 | m_weights[noa] = m_weights[noa+1]; |
| 301 | 0 | m_changeInWeights[noa] = m_changeInWeights[noa+1]; |
| 302 | ||
| 303 | 0 | m_inputList[noa-1].changeOutputNum(m_inputNums[noa-1], noa-1); |
| 304 | } | |
| 305 | 0 | m_numInputs--; |
| 306 | 0 | removed = true; |
| 307 | } | |
| 308 | 0 | } while (n == -1 && loc != -1); |
| 309 | 0 | return removed; |
| 310 | } | |
| 311 | ||
| 312 | /** | |
| 313 | * This function will remove all the inputs to this unit. | |
| 314 | * In doing so it will also terminate the connections at the other end. | |
| 315 | */ | |
| 316 | public void removeAllInputs() { | |
| 317 | 0 | super.removeAllInputs(); |
| 318 | ||
| 319 | 0 | double temp1 = m_weights[0]; |
| 320 | 0 | double temp2 = m_changeInWeights[0]; |
| 321 | ||
| 322 | 0 | m_weights = new double[1]; |
| 323 | 0 | m_changeInWeights = new double[1]; |
| 324 | ||
| 325 | 0 | m_weights[0] = temp1; |
| 326 | 0 | m_changeInWeights[0] = temp2; |
| 327 | ||
| 328 | 0 | } |
| 329 | ||
| 330 | /** | |
| 331 | * Returns the revision string. | |
| 332 | * | |
| 333 | * @return the revision | |
| 334 | */ | |
| 335 | public String getRevision() { | |
| 336 | 0 | return RevisionUtils.extract("$Revision: 8034 $"); |
| 337 | } | |
| 338 | } |