| Classes in this File | Line Coverage | Branch Coverage | Complexity | ||||
| CachedKernel |
|
| 2.75;2.75 |
| 1 | /* | |
| 2 | * This program is free software: you can redistribute it and/or modify | |
| 3 | * it under the terms of the GNU General Public License as published by | |
| 4 | * the Free Software Foundation, either version 3 of the License, or | |
| 5 | * (at your option) any later version. | |
| 6 | * | |
| 7 | * This program is distributed in the hope that it will be useful, | |
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| 10 | * GNU General Public License for more details. | |
| 11 | * | |
| 12 | * You should have received a copy of the GNU General Public License | |
| 13 | * along with this program. If not, see <http://www.gnu.org/licenses/>. | |
| 14 | */ | |
| 15 | ||
| 16 | /* | |
| 17 | * CachedKernel.java | |
| 18 | * Copyright (C) 2005-2012 University of Waikato, Hamilton, New Zealand | |
| 19 | */ | |
| 20 | ||
| 21 | package weka.classifiers.functions.supportVector; | |
| 22 | ||
| 23 | import java.util.Enumeration; | |
| 24 | import java.util.Vector; | |
| 25 | ||
| 26 | import weka.core.Instance; | |
| 27 | import weka.core.Instances; | |
| 28 | import weka.core.Option; | |
| 29 | import weka.core.Utils; | |
| 30 | ||
| 31 | /** | |
| 32 | * Base class for RBFKernel and PolyKernel that implements a simple LRU. | |
| 33 | * (least-recently-used) cache if the cache size is set to a value > 0. | |
| 34 | * Otherwise it uses a full cache. | |
| 35 | * | |
| 36 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) | |
| 37 | * @author Shane Legg (shane@intelligenesis.net) (sparse vector code) | |
| 38 | * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code) | |
| 39 | * @author J. Lindgren (jtlindgr{at}cs.helsinki.fi) (RBF kernel) | |
| 40 | * @author Steven Hugg (hugg@fasterlight.com) (refactored, LRU cache) | |
| 41 | * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz) (full cache) | |
| 42 | * @version $Revision: 8034 $ | |
| 43 | */ | |
| 44 | public abstract class CachedKernel | |
| 45 | extends Kernel { | |
| 46 | ||
| 47 | /** for serialization */ | |
| 48 | private static final long serialVersionUID = 702810182699015136L; | |
| 49 | ||
| 50 | /** Counts the number of kernel evaluations. */ | |
| 51 | protected int m_kernelEvals; | |
| 52 | ||
| 53 | /** Counts the number of kernel cache hits. */ | |
| 54 | protected int m_cacheHits; | |
| 55 | ||
| 56 | /** The size of the cache (a prime number) */ | |
| 57 | 0 | protected int m_cacheSize = 250007; |
| 58 | ||
| 59 | /** Kernel cache */ | |
| 60 | protected double[] m_storage; | |
| 61 | protected long[] m_keys; | |
| 62 | ||
| 63 | /** The kernel matrix if full cache is used (i.e. size is set to 0) */ | |
| 64 | protected double[][] m_kernelMatrix; | |
| 65 | ||
| 66 | /** The number of instance in the dataset */ | |
| 67 | protected int m_numInsts; | |
| 68 | ||
| 69 | /** number of cache slots in an entry */ | |
| 70 | 0 | protected int m_cacheSlots = 4; |
| 71 | ||
| 72 | /** | |
| 73 | * default constructor - does nothing. | |
| 74 | */ | |
| 75 | public CachedKernel() { | |
| 76 | 0 | super(); |
| 77 | 0 | } |
| 78 | ||
| 79 | /** | |
| 80 | * Initializes the kernel cache. The actual size of the cache in bytes is | |
| 81 | * (64 * cacheSize). | |
| 82 | * | |
| 83 | * @param data the data to use | |
| 84 | * @param cacheSize the cache size | |
| 85 | * @throws Exception if something goes wrong | |
| 86 | */ | |
| 87 | protected CachedKernel(Instances data, int cacheSize) throws Exception { | |
| 88 | 0 | super(); |
| 89 | ||
| 90 | 0 | setCacheSize(cacheSize); |
| 91 | ||
| 92 | 0 | buildKernel(data); |
| 93 | 0 | } |
| 94 | ||
| 95 | /** | |
| 96 | * Returns an enumeration describing the available options. | |
| 97 | * | |
| 98 | * @return an enumeration of all the available options. | |
| 99 | */ | |
| 100 | public Enumeration listOptions() { | |
| 101 | Vector result; | |
| 102 | Enumeration en; | |
| 103 | ||
| 104 | 0 | result = new Vector(); |
| 105 | ||
| 106 | 0 | en = super.listOptions(); |
| 107 | 0 | while (en.hasMoreElements()) |
| 108 | 0 | result.addElement(en.nextElement()); |
| 109 | ||
| 110 | 0 | result.addElement(new Option( |
| 111 | "\tThe size of the cache (a prime number), 0 for full cache and \n" | |
| 112 | + "\t-1 to turn it off.\n" | |
| 113 | + "\t(default: 250007)", | |
| 114 | "C", 1, "-C <num>")); | |
| 115 | ||
| 116 | 0 | return result.elements(); |
| 117 | } | |
| 118 | ||
| 119 | /** | |
| 120 | * Parses a given list of options. <p/> | |
| 121 | * | |
| 122 | * @param options the list of options as an array of strings | |
| 123 | * @throws Exception if an option is not supported | |
| 124 | */ | |
| 125 | public void setOptions(String[] options) throws Exception { | |
| 126 | String tmpStr; | |
| 127 | ||
| 128 | 0 | tmpStr = Utils.getOption('C', options); |
| 129 | 0 | if (tmpStr.length() != 0) |
| 130 | 0 | setCacheSize(Integer.parseInt(tmpStr)); |
| 131 | else | |
| 132 | 0 | setCacheSize(250007); |
| 133 | ||
| 134 | 0 | super.setOptions(options); |
| 135 | 0 | } |
| 136 | ||
| 137 | /** | |
| 138 | * Gets the current settings of the Kernel. | |
| 139 | * | |
| 140 | * @return an array of strings suitable for passing to setOptions | |
| 141 | */ | |
| 142 | public String[] getOptions() { | |
| 143 | int i; | |
| 144 | Vector result; | |
| 145 | String[] options; | |
| 146 | ||
| 147 | 0 | result = new Vector(); |
| 148 | 0 | options = super.getOptions(); |
| 149 | 0 | for (i = 0; i < options.length; i++) |
| 150 | 0 | result.add(options[i]); |
| 151 | ||
| 152 | 0 | result.add("-C"); |
| 153 | 0 | result.add("" + getCacheSize()); |
| 154 | ||
| 155 | 0 | return (String[]) result.toArray(new String[result.size()]); |
| 156 | } | |
| 157 | ||
| 158 | /** | |
| 159 | * This method is overridden in subclasses to implement specific kernels. | |
| 160 | * | |
| 161 | * @param id1 the index of instance 1 | |
| 162 | * @param id2 the index of instance 2 | |
| 163 | * @param inst1 the instance 1 object | |
| 164 | * @return the dot product | |
| 165 | * @throws Exception if something goes wrong | |
| 166 | */ | |
| 167 | protected abstract double evaluate(int id1, int id2, Instance inst1) | |
| 168 | throws Exception; | |
| 169 | ||
| 170 | /** | |
| 171 | * Implements the abstract function of Kernel using the cache. This method | |
| 172 | * uses the evaluate() method to do the actual dot product. | |
| 173 | * | |
| 174 | * @param id1 the index of the first instance in the dataset | |
| 175 | * @param id2 the index of the second instance in the dataset | |
| 176 | * @param inst1 the instance corresponding to id1 (used if id1 == -1) | |
| 177 | * @return the result of the kernel function | |
| 178 | * @throws Exception if something goes wrong | |
| 179 | */ | |
| 180 | public double eval(int id1, int id2, Instance inst1) throws Exception { | |
| 181 | ||
| 182 | 0 | double result = 0; |
| 183 | 0 | long key = -1; |
| 184 | 0 | int location = -1; |
| 185 | ||
| 186 | // we can only cache if we know the indexes and caching is not | |
| 187 | // disbled (m_cacheSize == -1) | |
| 188 | 0 | if ( (id1 >= 0) && (m_cacheSize != -1) ) { |
| 189 | ||
| 190 | // Use full cache? | |
| 191 | 0 | if (m_cacheSize == 0) { |
| 192 | 0 | if (m_kernelMatrix == null) { |
| 193 | 0 | m_kernelMatrix = new double[m_data.numInstances()][]; |
| 194 | 0 | for(int i = 0; i < m_data.numInstances(); i++) { |
| 195 | 0 | m_kernelMatrix[i] = new double[i + 1]; |
| 196 | 0 | for(int j = 0; j <= i; j++) { |
| 197 | 0 | m_kernelEvals++; |
| 198 | 0 | m_kernelMatrix[i][j] = evaluate(i, j, m_data.instance(i)); |
| 199 | } | |
| 200 | } | |
| 201 | } | |
| 202 | 0 | m_cacheHits++; |
| 203 | 0 | result = (id1 > id2) ? m_kernelMatrix[id1][id2] : m_kernelMatrix[id2][id1]; |
| 204 | 0 | return result; |
| 205 | } | |
| 206 | ||
| 207 | // Use LRU cache | |
| 208 | 0 | if (id1 > id2) { |
| 209 | 0 | key = (id1 + ((long) id2 * m_numInsts)); |
| 210 | } else { | |
| 211 | 0 | key = (id2 + ((long) id1 * m_numInsts)); |
| 212 | } | |
| 213 | 0 | location = (int) (key % m_cacheSize) * m_cacheSlots; |
| 214 | 0 | int loc = location; |
| 215 | 0 | for (int i = 0; i < m_cacheSlots; i++) { |
| 216 | 0 | long thiskey = m_keys[loc]; |
| 217 | 0 | if (thiskey == 0) |
| 218 | 0 | break; // empty slot, so break out of loop early |
| 219 | 0 | if (thiskey == (key + 1)) { |
| 220 | 0 | m_cacheHits++; |
| 221 | // move entry to front of cache (LRU) by swapping | |
| 222 | // only if it's not already at the front of cache | |
| 223 | 0 | if (i > 0) { |
| 224 | 0 | double tmps = m_storage[loc]; |
| 225 | 0 | m_storage[loc] = m_storage[location]; |
| 226 | 0 | m_keys[loc] = m_keys[location]; |
| 227 | 0 | m_storage[location] = tmps; |
| 228 | 0 | m_keys[location] = thiskey; |
| 229 | 0 | return tmps; |
| 230 | } else | |
| 231 | 0 | return m_storage[loc]; |
| 232 | } | |
| 233 | 0 | loc++; |
| 234 | } | |
| 235 | } | |
| 236 | ||
| 237 | 0 | result = evaluate(id1, id2, inst1); |
| 238 | ||
| 239 | 0 | m_kernelEvals++; |
| 240 | ||
| 241 | // store result in cache | |
| 242 | 0 | if ( (key != -1) && (m_cacheSize != -1) ) { |
| 243 | // move all cache slots forward one array index | |
| 244 | // to make room for the new entry | |
| 245 | 0 | System.arraycopy(m_keys, location, m_keys, location + 1, |
| 246 | m_cacheSlots - 1); | |
| 247 | 0 | System.arraycopy(m_storage, location, m_storage, location + 1, |
| 248 | m_cacheSlots - 1); | |
| 249 | 0 | m_storage[location] = result; |
| 250 | 0 | m_keys[location] = (key + 1); |
| 251 | } | |
| 252 | 0 | return result; |
| 253 | } | |
| 254 | ||
| 255 | /** | |
| 256 | * Returns the number of time Eval has been called. | |
| 257 | * | |
| 258 | * @return the number of kernel evaluation. | |
| 259 | */ | |
| 260 | public int numEvals() { | |
| 261 | 0 | return m_kernelEvals; |
| 262 | } | |
| 263 | ||
| 264 | /** | |
| 265 | * Returns the number of cache hits on dot products. | |
| 266 | * | |
| 267 | * @return the number of cache hits. | |
| 268 | */ | |
| 269 | public int numCacheHits() { | |
| 270 | 0 | return m_cacheHits; |
| 271 | } | |
| 272 | ||
| 273 | /** | |
| 274 | * Frees the cache used by the kernel. | |
| 275 | */ | |
| 276 | public void clean() { | |
| 277 | 0 | m_storage = null; |
| 278 | 0 | m_keys = null; |
| 279 | 0 | m_kernelMatrix = null; |
| 280 | 0 | } |
| 281 | ||
| 282 | /** | |
| 283 | * Calculates a dot product between two instances | |
| 284 | * | |
| 285 | * @param inst1 the first instance | |
| 286 | * @param inst2 the second instance | |
| 287 | * @return the dot product of the two instances. | |
| 288 | * @throws Exception if an error occurs | |
| 289 | */ | |
| 290 | protected final double dotProd(Instance inst1, Instance inst2) | |
| 291 | throws Exception { | |
| 292 | ||
| 293 | 0 | double result = 0; |
| 294 | ||
| 295 | // we can do a fast dot product | |
| 296 | 0 | int n1 = inst1.numValues(); |
| 297 | 0 | int n2 = inst2.numValues(); |
| 298 | 0 | int classIndex = m_data.classIndex(); |
| 299 | 0 | for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) { |
| 300 | 0 | int ind1 = inst1.index(p1); |
| 301 | 0 | int ind2 = inst2.index(p2); |
| 302 | 0 | if (ind1 == ind2) { |
| 303 | 0 | if (ind1 != classIndex) { |
| 304 | 0 | result += inst1.valueSparse(p1) * inst2.valueSparse(p2); |
| 305 | } | |
| 306 | 0 | p1++; |
| 307 | 0 | p2++; |
| 308 | 0 | } else if (ind1 > ind2) { |
| 309 | 0 | p2++; |
| 310 | } else { | |
| 311 | 0 | p1++; |
| 312 | } | |
| 313 | 0 | } |
| 314 | 0 | return (result); |
| 315 | } | |
| 316 | ||
| 317 | /** | |
| 318 | * Sets the size of the cache to use (a prime number) | |
| 319 | * | |
| 320 | * @param value the size of the cache | |
| 321 | */ | |
| 322 | public void setCacheSize(int value) { | |
| 323 | 0 | if (value >= -1) { |
| 324 | 0 | m_cacheSize = value; |
| 325 | 0 | clean(); |
| 326 | } | |
| 327 | else { | |
| 328 | 0 | System.out.println( |
| 329 | "Cache size cannot be smaller than -1 (provided: " + value + ")!"); | |
| 330 | } | |
| 331 | 0 | } |
| 332 | ||
| 333 | /** | |
| 334 | * Gets the size of the cache | |
| 335 | * | |
| 336 | * @return the cache size | |
| 337 | */ | |
| 338 | public int getCacheSize() { | |
| 339 | 0 | return m_cacheSize; |
| 340 | } | |
| 341 | ||
| 342 | /** | |
| 343 | * Returns the tip text for this property | |
| 344 | * | |
| 345 | * @return tip text for this property suitable for | |
| 346 | * displaying in the explorer/experimenter gui | |
| 347 | */ | |
| 348 | public String cacheSizeTipText() { | |
| 349 | 0 | return "The size of the cache (a prime number), 0 for full cache and -1 to turn it off."; |
| 350 | } | |
| 351 | ||
| 352 | /** | |
| 353 | * initializes variables etc. | |
| 354 | * | |
| 355 | * @param data the data to use | |
| 356 | */ | |
| 357 | protected void initVars(Instances data) { | |
| 358 | 0 | super.initVars(data); |
| 359 | ||
| 360 | 0 | m_kernelEvals = 0; |
| 361 | 0 | m_cacheHits = 0; |
| 362 | 0 | m_numInsts = m_data.numInstances(); |
| 363 | ||
| 364 | 0 | if (getCacheSize() > 0) { |
| 365 | // Use LRU cache | |
| 366 | 0 | m_storage = new double[m_cacheSize * m_cacheSlots]; |
| 367 | 0 | m_keys = new long[m_cacheSize * m_cacheSlots]; |
| 368 | } | |
| 369 | else { | |
| 370 | 0 | m_storage = null; |
| 371 | 0 | m_keys = null; |
| 372 | 0 | m_kernelMatrix = null; |
| 373 | } | |
| 374 | 0 | } |
| 375 | ||
| 376 | /** | |
| 377 | * builds the kernel with the given data. Initializes the kernel cache. | |
| 378 | * The actual size of the cache in bytes is (64 * cacheSize). | |
| 379 | * | |
| 380 | * @param data the data to base the kernel on | |
| 381 | * @throws Exception if something goes wrong | |
| 382 | */ | |
| 383 | public void buildKernel(Instances data) throws Exception { | |
| 384 | // does kernel handle the data? | |
| 385 | 0 | if (!getChecksTurnedOff()) |
| 386 | 0 | getCapabilities().testWithFail(data); |
| 387 | ||
| 388 | 0 | initVars(data); |
| 389 | 0 | } |
| 390 | } |