Coverage Report - weka.classifiers.functions.supportVector.CachedKernel
 
Classes in this File Line Coverage Branch Coverage Complexity
CachedKernel
0%
0/115
0%
0/50
2.75
 
 1  
 /*
 2  
  *   This program is free software: you can redistribute it and/or modify
 3  
  *   it under the terms of the GNU General Public License as published by
 4  
  *   the Free Software Foundation, either version 3 of the License, or
 5  
  *   (at your option) any later version.
 6  
  *
 7  
  *   This program is distributed in the hope that it will be useful,
 8  
  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 9  
  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 10  
  *   GNU General Public License for more details.
 11  
  *
 12  
  *   You should have received a copy of the GNU General Public License
 13  
  *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 14  
  */
 15  
 
 16  
 /*
 17  
  * CachedKernel.java
 18  
  * Copyright (C) 2005-2012 University of Waikato, Hamilton, New Zealand
 19  
  */
 20  
 
 21  
 package weka.classifiers.functions.supportVector;
 22  
 
 23  
 import java.util.Enumeration;
 24  
 import java.util.Vector;
 25  
 
 26  
 import weka.core.Instance;
 27  
 import weka.core.Instances;
 28  
 import weka.core.Option;
 29  
 import weka.core.Utils;
 30  
 
 31  
 /**
 32  
  * Base class for RBFKernel and PolyKernel that implements a simple LRU.
 33  
  * (least-recently-used) cache if the cache size is set to a value > 0.
 34  
  * Otherwise it uses a full cache.
 35  
  * 
 36  
  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 37  
  * @author Shane Legg (shane@intelligenesis.net) (sparse vector code)
 38  
  * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code)
 39  
  * @author J. Lindgren (jtlindgr{at}cs.helsinki.fi) (RBF kernel)
 40  
  * @author Steven Hugg (hugg@fasterlight.com) (refactored, LRU cache)
 41  
  * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz) (full cache)
 42  
  * @version $Revision: 8034 $
 43  
  */
 44  
 public abstract class CachedKernel 
 45  
   extends Kernel {
 46  
 
 47  
   /** for serialization */
 48  
   private static final long serialVersionUID = 702810182699015136L;
 49  
     
 50  
   /** Counts the number of kernel evaluations. */
 51  
   protected int m_kernelEvals;
 52  
 
 53  
   /** Counts the number of kernel cache hits. */
 54  
   protected int m_cacheHits;
 55  
 
 56  
   /** The size of the cache (a prime number) */
 57  0
   protected int m_cacheSize = 250007;
 58  
 
 59  
   /** Kernel cache */
 60  
   protected double[] m_storage;
 61  
   protected long[] m_keys;
 62  
 
 63  
   /** The kernel matrix if full cache is used (i.e. size is set to 0) */
 64  
   protected double[][] m_kernelMatrix;
 65  
 
 66  
   /** The number of instance in the dataset */
 67  
   protected int m_numInsts;
 68  
 
 69  
   /** number of cache slots in an entry */
 70  0
   protected int m_cacheSlots = 4;
 71  
 
 72  
   /**
 73  
    * default constructor - does nothing.
 74  
    */
 75  
   public CachedKernel() {
 76  0
     super();
 77  0
   }
 78  
   
 79  
   /**
 80  
    * Initializes the kernel cache. The actual size of the cache in bytes is
 81  
    * (64 * cacheSize).
 82  
    * 
 83  
    * @param data        the data to use
 84  
    * @param cacheSize        the cache size
 85  
    * @throws Exception        if something goes wrong
 86  
    */
 87  
   protected CachedKernel(Instances data, int cacheSize) throws Exception {
 88  0
     super();
 89  
     
 90  0
     setCacheSize(cacheSize);
 91  
     
 92  0
     buildKernel(data);
 93  0
   }
 94  
   
 95  
   /**
 96  
    * Returns an enumeration describing the available options.
 97  
    *
 98  
    * @return                 an enumeration of all the available options.
 99  
    */
 100  
   public Enumeration listOptions() {
 101  
     Vector                result;
 102  
     Enumeration                en;
 103  
     
 104  0
     result = new Vector();
 105  
 
 106  0
     en = super.listOptions();
 107  0
     while (en.hasMoreElements())
 108  0
       result.addElement(en.nextElement());
 109  
 
 110  0
     result.addElement(new Option(
 111  
         "\tThe size of the cache (a prime number), 0 for full cache and \n"
 112  
         + "\t-1 to turn it off.\n"
 113  
         + "\t(default: 250007)",
 114  
         "C", 1, "-C <num>"));
 115  
 
 116  0
     return result.elements();
 117  
   }
 118  
 
 119  
   /**
 120  
    * Parses a given list of options. <p/>
 121  
    * 
 122  
    * @param options         the list of options as an array of strings
 123  
    * @throws Exception         if an option is not supported
 124  
    */
 125  
   public void setOptions(String[] options) throws Exception {
 126  
     String        tmpStr;
 127  
     
 128  0
     tmpStr = Utils.getOption('C', options);
 129  0
     if (tmpStr.length() != 0)
 130  0
       setCacheSize(Integer.parseInt(tmpStr));
 131  
     else
 132  0
       setCacheSize(250007);
 133  
     
 134  0
     super.setOptions(options);
 135  0
   }
 136  
 
 137  
   /**
 138  
    * Gets the current settings of the Kernel.
 139  
    *
 140  
    * @return an array of strings suitable for passing to setOptions
 141  
    */
 142  
   public String[] getOptions() {
 143  
     int       i;
 144  
     Vector    result;
 145  
     String[]  options;
 146  
 
 147  0
     result = new Vector();
 148  0
     options = super.getOptions();
 149  0
     for (i = 0; i < options.length; i++)
 150  0
       result.add(options[i]);
 151  
 
 152  0
     result.add("-C");
 153  0
     result.add("" + getCacheSize());
 154  
 
 155  0
     return (String[]) result.toArray(new String[result.size()]);          
 156  
   }
 157  
 
 158  
   /**
 159  
    * This method is overridden in subclasses to implement specific kernels.
 160  
    * 
 161  
    * @param id1           the index of instance 1
 162  
    * @param id2                the index of instance 2
 163  
    * @param inst1        the instance 1 object
 164  
    * @return                 the dot product
 165  
    * @throws Exception         if something goes wrong
 166  
    */
 167  
   protected abstract double evaluate(int id1, int id2, Instance inst1)
 168  
     throws Exception;
 169  
 
 170  
   /**
 171  
    * Implements the abstract function of Kernel using the cache. This method
 172  
    * uses the evaluate() method to do the actual dot product.
 173  
    *
 174  
    * @param id1         the index of the first instance in the dataset
 175  
    * @param id2         the index of the second instance in the dataset
 176  
    * @param inst1         the instance corresponding to id1 (used if id1 == -1)
 177  
    * @return                 the result of the kernel function
 178  
    * @throws Exception         if something goes wrong
 179  
    */
 180  
   public double eval(int id1, int id2, Instance inst1) throws Exception {
 181  
                 
 182  0
     double result = 0;
 183  0
     long key = -1;
 184  0
     int location = -1;
 185  
 
 186  
     // we can only cache if we know the indexes and caching is not 
 187  
     // disbled (m_cacheSize == -1)
 188  0
     if ( (id1 >= 0) && (m_cacheSize != -1) ) {
 189  
 
 190  
       // Use full cache?
 191  0
       if (m_cacheSize == 0) {
 192  0
         if (m_kernelMatrix == null) {
 193  0
           m_kernelMatrix = new double[m_data.numInstances()][];
 194  0
           for(int i = 0; i < m_data.numInstances(); i++) {
 195  0
             m_kernelMatrix[i] = new double[i + 1];
 196  0
             for(int j = 0; j <= i; j++) {
 197  0
               m_kernelEvals++;
 198  0
               m_kernelMatrix[i][j] = evaluate(i, j, m_data.instance(i));
 199  
             }
 200  
           }
 201  
         } 
 202  0
         m_cacheHits++;
 203  0
         result = (id1 > id2) ? m_kernelMatrix[id1][id2] : m_kernelMatrix[id2][id1];
 204  0
         return result;
 205  
       }
 206  
 
 207  
       // Use LRU cache
 208  0
       if (id1 > id2) {
 209  0
         key = (id1 + ((long) id2 * m_numInsts));
 210  
       } else {
 211  0
         key = (id2 + ((long) id1 * m_numInsts));
 212  
       }
 213  0
       location = (int) (key % m_cacheSize) * m_cacheSlots;
 214  0
       int loc = location;
 215  0
       for (int i = 0; i < m_cacheSlots; i++) {
 216  0
         long thiskey = m_keys[loc];
 217  0
         if (thiskey == 0)
 218  0
           break; // empty slot, so break out of loop early
 219  0
         if (thiskey == (key + 1)) {
 220  0
           m_cacheHits++;
 221  
           // move entry to front of cache (LRU) by swapping
 222  
           // only if it's not already at the front of cache
 223  0
           if (i > 0) {
 224  0
             double tmps = m_storage[loc];
 225  0
             m_storage[loc] = m_storage[location];
 226  0
             m_keys[loc] = m_keys[location];
 227  0
             m_storage[location] = tmps;
 228  0
             m_keys[location] = thiskey;
 229  0
             return tmps;
 230  
           } else
 231  0
             return m_storage[loc];
 232  
         }
 233  0
         loc++;
 234  
       }
 235  
     }
 236  
 
 237  0
     result = evaluate(id1, id2, inst1);
 238  
 
 239  0
     m_kernelEvals++;
 240  
 
 241  
     // store result in cache
 242  0
     if ( (key != -1) && (m_cacheSize != -1) ) {
 243  
       // move all cache slots forward one array index
 244  
       // to make room for the new entry
 245  0
       System.arraycopy(m_keys, location, m_keys, location + 1,
 246  
                        m_cacheSlots - 1);
 247  0
       System.arraycopy(m_storage, location, m_storage, location + 1,
 248  
                        m_cacheSlots - 1);
 249  0
       m_storage[location] = result;
 250  0
       m_keys[location] = (key + 1);
 251  
     }
 252  0
     return result;
 253  
   }
 254  
 
 255  
   /**
 256  
    * Returns the number of time Eval has been called.
 257  
    * 
 258  
    * @return                 the number of kernel evaluation.
 259  
    */
 260  
   public int numEvals() {
 261  0
     return m_kernelEvals;
 262  
   }
 263  
 
 264  
   /**
 265  
    * Returns the number of cache hits on dot products.
 266  
    * 
 267  
    * @return                 the number of cache hits.
 268  
    */
 269  
   public int numCacheHits() {
 270  0
     return m_cacheHits;
 271  
   }
 272  
 
 273  
   /**
 274  
    * Frees the cache used by the kernel.
 275  
    */
 276  
   public void clean() {
 277  0
     m_storage = null;
 278  0
     m_keys = null;
 279  0
     m_kernelMatrix = null;
 280  0
   }
 281  
 
 282  
   /**
 283  
    * Calculates a dot product between two instances
 284  
    * 
 285  
    * @param inst1        the first instance
 286  
    * @param inst2        the second instance
 287  
    * @return                 the dot product of the two instances.
 288  
    * @throws Exception        if an error occurs
 289  
    */
 290  
   protected final double dotProd(Instance inst1, Instance inst2)
 291  
     throws Exception {
 292  
 
 293  0
     double result = 0;
 294  
 
 295  
     // we can do a fast dot product
 296  0
     int n1 = inst1.numValues();
 297  0
     int n2 = inst2.numValues();
 298  0
     int classIndex = m_data.classIndex();
 299  0
     for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
 300  0
       int ind1 = inst1.index(p1);
 301  0
       int ind2 = inst2.index(p2);
 302  0
       if (ind1 == ind2) {
 303  0
         if (ind1 != classIndex) {
 304  0
           result += inst1.valueSparse(p1) * inst2.valueSparse(p2);
 305  
         }
 306  0
         p1++;
 307  0
         p2++;
 308  0
       } else if (ind1 > ind2) {
 309  0
         p2++;
 310  
       } else {
 311  0
         p1++;
 312  
       }
 313  0
     }
 314  0
     return (result);
 315  
   }
 316  
 
 317  
   /**
 318  
    * Sets the size of the cache to use (a prime number)
 319  
    * 
 320  
    * @param value        the size of the cache
 321  
    */
 322  
   public void setCacheSize(int value) {
 323  0
     if (value >= -1) {
 324  0
       m_cacheSize = value;
 325  0
       clean();
 326  
     }
 327  
     else {
 328  0
       System.out.println(
 329  
           "Cache size cannot be smaller than -1 (provided: " + value + ")!");
 330  
     }
 331  0
   }
 332  
   
 333  
   /**
 334  
    * Gets the size of the cache
 335  
    * 
 336  
    * @return                 the cache size
 337  
    */
 338  
   public int getCacheSize() {
 339  0
     return m_cacheSize;
 340  
   }
 341  
 
 342  
   /**
 343  
    * Returns the tip text for this property
 344  
    * 
 345  
    * @return                 tip text for this property suitable for
 346  
    *                         displaying in the explorer/experimenter gui
 347  
    */
 348  
   public String cacheSizeTipText() {
 349  0
     return "The size of the cache (a prime number), 0 for full cache and -1 to turn it off.";
 350  
   }
 351  
 
 352  
   /**
 353  
    * initializes variables etc.
 354  
    * 
 355  
    * @param data        the data to use
 356  
    */
 357  
   protected void initVars(Instances data) {
 358  0
     super.initVars(data);
 359  
     
 360  0
     m_kernelEvals = 0;
 361  0
     m_cacheHits   = 0;
 362  0
     m_numInsts    = m_data.numInstances();
 363  
 
 364  0
     if (getCacheSize() > 0) {
 365  
       // Use LRU cache
 366  0
       m_storage = new double[m_cacheSize * m_cacheSlots];
 367  0
       m_keys    = new long[m_cacheSize * m_cacheSlots];
 368  
     } 
 369  
     else {
 370  0
       m_storage      = null;
 371  0
       m_keys         = null;
 372  0
       m_kernelMatrix = null;
 373  
     }
 374  0
   }
 375  
   
 376  
   /**
 377  
    * builds the kernel with the given data. Initializes the kernel cache. 
 378  
    * The actual size of the cache in bytes is (64 * cacheSize).
 379  
    * 
 380  
    * @param data        the data to base the kernel on
 381  
    * @throws Exception        if something goes wrong
 382  
    */
 383  
   public void buildKernel(Instances data) throws Exception {
 384  
     // does kernel handle the data?
 385  0
     if (!getChecksTurnedOff())
 386  0
       getCapabilities().testWithFail(data);
 387  
 
 388  0
     initVars(data);
 389  0
   }
 390  
 }