| 1 | |
|
| 2 | |
|
| 3 | |
|
| 4 | |
|
| 5 | |
|
| 6 | |
|
| 7 | |
|
| 8 | |
|
| 9 | |
|
| 10 | |
|
| 11 | |
|
| 12 | |
|
| 13 | |
|
| 14 | |
|
| 15 | |
|
| 16 | |
|
| 17 | |
|
| 18 | |
|
| 19 | |
|
| 20 | |
|
| 21 | |
|
| 22 | |
package weka.classifiers.functions; |
| 23 | |
|
| 24 | |
import java.io.Serializable; |
| 25 | |
import java.util.Enumeration; |
| 26 | |
import java.util.Random; |
| 27 | |
import java.util.Vector; |
| 28 | |
|
| 29 | |
import weka.classifiers.AbstractClassifier; |
| 30 | |
import weka.classifiers.functions.supportVector.Kernel; |
| 31 | |
import weka.classifiers.functions.supportVector.PolyKernel; |
| 32 | |
import weka.classifiers.functions.supportVector.SMOset; |
| 33 | |
import weka.core.Attribute; |
| 34 | |
import weka.core.Capabilities; |
| 35 | |
import weka.core.Capabilities.Capability; |
| 36 | |
import weka.core.DenseInstance; |
| 37 | |
import weka.core.FastVector; |
| 38 | |
import weka.core.Instance; |
| 39 | |
import weka.core.Instances; |
| 40 | |
import weka.core.Option; |
| 41 | |
import weka.core.OptionHandler; |
| 42 | |
import weka.core.RevisionUtils; |
| 43 | |
import weka.core.SelectedTag; |
| 44 | |
import weka.core.Tag; |
| 45 | |
import weka.core.TechnicalInformation; |
| 46 | |
import weka.core.TechnicalInformation.Field; |
| 47 | |
import weka.core.TechnicalInformation.Type; |
| 48 | |
import weka.core.TechnicalInformationHandler; |
| 49 | |
import weka.core.Utils; |
| 50 | |
import weka.core.WeightedInstancesHandler; |
| 51 | |
import weka.filters.Filter; |
| 52 | |
import weka.filters.unsupervised.attribute.NominalToBinary; |
| 53 | |
import weka.filters.unsupervised.attribute.Normalize; |
| 54 | |
import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
| 55 | |
import weka.filters.unsupervised.attribute.Standardize; |
| 56 | |
|
| 57 | |
|
| 58 | |
|
| 59 | |
|
| 60 | |
|
| 61 | |
|
| 62 | |
|
| 63 | |
|
| 64 | |
|
| 65 | |
|
| 66 | |
|
| 67 | |
|
| 68 | |
|
| 69 | |
|
| 70 | |
|
| 71 | |
|
| 72 | |
|
| 73 | |
|
| 74 | |
|
| 75 | |
|
| 76 | |
|
| 77 | |
|
| 78 | |
|
| 79 | |
|
| 80 | |
|
| 81 | |
|
| 82 | |
|
| 83 | |
|
| 84 | |
|
| 85 | |
|
| 86 | |
|
| 87 | |
|
| 88 | |
|
| 89 | |
|
| 90 | |
|
| 91 | |
|
| 92 | |
|
| 93 | |
|
| 94 | |
|
| 95 | |
|
| 96 | |
|
| 97 | |
|
| 98 | |
|
| 99 | |
|
| 100 | |
|
| 101 | |
|
| 102 | |
|
| 103 | |
|
| 104 | |
|
| 105 | |
|
| 106 | |
|
| 107 | |
|
| 108 | |
|
| 109 | |
|
| 110 | |
|
| 111 | |
|
| 112 | |
|
| 113 | |
|
| 114 | |
|
| 115 | |
|
| 116 | |
|
| 117 | |
|
| 118 | |
|
| 119 | |
|
| 120 | |
|
| 121 | |
|
| 122 | |
|
| 123 | |
|
| 124 | |
|
| 125 | |
|
| 126 | |
|
| 127 | |
|
| 128 | |
|
| 129 | |
|
| 130 | |
|
| 131 | |
|
| 132 | |
|
| 133 | |
|
| 134 | |
|
| 135 | |
|
| 136 | |
|
| 137 | |
|
| 138 | |
|
| 139 | |
|
| 140 | |
|
| 141 | |
|
| 142 | |
|
| 143 | |
|
| 144 | |
|
| 145 | |
|
| 146 | |
|
| 147 | |
|
| 148 | |
|
| 149 | |
|
| 150 | |
|
| 151 | |
|
| 152 | |
|
| 153 | |
|
| 154 | |
|
| 155 | |
|
| 156 | |
|
| 157 | |
|
| 158 | |
|
| 159 | |
|
| 160 | |
|
| 161 | |
|
| 162 | |
|
| 163 | |
|
| 164 | |
|
| 165 | |
|
| 166 | |
|
| 167 | |
|
| 168 | |
|
| 169 | |
|
| 170 | |
|
| 171 | |
|
| 172 | |
|
| 173 | |
|
| 174 | |
|
| 175 | |
|
| 176 | |
|
| 177 | |
|
| 178 | |
|
| 179 | |
|
| 180 | |
|
| 181 | |
|
| 182 | |
|
| 183 | |
|
| 184 | |
|
| 185 | |
|
| 186 | |
|
| 187 | |
|
| 188 | |
|
| 189 | |
|
| 190 | |
|
| 191 | |
|
| 192 | |
|
| 193 | 0 | public class SMO |
| 194 | |
extends AbstractClassifier |
| 195 | |
implements WeightedInstancesHandler, TechnicalInformationHandler { |
| 196 | |
|
| 197 | |
|
| 198 | |
static final long serialVersionUID = -6585883636378691736L; |
| 199 | |
|
| 200 | |
|
| 201 | |
|
| 202 | |
|
| 203 | |
|
| 204 | |
|
| 205 | |
public String globalInfo() { |
| 206 | |
|
| 207 | 0 | return "Implements John Platt's sequential minimal optimization " |
| 208 | |
+ "algorithm for training a support vector classifier.\n\n" |
| 209 | |
+ "This implementation globally replaces all missing values and " |
| 210 | |
+ "transforms nominal attributes into binary ones. It also " |
| 211 | |
+ "normalizes all attributes by default. (In that case the coefficients " |
| 212 | |
+ "in the output are based on the normalized data, not the " |
| 213 | |
+ "original data --- this is important for interpreting the classifier.)\n\n" |
| 214 | |
+ "Multi-class problems are solved using pairwise classification " |
| 215 | |
+ "(1-vs-1 and if logistic models are built pairwise coupling " |
| 216 | |
+ "according to Hastie and Tibshirani, 1998).\n\n" |
| 217 | |
+ "To obtain proper probability estimates, use the option that fits " |
| 218 | |
+ "logistic regression models to the outputs of the support vector " |
| 219 | |
+ "machine. In the multi-class case the predicted probabilities " |
| 220 | |
+ "are coupled using Hastie and Tibshirani's pairwise coupling " |
| 221 | |
+ "method.\n\n" |
| 222 | |
+ "Note: for improved speed normalization should be turned off when " |
| 223 | |
+ "operating on SparseInstances.\n\n" |
| 224 | |
+ "For more information on the SMO algorithm, see\n\n" |
| 225 | |
+ getTechnicalInformation().toString(); |
| 226 | |
} |
| 227 | |
|
| 228 | |
|
| 229 | |
|
| 230 | |
|
| 231 | |
|
| 232 | |
|
| 233 | |
|
| 234 | |
|
| 235 | |
public TechnicalInformation getTechnicalInformation() { |
| 236 | |
TechnicalInformation result; |
| 237 | |
TechnicalInformation additional; |
| 238 | |
|
| 239 | 0 | result = new TechnicalInformation(Type.INCOLLECTION); |
| 240 | 0 | result.setValue(Field.AUTHOR, "J. Platt"); |
| 241 | 0 | result.setValue(Field.YEAR, "1998"); |
| 242 | 0 | result.setValue(Field.TITLE, "Fast Training of Support Vector Machines using Sequential Minimal Optimization"); |
| 243 | 0 | result.setValue(Field.BOOKTITLE, "Advances in Kernel Methods - Support Vector Learning"); |
| 244 | 0 | result.setValue(Field.EDITOR, "B. Schoelkopf and C. Burges and A. Smola"); |
| 245 | 0 | result.setValue(Field.PUBLISHER, "MIT Press"); |
| 246 | 0 | result.setValue(Field.URL, "http://research.microsoft.com/~jplatt/smo.html"); |
| 247 | 0 | result.setValue(Field.PDF, "http://research.microsoft.com/~jplatt/smo-book.pdf"); |
| 248 | 0 | result.setValue(Field.PS, "http://research.microsoft.com/~jplatt/smo-book.ps.gz"); |
| 249 | |
|
| 250 | 0 | additional = result.add(Type.ARTICLE); |
| 251 | 0 | additional.setValue(Field.AUTHOR, "S.S. Keerthi and S.K. Shevade and C. Bhattacharyya and K.R.K. Murthy"); |
| 252 | 0 | additional.setValue(Field.YEAR, "2001"); |
| 253 | 0 | additional.setValue(Field.TITLE, "Improvements to Platt's SMO Algorithm for SVM Classifier Design"); |
| 254 | 0 | additional.setValue(Field.JOURNAL, "Neural Computation"); |
| 255 | 0 | additional.setValue(Field.VOLUME, "13"); |
| 256 | 0 | additional.setValue(Field.NUMBER, "3"); |
| 257 | 0 | additional.setValue(Field.PAGES, "637-649"); |
| 258 | 0 | additional.setValue(Field.PS, "http://guppy.mpe.nus.edu.sg/~mpessk/svm/smo_mod_nc.ps.gz"); |
| 259 | |
|
| 260 | 0 | additional = result.add(Type.INPROCEEDINGS); |
| 261 | 0 | additional.setValue(Field.AUTHOR, "Trevor Hastie and Robert Tibshirani"); |
| 262 | 0 | additional.setValue(Field.YEAR, "1998"); |
| 263 | 0 | additional.setValue(Field.TITLE, "Classification by Pairwise Coupling"); |
| 264 | 0 | additional.setValue(Field.BOOKTITLE, "Advances in Neural Information Processing Systems"); |
| 265 | 0 | additional.setValue(Field.VOLUME, "10"); |
| 266 | 0 | additional.setValue(Field.PUBLISHER, "MIT Press"); |
| 267 | 0 | additional.setValue(Field.EDITOR, "Michael I. Jordan and Michael J. Kearns and Sara A. Solla"); |
| 268 | 0 | additional.setValue(Field.PS, "http://www-stat.stanford.edu/~hastie/Papers/2class.ps"); |
| 269 | |
|
| 270 | 0 | return result; |
| 271 | |
} |
| 272 | |
|
| 273 | |
|
| 274 | |
|
| 275 | |
|
| 276 | 0 | public class BinarySMO |
| 277 | |
implements Serializable { |
| 278 | |
|
| 279 | |
|
| 280 | |
static final long serialVersionUID = -8246163625699362456L; |
| 281 | |
|
| 282 | |
|
| 283 | |
protected double[] m_alpha; |
| 284 | |
|
| 285 | |
|
| 286 | |
protected double m_b, m_bLow, m_bUp; |
| 287 | |
|
| 288 | |
|
| 289 | |
protected int m_iLow, m_iUp; |
| 290 | |
|
| 291 | |
|
| 292 | |
protected Instances m_data; |
| 293 | |
|
| 294 | |
|
| 295 | |
protected double[] m_weights; |
| 296 | |
|
| 297 | |
|
| 298 | |
|
| 299 | |
protected double[] m_sparseWeights; |
| 300 | |
protected int[] m_sparseIndices; |
| 301 | |
|
| 302 | |
|
| 303 | |
protected Kernel m_kernel; |
| 304 | |
|
| 305 | |
|
| 306 | |
protected double[] m_class; |
| 307 | |
|
| 308 | |
|
| 309 | |
protected double[] m_errors; |
| 310 | |
|
| 311 | |
|
| 312 | |
|
| 313 | |
protected SMOset m_I0; |
| 314 | |
|
| 315 | |
protected SMOset m_I1; |
| 316 | |
|
| 317 | |
protected SMOset m_I2; |
| 318 | |
|
| 319 | |
protected SMOset m_I3; |
| 320 | |
|
| 321 | |
protected SMOset m_I4; |
| 322 | |
|
| 323 | |
|
| 324 | |
protected SMOset m_supportVectors; |
| 325 | |
|
| 326 | |
|
| 327 | 0 | protected Logistic m_logistic = null; |
| 328 | |
|
| 329 | |
|
| 330 | 0 | protected double m_sumOfWeights = 0; |
| 331 | |
|
| 332 | |
|
| 333 | |
|
| 334 | |
|
| 335 | |
|
| 336 | |
|
| 337 | |
|
| 338 | |
|
| 339 | |
|
| 340 | |
|
| 341 | |
|
| 342 | |
|
| 343 | |
protected void fitLogistic(Instances insts, int cl1, int cl2, |
| 344 | |
int numFolds, Random random) |
| 345 | |
throws Exception { |
| 346 | |
|
| 347 | |
|
| 348 | 0 | FastVector atts = new FastVector(2); |
| 349 | 0 | atts.addElement(new Attribute("pred")); |
| 350 | 0 | FastVector attVals = new FastVector(2); |
| 351 | 0 | attVals.addElement(insts.classAttribute().value(cl1)); |
| 352 | 0 | attVals.addElement(insts.classAttribute().value(cl2)); |
| 353 | 0 | atts.addElement(new Attribute("class", attVals)); |
| 354 | 0 | Instances data = new Instances("data", atts, insts.numInstances()); |
| 355 | 0 | data.setClassIndex(1); |
| 356 | |
|
| 357 | |
|
| 358 | 0 | if (numFolds <= 0) { |
| 359 | |
|
| 360 | |
|
| 361 | 0 | for (int j = 0; j < insts.numInstances(); j++) { |
| 362 | 0 | Instance inst = insts.instance(j); |
| 363 | 0 | double[] vals = new double[2]; |
| 364 | 0 | vals[0] = SVMOutput(-1, inst); |
| 365 | 0 | if (inst.classValue() == cl2) { |
| 366 | 0 | vals[1] = 1; |
| 367 | |
} |
| 368 | 0 | data.add(new DenseInstance(inst.weight(), vals)); |
| 369 | |
} |
| 370 | |
} else { |
| 371 | |
|
| 372 | |
|
| 373 | 0 | if (numFolds > insts.numInstances()) { |
| 374 | 0 | numFolds = insts.numInstances(); |
| 375 | |
} |
| 376 | |
|
| 377 | |
|
| 378 | 0 | insts = new Instances(insts); |
| 379 | |
|
| 380 | |
|
| 381 | |
|
| 382 | 0 | insts.randomize(random); |
| 383 | 0 | insts.stratify(numFolds); |
| 384 | 0 | for (int i = 0; i < numFolds; i++) { |
| 385 | 0 | Instances train = insts.trainCV(numFolds, i, random); |
| 386 | |
|
| 387 | |
|
| 388 | 0 | BinarySMO smo = new BinarySMO(); |
| 389 | 0 | smo.setKernel(Kernel.makeCopy(SMO.this.m_kernel)); |
| 390 | 0 | smo.buildClassifier(train, cl1, cl2, false, -1, -1); |
| 391 | 0 | Instances test = insts.testCV(numFolds, i); |
| 392 | 0 | for (int j = 0; j < test.numInstances(); j++) { |
| 393 | 0 | double[] vals = new double[2]; |
| 394 | 0 | vals[0] = smo.SVMOutput(-1, test.instance(j)); |
| 395 | 0 | if (test.instance(j).classValue() == cl2) { |
| 396 | 0 | vals[1] = 1; |
| 397 | |
} |
| 398 | 0 | data.add(new DenseInstance(test.instance(j).weight(), vals)); |
| 399 | |
} |
| 400 | |
} |
| 401 | |
} |
| 402 | |
|
| 403 | |
|
| 404 | 0 | m_logistic = new Logistic(); |
| 405 | 0 | m_logistic.buildClassifier(data); |
| 406 | 0 | } |
| 407 | |
|
| 408 | |
|
| 409 | |
|
| 410 | |
|
| 411 | |
|
| 412 | |
|
| 413 | |
public void setKernel(Kernel value) { |
| 414 | 0 | m_kernel = value; |
| 415 | 0 | } |
| 416 | |
|
| 417 | |
|
| 418 | |
|
| 419 | |
|
| 420 | |
|
| 421 | |
|
| 422 | |
public Kernel getKernel() { |
| 423 | 0 | return m_kernel; |
| 424 | |
} |
| 425 | |
|
| 426 | |
|
| 427 | |
|
| 428 | |
|
| 429 | |
|
| 430 | |
|
| 431 | |
|
| 432 | |
|
| 433 | |
|
| 434 | |
|
| 435 | |
|
| 436 | |
|
| 437 | |
protected void buildClassifier(Instances insts, int cl1, int cl2, |
| 438 | |
boolean fitLogistic, int numFolds, |
| 439 | |
int randomSeed) throws Exception { |
| 440 | |
|
| 441 | |
|
| 442 | 0 | m_bUp = -1; m_bLow = 1; m_b = 0; |
| 443 | 0 | m_alpha = null; m_data = null; m_weights = null; m_errors = null; |
| 444 | 0 | m_logistic = null; m_I0 = null; m_I1 = null; m_I2 = null; |
| 445 | 0 | m_I3 = null; m_I4 = null; m_sparseWeights = null; m_sparseIndices = null; |
| 446 | |
|
| 447 | |
|
| 448 | 0 | m_sumOfWeights = insts.sumOfWeights(); |
| 449 | |
|
| 450 | |
|
| 451 | 0 | m_class = new double[insts.numInstances()]; |
| 452 | 0 | m_iUp = -1; m_iLow = -1; |
| 453 | 0 | for (int i = 0; i < m_class.length; i++) { |
| 454 | 0 | if ((int) insts.instance(i).classValue() == cl1) { |
| 455 | 0 | m_class[i] = -1; m_iLow = i; |
| 456 | 0 | } else if ((int) insts.instance(i).classValue() == cl2) { |
| 457 | 0 | m_class[i] = 1; m_iUp = i; |
| 458 | |
} else { |
| 459 | 0 | throw new Exception ("This should never happen!"); |
| 460 | |
} |
| 461 | |
} |
| 462 | |
|
| 463 | |
|
| 464 | 0 | if ((m_iUp == -1) || (m_iLow == -1)) { |
| 465 | 0 | if (m_iUp != -1) { |
| 466 | 0 | m_b = -1; |
| 467 | 0 | } else if (m_iLow != -1) { |
| 468 | 0 | m_b = 1; |
| 469 | |
} else { |
| 470 | 0 | m_class = null; |
| 471 | 0 | return; |
| 472 | |
} |
| 473 | 0 | if (m_KernelIsLinear) { |
| 474 | 0 | m_sparseWeights = new double[0]; |
| 475 | 0 | m_sparseIndices = new int[0]; |
| 476 | 0 | m_class = null; |
| 477 | |
} else { |
| 478 | 0 | m_supportVectors = new SMOset(0); |
| 479 | 0 | m_alpha = new double[0]; |
| 480 | 0 | m_class = new double[0]; |
| 481 | |
} |
| 482 | |
|
| 483 | |
|
| 484 | 0 | if (fitLogistic) { |
| 485 | 0 | fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed)); |
| 486 | |
} |
| 487 | 0 | return; |
| 488 | |
} |
| 489 | |
|
| 490 | |
|
| 491 | 0 | m_data = insts; |
| 492 | |
|
| 493 | |
|
| 494 | 0 | if (m_KernelIsLinear) { |
| 495 | 0 | m_weights = new double[m_data.numAttributes()]; |
| 496 | |
} else { |
| 497 | 0 | m_weights = null; |
| 498 | |
} |
| 499 | |
|
| 500 | |
|
| 501 | 0 | m_alpha = new double[m_data.numInstances()]; |
| 502 | |
|
| 503 | |
|
| 504 | 0 | m_supportVectors = new SMOset(m_data.numInstances()); |
| 505 | 0 | m_I0 = new SMOset(m_data.numInstances()); |
| 506 | 0 | m_I1 = new SMOset(m_data.numInstances()); |
| 507 | 0 | m_I2 = new SMOset(m_data.numInstances()); |
| 508 | 0 | m_I3 = new SMOset(m_data.numInstances()); |
| 509 | 0 | m_I4 = new SMOset(m_data.numInstances()); |
| 510 | |
|
| 511 | |
|
| 512 | 0 | m_sparseWeights = null; |
| 513 | 0 | m_sparseIndices = null; |
| 514 | |
|
| 515 | |
|
| 516 | 0 | m_kernel.buildKernel(m_data); |
| 517 | |
|
| 518 | |
|
| 519 | 0 | m_errors = new double[m_data.numInstances()]; |
| 520 | 0 | m_errors[m_iLow] = 1; m_errors[m_iUp] = -1; |
| 521 | |
|
| 522 | |
|
| 523 | 0 | for (int i = 0; i < m_class.length; i++ ) { |
| 524 | 0 | if (m_class[i] == 1) { |
| 525 | 0 | m_I1.insert(i); |
| 526 | |
} else { |
| 527 | 0 | m_I4.insert(i); |
| 528 | |
} |
| 529 | |
} |
| 530 | |
|
| 531 | |
|
| 532 | 0 | int numChanged = 0; |
| 533 | 0 | boolean examineAll = true; |
| 534 | 0 | while ((numChanged > 0) || examineAll) { |
| 535 | 0 | numChanged = 0; |
| 536 | 0 | if (examineAll) { |
| 537 | 0 | for (int i = 0; i < m_alpha.length; i++) { |
| 538 | 0 | if (examineExample(i)) { |
| 539 | 0 | numChanged++; |
| 540 | |
} |
| 541 | |
} |
| 542 | |
} else { |
| 543 | |
|
| 544 | |
|
| 545 | 0 | for (int i = 0; i < m_alpha.length; i++) { |
| 546 | 0 | if ((m_alpha[i] > 0) && |
| 547 | |
(m_alpha[i] < m_C * m_data.instance(i).weight())) { |
| 548 | 0 | if (examineExample(i)) { |
| 549 | 0 | numChanged++; |
| 550 | |
} |
| 551 | |
|
| 552 | |
|
| 553 | 0 | if (m_bUp > m_bLow - 2 * m_tol) { |
| 554 | 0 | numChanged = 0; |
| 555 | 0 | break; |
| 556 | |
} |
| 557 | |
} |
| 558 | |
} |
| 559 | |
|
| 560 | |
|
| 561 | |
|
| 562 | |
|
| 563 | |
|
| 564 | |
|
| 565 | |
|
| 566 | |
} |
| 567 | |
|
| 568 | 0 | if (examineAll) { |
| 569 | 0 | examineAll = false; |
| 570 | 0 | } else if (numChanged == 0) { |
| 571 | 0 | examineAll = true; |
| 572 | |
} |
| 573 | |
} |
| 574 | |
|
| 575 | |
|
| 576 | 0 | m_b = (m_bLow + m_bUp) / 2.0; |
| 577 | |
|
| 578 | |
|
| 579 | 0 | m_kernel.clean(); |
| 580 | |
|
| 581 | 0 | m_errors = null; |
| 582 | 0 | m_I0 = m_I1 = m_I2 = m_I3 = m_I4 = null; |
| 583 | |
|
| 584 | |
|
| 585 | |
|
| 586 | 0 | if (m_KernelIsLinear) { |
| 587 | |
|
| 588 | |
|
| 589 | 0 | m_supportVectors = null; |
| 590 | |
|
| 591 | |
|
| 592 | 0 | m_class = null; |
| 593 | |
|
| 594 | |
|
| 595 | 0 | if (!m_checksTurnedOff) { |
| 596 | 0 | m_data = new Instances(m_data, 0); |
| 597 | |
} else { |
| 598 | 0 | m_data = null; |
| 599 | |
} |
| 600 | |
|
| 601 | |
|
| 602 | 0 | double[] sparseWeights = new double[m_weights.length]; |
| 603 | 0 | int[] sparseIndices = new int[m_weights.length]; |
| 604 | 0 | int counter = 0; |
| 605 | 0 | for (int i = 0; i < m_weights.length; i++) { |
| 606 | 0 | if (m_weights[i] != 0.0) { |
| 607 | 0 | sparseWeights[counter] = m_weights[i]; |
| 608 | 0 | sparseIndices[counter] = i; |
| 609 | 0 | counter++; |
| 610 | |
} |
| 611 | |
} |
| 612 | 0 | m_sparseWeights = new double[counter]; |
| 613 | 0 | m_sparseIndices = new int[counter]; |
| 614 | 0 | System.arraycopy(sparseWeights, 0, m_sparseWeights, 0, counter); |
| 615 | 0 | System.arraycopy(sparseIndices, 0, m_sparseIndices, 0, counter); |
| 616 | |
|
| 617 | |
|
| 618 | 0 | m_weights = null; |
| 619 | |
|
| 620 | |
|
| 621 | 0 | m_alpha = null; |
| 622 | |
} |
| 623 | |
|
| 624 | |
|
| 625 | 0 | if (fitLogistic) { |
| 626 | 0 | fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed)); |
| 627 | |
} |
| 628 | |
|
| 629 | 0 | } |
| 630 | |
|
| 631 | |
|
| 632 | |
|
| 633 | |
|
| 634 | |
|
| 635 | |
|
| 636 | |
|
| 637 | |
|
| 638 | |
|
| 639 | |
public double SVMOutput(int index, Instance inst) throws Exception { |
| 640 | |
|
| 641 | 0 | double result = 0; |
| 642 | |
|
| 643 | |
|
| 644 | 0 | if (m_KernelIsLinear) { |
| 645 | |
|
| 646 | |
|
| 647 | 0 | if (m_sparseWeights == null) { |
| 648 | 0 | int n1 = inst.numValues(); |
| 649 | 0 | for (int p = 0; p < n1; p++) { |
| 650 | 0 | if (inst.index(p) != m_classIndex) { |
| 651 | 0 | result += m_weights[inst.index(p)] * inst.valueSparse(p); |
| 652 | |
} |
| 653 | |
} |
| 654 | 0 | } else { |
| 655 | 0 | int n1 = inst.numValues(); int n2 = m_sparseWeights.length; |
| 656 | 0 | for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) { |
| 657 | 0 | int ind1 = inst.index(p1); |
| 658 | 0 | int ind2 = m_sparseIndices[p2]; |
| 659 | 0 | if (ind1 == ind2) { |
| 660 | 0 | if (ind1 != m_classIndex) { |
| 661 | 0 | result += inst.valueSparse(p1) * m_sparseWeights[p2]; |
| 662 | |
} |
| 663 | 0 | p1++; p2++; |
| 664 | 0 | } else if (ind1 > ind2) { |
| 665 | 0 | p2++; |
| 666 | |
} else { |
| 667 | 0 | p1++; |
| 668 | |
} |
| 669 | 0 | } |
| 670 | 0 | } |
| 671 | |
} else { |
| 672 | 0 | for (int i = m_supportVectors.getNext(-1); i != -1; |
| 673 | 0 | i = m_supportVectors.getNext(i)) { |
| 674 | 0 | result += m_class[i] * m_alpha[i] * m_kernel.eval(index, i, inst); |
| 675 | |
} |
| 676 | |
} |
| 677 | 0 | result -= m_b; |
| 678 | |
|
| 679 | 0 | return result; |
| 680 | |
} |
| 681 | |
|
| 682 | |
|
| 683 | |
|
| 684 | |
|
| 685 | |
|
| 686 | |
|
| 687 | |
public String toString() { |
| 688 | |
|
| 689 | 0 | StringBuffer text = new StringBuffer(); |
| 690 | 0 | int printed = 0; |
| 691 | |
|
| 692 | 0 | if ((m_alpha == null) && (m_sparseWeights == null)) { |
| 693 | 0 | return "BinarySMO: No model built yet.\n"; |
| 694 | |
} |
| 695 | |
try { |
| 696 | 0 | text.append("BinarySMO\n\n"); |
| 697 | |
|
| 698 | |
|
| 699 | 0 | if (m_KernelIsLinear) { |
| 700 | 0 | text.append("Machine linear: showing attribute weights, "); |
| 701 | 0 | text.append("not support vectors.\n\n"); |
| 702 | |
|
| 703 | |
|
| 704 | |
|
| 705 | 0 | for (int i = 0; i < m_sparseWeights.length; i++) { |
| 706 | 0 | if (m_sparseIndices[i] != (int)m_classIndex) { |
| 707 | 0 | if (printed > 0) { |
| 708 | 0 | text.append(" + "); |
| 709 | |
} else { |
| 710 | 0 | text.append(" "); |
| 711 | |
} |
| 712 | 0 | text.append(Utils.doubleToString(m_sparseWeights[i], 12, 4) + |
| 713 | |
" * "); |
| 714 | 0 | if (m_filterType == FILTER_STANDARDIZE) { |
| 715 | 0 | text.append("(standardized) "); |
| 716 | 0 | } else if (m_filterType == FILTER_NORMALIZE) { |
| 717 | 0 | text.append("(normalized) "); |
| 718 | |
} |
| 719 | 0 | if (!m_checksTurnedOff) { |
| 720 | 0 | text.append(m_data.attribute(m_sparseIndices[i]).name()+"\n"); |
| 721 | |
} else { |
| 722 | 0 | text.append("attribute with index " + |
| 723 | |
m_sparseIndices[i] +"\n"); |
| 724 | |
} |
| 725 | 0 | printed++; |
| 726 | |
} |
| 727 | |
} |
| 728 | |
} else { |
| 729 | 0 | for (int i = 0; i < m_alpha.length; i++) { |
| 730 | 0 | if (m_supportVectors.contains(i)) { |
| 731 | 0 | double val = m_alpha[i]; |
| 732 | 0 | if (m_class[i] == 1) { |
| 733 | 0 | if (printed > 0) { |
| 734 | 0 | text.append(" + "); |
| 735 | |
} |
| 736 | |
} else { |
| 737 | 0 | text.append(" - "); |
| 738 | |
} |
| 739 | 0 | text.append(Utils.doubleToString(val, 12, 4) |
| 740 | |
+ " * <"); |
| 741 | 0 | for (int j = 0; j < m_data.numAttributes(); j++) { |
| 742 | 0 | if (j != m_data.classIndex()) { |
| 743 | 0 | text.append(m_data.instance(i).toString(j)); |
| 744 | |
} |
| 745 | 0 | if (j != m_data.numAttributes() - 1) { |
| 746 | 0 | text.append(" "); |
| 747 | |
} |
| 748 | |
} |
| 749 | 0 | text.append("> * X]\n"); |
| 750 | 0 | printed++; |
| 751 | |
} |
| 752 | |
} |
| 753 | |
} |
| 754 | 0 | if (m_b > 0) { |
| 755 | 0 | text.append(" - " + Utils.doubleToString(m_b, 12, 4)); |
| 756 | |
} else { |
| 757 | 0 | text.append(" + " + Utils.doubleToString(-m_b, 12, 4)); |
| 758 | |
} |
| 759 | |
|
| 760 | 0 | if (!m_KernelIsLinear) { |
| 761 | 0 | text.append("\n\nNumber of support vectors: " + |
| 762 | |
m_supportVectors.numElements()); |
| 763 | |
} |
| 764 | 0 | int numEval = 0; |
| 765 | 0 | int numCacheHits = -1; |
| 766 | 0 | if (m_kernel != null) { |
| 767 | 0 | numEval = m_kernel.numEvals(); |
| 768 | 0 | numCacheHits = m_kernel.numCacheHits(); |
| 769 | |
} |
| 770 | 0 | text.append("\n\nNumber of kernel evaluations: " + numEval); |
| 771 | 0 | if (numCacheHits >= 0 && numEval > 0) { |
| 772 | 0 | double hitRatio = 1 - numEval*1.0/(numCacheHits+numEval); |
| 773 | 0 | text.append(" (" + Utils.doubleToString(hitRatio*100, 7, 3).trim() + "% cached)"); |
| 774 | |
} |
| 775 | |
|
| 776 | 0 | } catch (Exception e) { |
| 777 | 0 | e.printStackTrace(); |
| 778 | |
|
| 779 | 0 | return "Can't print BinarySMO classifier."; |
| 780 | 0 | } |
| 781 | |
|
| 782 | 0 | return text.toString(); |
| 783 | |
} |
| 784 | |
|
| 785 | |
|
| 786 | |
|
| 787 | |
|
| 788 | |
|
| 789 | |
|
| 790 | |
|
| 791 | |
|
| 792 | |
protected boolean examineExample(int i2) throws Exception { |
| 793 | |
|
| 794 | |
double y2, F2; |
| 795 | 0 | int i1 = -1; |
| 796 | |
|
| 797 | 0 | y2 = m_class[i2]; |
| 798 | 0 | if (m_I0.contains(i2)) { |
| 799 | 0 | F2 = m_errors[i2]; |
| 800 | |
} else { |
| 801 | 0 | F2 = SVMOutput(i2, m_data.instance(i2)) + m_b - y2; |
| 802 | 0 | m_errors[i2] = F2; |
| 803 | |
|
| 804 | |
|
| 805 | 0 | if ((m_I1.contains(i2) || m_I2.contains(i2)) && (F2 < m_bUp)) { |
| 806 | 0 | m_bUp = F2; m_iUp = i2; |
| 807 | 0 | } else if ((m_I3.contains(i2) || m_I4.contains(i2)) && (F2 > m_bLow)) { |
| 808 | 0 | m_bLow = F2; m_iLow = i2; |
| 809 | |
} |
| 810 | |
} |
| 811 | |
|
| 812 | |
|
| 813 | |
|
| 814 | |
|
| 815 | 0 | boolean optimal = true; |
| 816 | 0 | if (m_I0.contains(i2) || m_I1.contains(i2) || m_I2.contains(i2)) { |
| 817 | 0 | if (m_bLow - F2 > 2 * m_tol) { |
| 818 | 0 | optimal = false; i1 = m_iLow; |
| 819 | |
} |
| 820 | |
} |
| 821 | 0 | if (m_I0.contains(i2) || m_I3.contains(i2) || m_I4.contains(i2)) { |
| 822 | 0 | if (F2 - m_bUp > 2 * m_tol) { |
| 823 | 0 | optimal = false; i1 = m_iUp; |
| 824 | |
} |
| 825 | |
} |
| 826 | 0 | if (optimal) { |
| 827 | 0 | return false; |
| 828 | |
} |
| 829 | |
|
| 830 | |
|
| 831 | 0 | if (m_I0.contains(i2)) { |
| 832 | 0 | if (m_bLow - F2 > F2 - m_bUp) { |
| 833 | 0 | i1 = m_iLow; |
| 834 | |
} else { |
| 835 | 0 | i1 = m_iUp; |
| 836 | |
} |
| 837 | |
} |
| 838 | 0 | if (i1 == -1) { |
| 839 | 0 | throw new Exception("This should never happen!"); |
| 840 | |
} |
| 841 | 0 | return takeStep(i1, i2, F2); |
| 842 | |
} |
| 843 | |
|
| 844 | |
|
| 845 | |
|
| 846 | |
|
| 847 | |
|
| 848 | |
|
| 849 | |
|
| 850 | |
|
| 851 | |
|
| 852 | |
|
| 853 | |
|
| 854 | |
protected boolean takeStep(int i1, int i2, double F2) throws Exception { |
| 855 | |
|
| 856 | |
double alph1, alph2, y1, y2, F1, s, L, H, k11, k12, k22, eta, |
| 857 | |
a1, a2, f1, f2, v1, v2, Lobj, Hobj; |
| 858 | 0 | double C1 = m_C * m_data.instance(i1).weight(); |
| 859 | 0 | double C2 = m_C * m_data.instance(i2).weight(); |
| 860 | |
|
| 861 | |
|
| 862 | 0 | if (i1 == i2) { |
| 863 | 0 | return false; |
| 864 | |
} |
| 865 | |
|
| 866 | |
|
| 867 | 0 | alph1 = m_alpha[i1]; alph2 = m_alpha[i2]; |
| 868 | 0 | y1 = m_class[i1]; y2 = m_class[i2]; |
| 869 | 0 | F1 = m_errors[i1]; |
| 870 | 0 | s = y1 * y2; |
| 871 | |
|
| 872 | |
|
| 873 | 0 | if (y1 != y2) { |
| 874 | 0 | L = Math.max(0, alph2 - alph1); |
| 875 | 0 | H = Math.min(C2, C1 + alph2 - alph1); |
| 876 | |
} else { |
| 877 | 0 | L = Math.max(0, alph1 + alph2 - C1); |
| 878 | 0 | H = Math.min(C2, alph1 + alph2); |
| 879 | |
} |
| 880 | 0 | if (L >= H) { |
| 881 | 0 | return false; |
| 882 | |
} |
| 883 | |
|
| 884 | |
|
| 885 | 0 | k11 = m_kernel.eval(i1, i1, m_data.instance(i1)); |
| 886 | 0 | k12 = m_kernel.eval(i1, i2, m_data.instance(i1)); |
| 887 | 0 | k22 = m_kernel.eval(i2, i2, m_data.instance(i2)); |
| 888 | 0 | eta = 2 * k12 - k11 - k22; |
| 889 | |
|
| 890 | |
|
| 891 | 0 | if (eta < 0) { |
| 892 | |
|
| 893 | |
|
| 894 | 0 | a2 = alph2 - y2 * (F1 - F2) / eta; |
| 895 | |
|
| 896 | |
|
| 897 | 0 | if (a2 < L) { |
| 898 | 0 | a2 = L; |
| 899 | 0 | } else if (a2 > H) { |
| 900 | 0 | a2 = H; |
| 901 | |
} |
| 902 | |
} else { |
| 903 | |
|
| 904 | |
|
| 905 | 0 | f1 = SVMOutput(i1, m_data.instance(i1)); |
| 906 | 0 | f2 = SVMOutput(i2, m_data.instance(i2)); |
| 907 | 0 | v1 = f1 + m_b - y1 * alph1 * k11 - y2 * alph2 * k12; |
| 908 | 0 | v2 = f2 + m_b - y1 * alph1 * k12 - y2 * alph2 * k22; |
| 909 | 0 | double gamma = alph1 + s * alph2; |
| 910 | 0 | Lobj = (gamma - s * L) + L - 0.5 * k11 * (gamma - s * L) * (gamma - s * L) - |
| 911 | |
0.5 * k22 * L * L - s * k12 * (gamma - s * L) * L - |
| 912 | |
y1 * (gamma - s * L) * v1 - y2 * L * v2; |
| 913 | 0 | Hobj = (gamma - s * H) + H - 0.5 * k11 * (gamma - s * H) * (gamma - s * H) - |
| 914 | |
0.5 * k22 * H * H - s * k12 * (gamma - s * H) * H - |
| 915 | |
y1 * (gamma - s * H) * v1 - y2 * H * v2; |
| 916 | 0 | if (Lobj > Hobj + m_eps) { |
| 917 | 0 | a2 = L; |
| 918 | 0 | } else if (Lobj < Hobj - m_eps) { |
| 919 | 0 | a2 = H; |
| 920 | |
} else { |
| 921 | 0 | a2 = alph2; |
| 922 | |
} |
| 923 | |
} |
| 924 | 0 | if (Math.abs(a2 - alph2) < m_eps * (a2 + alph2 + m_eps)) { |
| 925 | 0 | return false; |
| 926 | |
} |
| 927 | |
|
| 928 | |
|
| 929 | 0 | if (a2 > C2 - m_Del * C2) { |
| 930 | 0 | a2 = C2; |
| 931 | 0 | } else if (a2 <= m_Del * C2) { |
| 932 | 0 | a2 = 0; |
| 933 | |
} |
| 934 | |
|
| 935 | |
|
| 936 | 0 | a1 = alph1 + s * (alph2 - a2); |
| 937 | |
|
| 938 | |
|
| 939 | 0 | if (a1 > C1 - m_Del * C1) { |
| 940 | 0 | a1 = C1; |
| 941 | 0 | } else if (a1 <= m_Del * C1) { |
| 942 | 0 | a1 = 0; |
| 943 | |
} |
| 944 | |
|
| 945 | |
|
| 946 | 0 | if (a1 > 0) { |
| 947 | 0 | m_supportVectors.insert(i1); |
| 948 | |
} else { |
| 949 | 0 | m_supportVectors.delete(i1); |
| 950 | |
} |
| 951 | 0 | if ((a1 > 0) && (a1 < C1)) { |
| 952 | 0 | m_I0.insert(i1); |
| 953 | |
} else { |
| 954 | 0 | m_I0.delete(i1); |
| 955 | |
} |
| 956 | 0 | if ((y1 == 1) && (a1 == 0)) { |
| 957 | 0 | m_I1.insert(i1); |
| 958 | |
} else { |
| 959 | 0 | m_I1.delete(i1); |
| 960 | |
} |
| 961 | 0 | if ((y1 == -1) && (a1 == C1)) { |
| 962 | 0 | m_I2.insert(i1); |
| 963 | |
} else { |
| 964 | 0 | m_I2.delete(i1); |
| 965 | |
} |
| 966 | 0 | if ((y1 == 1) && (a1 == C1)) { |
| 967 | 0 | m_I3.insert(i1); |
| 968 | |
} else { |
| 969 | 0 | m_I3.delete(i1); |
| 970 | |
} |
| 971 | 0 | if ((y1 == -1) && (a1 == 0)) { |
| 972 | 0 | m_I4.insert(i1); |
| 973 | |
} else { |
| 974 | 0 | m_I4.delete(i1); |
| 975 | |
} |
| 976 | 0 | if (a2 > 0) { |
| 977 | 0 | m_supportVectors.insert(i2); |
| 978 | |
} else { |
| 979 | 0 | m_supportVectors.delete(i2); |
| 980 | |
} |
| 981 | 0 | if ((a2 > 0) && (a2 < C2)) { |
| 982 | 0 | m_I0.insert(i2); |
| 983 | |
} else { |
| 984 | 0 | m_I0.delete(i2); |
| 985 | |
} |
| 986 | 0 | if ((y2 == 1) && (a2 == 0)) { |
| 987 | 0 | m_I1.insert(i2); |
| 988 | |
} else { |
| 989 | 0 | m_I1.delete(i2); |
| 990 | |
} |
| 991 | 0 | if ((y2 == -1) && (a2 == C2)) { |
| 992 | 0 | m_I2.insert(i2); |
| 993 | |
} else { |
| 994 | 0 | m_I2.delete(i2); |
| 995 | |
} |
| 996 | 0 | if ((y2 == 1) && (a2 == C2)) { |
| 997 | 0 | m_I3.insert(i2); |
| 998 | |
} else { |
| 999 | 0 | m_I3.delete(i2); |
| 1000 | |
} |
| 1001 | 0 | if ((y2 == -1) && (a2 == 0)) { |
| 1002 | 0 | m_I4.insert(i2); |
| 1003 | |
} else { |
| 1004 | 0 | m_I4.delete(i2); |
| 1005 | |
} |
| 1006 | |
|
| 1007 | |
|
| 1008 | 0 | if (m_KernelIsLinear) { |
| 1009 | 0 | Instance inst1 = m_data.instance(i1); |
| 1010 | 0 | for (int p1 = 0; p1 < inst1.numValues(); p1++) { |
| 1011 | 0 | if (inst1.index(p1) != m_data.classIndex()) { |
| 1012 | 0 | m_weights[inst1.index(p1)] += |
| 1013 | |
y1 * (a1 - alph1) * inst1.valueSparse(p1); |
| 1014 | |
} |
| 1015 | |
} |
| 1016 | 0 | Instance inst2 = m_data.instance(i2); |
| 1017 | 0 | for (int p2 = 0; p2 < inst2.numValues(); p2++) { |
| 1018 | 0 | if (inst2.index(p2) != m_data.classIndex()) { |
| 1019 | 0 | m_weights[inst2.index(p2)] += |
| 1020 | |
y2 * (a2 - alph2) * inst2.valueSparse(p2); |
| 1021 | |
} |
| 1022 | |
} |
| 1023 | |
} |
| 1024 | |
|
| 1025 | |
|
| 1026 | 0 | for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) { |
| 1027 | 0 | if ((j != i1) && (j != i2)) { |
| 1028 | 0 | m_errors[j] += |
| 1029 | |
y1 * (a1 - alph1) * m_kernel.eval(i1, j, m_data.instance(i1)) + |
| 1030 | |
y2 * (a2 - alph2) * m_kernel.eval(i2, j, m_data.instance(i2)); |
| 1031 | |
} |
| 1032 | |
} |
| 1033 | |
|
| 1034 | |
|
| 1035 | 0 | m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12; |
| 1036 | 0 | m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22; |
| 1037 | |
|
| 1038 | |
|
| 1039 | 0 | m_alpha[i1] = a1; |
| 1040 | 0 | m_alpha[i2] = a2; |
| 1041 | |
|
| 1042 | |
|
| 1043 | 0 | m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE; |
| 1044 | 0 | m_iLow = -1; m_iUp = -1; |
| 1045 | 0 | for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) { |
| 1046 | 0 | if (m_errors[j] < m_bUp) { |
| 1047 | 0 | m_bUp = m_errors[j]; m_iUp = j; |
| 1048 | |
} |
| 1049 | 0 | if (m_errors[j] > m_bLow) { |
| 1050 | 0 | m_bLow = m_errors[j]; m_iLow = j; |
| 1051 | |
} |
| 1052 | |
} |
| 1053 | 0 | if (!m_I0.contains(i1)) { |
| 1054 | 0 | if (m_I3.contains(i1) || m_I4.contains(i1)) { |
| 1055 | 0 | if (m_errors[i1] > m_bLow) { |
| 1056 | 0 | m_bLow = m_errors[i1]; m_iLow = i1; |
| 1057 | |
} |
| 1058 | |
} else { |
| 1059 | 0 | if (m_errors[i1] < m_bUp) { |
| 1060 | 0 | m_bUp = m_errors[i1]; m_iUp = i1; |
| 1061 | |
} |
| 1062 | |
} |
| 1063 | |
} |
| 1064 | 0 | if (!m_I0.contains(i2)) { |
| 1065 | 0 | if (m_I3.contains(i2) || m_I4.contains(i2)) { |
| 1066 | 0 | if (m_errors[i2] > m_bLow) { |
| 1067 | 0 | m_bLow = m_errors[i2]; m_iLow = i2; |
| 1068 | |
} |
| 1069 | |
} else { |
| 1070 | 0 | if (m_errors[i2] < m_bUp) { |
| 1071 | 0 | m_bUp = m_errors[i2]; m_iUp = i2; |
| 1072 | |
} |
| 1073 | |
} |
| 1074 | |
} |
| 1075 | 0 | if ((m_iLow == -1) || (m_iUp == -1)) { |
| 1076 | 0 | throw new Exception("This should never happen!"); |
| 1077 | |
} |
| 1078 | |
|
| 1079 | |
|
| 1080 | 0 | return true; |
| 1081 | |
} |
| 1082 | |
|
| 1083 | |
|
| 1084 | |
|
| 1085 | |
|
| 1086 | |
|
| 1087 | |
|
| 1088 | |
protected void checkClassifier() throws Exception { |
| 1089 | |
|
| 1090 | 0 | double sum = 0; |
| 1091 | 0 | for (int i = 0; i < m_alpha.length; i++) { |
| 1092 | 0 | if (m_alpha[i] > 0) { |
| 1093 | 0 | sum += m_class[i] * m_alpha[i]; |
| 1094 | |
} |
| 1095 | |
} |
| 1096 | 0 | System.err.println("Sum of y(i) * alpha(i): " + sum); |
| 1097 | |
|
| 1098 | 0 | for (int i = 0; i < m_alpha.length; i++) { |
| 1099 | 0 | double output = SVMOutput(i, m_data.instance(i)); |
| 1100 | 0 | if (Utils.eq(m_alpha[i], 0)) { |
| 1101 | 0 | if (Utils.sm(m_class[i] * output, 1)) { |
| 1102 | 0 | System.err.println("KKT condition 1 violated: " + m_class[i] * output); |
| 1103 | |
} |
| 1104 | |
} |
| 1105 | 0 | if (Utils.gr(m_alpha[i], 0) && |
| 1106 | |
Utils.sm(m_alpha[i], m_C * m_data.instance(i).weight())) { |
| 1107 | 0 | if (!Utils.eq(m_class[i] * output, 1)) { |
| 1108 | 0 | System.err.println("KKT condition 2 violated: " + m_class[i] * output); |
| 1109 | |
} |
| 1110 | |
} |
| 1111 | 0 | if (Utils.eq(m_alpha[i], m_C * m_data.instance(i).weight())) { |
| 1112 | 0 | if (Utils.gr(m_class[i] * output, 1)) { |
| 1113 | 0 | System.err.println("KKT condition 3 violated: " + m_class[i] * output); |
| 1114 | |
} |
| 1115 | |
} |
| 1116 | |
} |
| 1117 | 0 | } |
| 1118 | |
|
| 1119 | |
|
| 1120 | |
|
| 1121 | |
|
| 1122 | |
|
| 1123 | |
|
| 1124 | |
public String getRevision() { |
| 1125 | 0 | return RevisionUtils.extract("$Revision: 8034 $"); |
| 1126 | |
} |
| 1127 | |
} |
| 1128 | |
|
| 1129 | |
|
| 1130 | |
public static final int FILTER_NORMALIZE = 0; |
| 1131 | |
|
| 1132 | |
public static final int FILTER_STANDARDIZE = 1; |
| 1133 | |
|
| 1134 | |
public static final int FILTER_NONE = 2; |
| 1135 | |
|
| 1136 | 0 | public static final Tag [] TAGS_FILTER = { |
| 1137 | |
new Tag(FILTER_NORMALIZE, "Normalize training data"), |
| 1138 | |
new Tag(FILTER_STANDARDIZE, "Standardize training data"), |
| 1139 | |
new Tag(FILTER_NONE, "No normalization/standardization"), |
| 1140 | |
}; |
| 1141 | |
|
| 1142 | |
|
| 1143 | 0 | protected BinarySMO[][] m_classifiers = null; |
| 1144 | |
|
| 1145 | |
|
| 1146 | 0 | protected double m_C = 1.0; |
| 1147 | |
|
| 1148 | |
|
| 1149 | 0 | protected double m_eps = 1.0e-12; |
| 1150 | |
|
| 1151 | |
|
| 1152 | 0 | protected double m_tol = 1.0e-3; |
| 1153 | |
|
| 1154 | |
|
| 1155 | 0 | protected int m_filterType = FILTER_NORMALIZE; |
| 1156 | |
|
| 1157 | |
|
| 1158 | |
protected NominalToBinary m_NominalToBinary; |
| 1159 | |
|
| 1160 | |
|
| 1161 | 0 | protected Filter m_Filter = null; |
| 1162 | |
|
| 1163 | |
|
| 1164 | |
protected ReplaceMissingValues m_Missing; |
| 1165 | |
|
| 1166 | |
|
| 1167 | 0 | protected int m_classIndex = -1; |
| 1168 | |
|
| 1169 | |
|
| 1170 | |
protected Attribute m_classAttribute; |
| 1171 | |
|
| 1172 | |
|
| 1173 | 0 | protected boolean m_KernelIsLinear = false; |
| 1174 | |
|
| 1175 | |
|
| 1176 | |
|
| 1177 | |
|
| 1178 | |
|
| 1179 | |
|
| 1180 | |
protected boolean m_checksTurnedOff; |
| 1181 | |
|
| 1182 | |
|
| 1183 | 0 | protected static double m_Del = 1000 * Double.MIN_VALUE; |
| 1184 | |
|
| 1185 | |
|
| 1186 | 0 | protected boolean m_fitLogisticModels = false; |
| 1187 | |
|
| 1188 | |
|
| 1189 | 0 | protected int m_numFolds = -1; |
| 1190 | |
|
| 1191 | |
|
| 1192 | 0 | protected int m_randomSeed = 1; |
| 1193 | |
|
| 1194 | |
|
| 1195 | 0 | protected Kernel m_kernel = new PolyKernel(); |
| 1196 | |
|
| 1197 | |
|
| 1198 | |
|
| 1199 | |
|
| 1200 | |
public void turnChecksOff() { |
| 1201 | |
|
| 1202 | 0 | m_checksTurnedOff = true; |
| 1203 | 0 | } |
| 1204 | |
|
| 1205 | |
|
| 1206 | |
|
| 1207 | |
|
| 1208 | |
public void turnChecksOn() { |
| 1209 | |
|
| 1210 | 0 | m_checksTurnedOff = false; |
| 1211 | 0 | } |
| 1212 | |
|
| 1213 | |
|
| 1214 | |
|
| 1215 | |
|
| 1216 | |
|
| 1217 | |
|
| 1218 | |
public Capabilities getCapabilities() { |
| 1219 | 0 | Capabilities result = getKernel().getCapabilities(); |
| 1220 | 0 | result.setOwner(this); |
| 1221 | |
|
| 1222 | |
|
| 1223 | 0 | result.enableAllAttributeDependencies(); |
| 1224 | |
|
| 1225 | |
|
| 1226 | 0 | if (result.handles(Capability.NUMERIC_ATTRIBUTES)) |
| 1227 | 0 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
| 1228 | 0 | result.enable(Capability.MISSING_VALUES); |
| 1229 | |
|
| 1230 | |
|
| 1231 | 0 | result.disableAllClasses(); |
| 1232 | 0 | result.disableAllClassDependencies(); |
| 1233 | 0 | result.enable(Capability.NOMINAL_CLASS); |
| 1234 | 0 | result.enable(Capability.MISSING_CLASS_VALUES); |
| 1235 | |
|
| 1236 | 0 | return result; |
| 1237 | |
} |
| 1238 | |
|
| 1239 | |
|
| 1240 | |
|
| 1241 | |
|
| 1242 | |
|
| 1243 | |
|
| 1244 | |
|
| 1245 | |
|
| 1246 | |
public void buildClassifier(Instances insts) throws Exception { |
| 1247 | |
|
| 1248 | 0 | if (!m_checksTurnedOff) { |
| 1249 | |
|
| 1250 | 0 | getCapabilities().testWithFail(insts); |
| 1251 | |
|
| 1252 | |
|
| 1253 | 0 | insts = new Instances(insts); |
| 1254 | 0 | insts.deleteWithMissingClass(); |
| 1255 | |
|
| 1256 | |
|
| 1257 | |
|
| 1258 | |
|
| 1259 | 0 | Instances data = new Instances(insts, insts.numInstances()); |
| 1260 | 0 | for(int i = 0; i < insts.numInstances(); i++){ |
| 1261 | 0 | if(insts.instance(i).weight() > 0) |
| 1262 | 0 | data.add(insts.instance(i)); |
| 1263 | |
} |
| 1264 | 0 | if (data.numInstances() == 0) { |
| 1265 | 0 | throw new Exception("No training instances left after removing " + |
| 1266 | |
"instances with weight 0!"); |
| 1267 | |
} |
| 1268 | 0 | insts = data; |
| 1269 | |
} |
| 1270 | |
|
| 1271 | 0 | if (!m_checksTurnedOff) { |
| 1272 | 0 | m_Missing = new ReplaceMissingValues(); |
| 1273 | 0 | m_Missing.setInputFormat(insts); |
| 1274 | 0 | insts = Filter.useFilter(insts, m_Missing); |
| 1275 | |
} else { |
| 1276 | 0 | m_Missing = null; |
| 1277 | |
} |
| 1278 | |
|
| 1279 | 0 | if (getCapabilities().handles(Capability.NUMERIC_ATTRIBUTES)) { |
| 1280 | 0 | boolean onlyNumeric = true; |
| 1281 | 0 | if (!m_checksTurnedOff) { |
| 1282 | 0 | for (int i = 0; i < insts.numAttributes(); i++) { |
| 1283 | 0 | if (i != insts.classIndex()) { |
| 1284 | 0 | if (!insts.attribute(i).isNumeric()) { |
| 1285 | 0 | onlyNumeric = false; |
| 1286 | 0 | break; |
| 1287 | |
} |
| 1288 | |
} |
| 1289 | |
} |
| 1290 | |
} |
| 1291 | |
|
| 1292 | 0 | if (!onlyNumeric) { |
| 1293 | 0 | m_NominalToBinary = new NominalToBinary(); |
| 1294 | 0 | m_NominalToBinary.setInputFormat(insts); |
| 1295 | 0 | insts = Filter.useFilter(insts, m_NominalToBinary); |
| 1296 | |
} |
| 1297 | |
else { |
| 1298 | 0 | m_NominalToBinary = null; |
| 1299 | |
} |
| 1300 | 0 | } |
| 1301 | |
else { |
| 1302 | 0 | m_NominalToBinary = null; |
| 1303 | |
} |
| 1304 | |
|
| 1305 | 0 | if (m_filterType == FILTER_STANDARDIZE) { |
| 1306 | 0 | m_Filter = new Standardize(); |
| 1307 | 0 | m_Filter.setInputFormat(insts); |
| 1308 | 0 | insts = Filter.useFilter(insts, m_Filter); |
| 1309 | 0 | } else if (m_filterType == FILTER_NORMALIZE) { |
| 1310 | 0 | m_Filter = new Normalize(); |
| 1311 | 0 | m_Filter.setInputFormat(insts); |
| 1312 | 0 | insts = Filter.useFilter(insts, m_Filter); |
| 1313 | |
} else { |
| 1314 | 0 | m_Filter = null; |
| 1315 | |
} |
| 1316 | |
|
| 1317 | 0 | m_classIndex = insts.classIndex(); |
| 1318 | 0 | m_classAttribute = insts.classAttribute(); |
| 1319 | 0 | m_KernelIsLinear = (m_kernel instanceof PolyKernel) && (((PolyKernel) m_kernel).getExponent() == 1.0); |
| 1320 | |
|
| 1321 | |
|
| 1322 | 0 | Instances[] subsets = new Instances[insts.numClasses()]; |
| 1323 | 0 | for (int i = 0; i < insts.numClasses(); i++) { |
| 1324 | 0 | subsets[i] = new Instances(insts, insts.numInstances()); |
| 1325 | |
} |
| 1326 | 0 | for (int j = 0; j < insts.numInstances(); j++) { |
| 1327 | 0 | Instance inst = insts.instance(j); |
| 1328 | 0 | subsets[(int)inst.classValue()].add(inst); |
| 1329 | |
} |
| 1330 | 0 | for (int i = 0; i < insts.numClasses(); i++) { |
| 1331 | 0 | subsets[i].compactify(); |
| 1332 | |
} |
| 1333 | |
|
| 1334 | |
|
| 1335 | 0 | Random rand = new Random(m_randomSeed); |
| 1336 | 0 | m_classifiers = new BinarySMO[insts.numClasses()][insts.numClasses()]; |
| 1337 | 0 | for (int i = 0; i < insts.numClasses(); i++) { |
| 1338 | 0 | for (int j = i + 1; j < insts.numClasses(); j++) { |
| 1339 | 0 | m_classifiers[i][j] = new BinarySMO(); |
| 1340 | 0 | m_classifiers[i][j].setKernel(Kernel.makeCopy(getKernel())); |
| 1341 | 0 | Instances data = new Instances(insts, insts.numInstances()); |
| 1342 | 0 | for (int k = 0; k < subsets[i].numInstances(); k++) { |
| 1343 | 0 | data.add(subsets[i].instance(k)); |
| 1344 | |
} |
| 1345 | 0 | for (int k = 0; k < subsets[j].numInstances(); k++) { |
| 1346 | 0 | data.add(subsets[j].instance(k)); |
| 1347 | |
} |
| 1348 | 0 | data.compactify(); |
| 1349 | 0 | data.randomize(rand); |
| 1350 | 0 | m_classifiers[i][j].buildClassifier(data, i, j, |
| 1351 | |
m_fitLogisticModels, |
| 1352 | |
m_numFolds, m_randomSeed); |
| 1353 | |
} |
| 1354 | |
} |
| 1355 | 0 | } |
| 1356 | |
|
| 1357 | |
|
| 1358 | |
|
| 1359 | |
|
| 1360 | |
|
| 1361 | |
|
| 1362 | |
|
| 1363 | |
public double[] distributionForInstance(Instance inst) throws Exception { |
| 1364 | |
|
| 1365 | |
|
| 1366 | 0 | if (!m_checksTurnedOff) { |
| 1367 | 0 | m_Missing.input(inst); |
| 1368 | 0 | m_Missing.batchFinished(); |
| 1369 | 0 | inst = m_Missing.output(); |
| 1370 | |
} |
| 1371 | |
|
| 1372 | 0 | if (m_NominalToBinary != null) { |
| 1373 | 0 | m_NominalToBinary.input(inst); |
| 1374 | 0 | m_NominalToBinary.batchFinished(); |
| 1375 | 0 | inst = m_NominalToBinary.output(); |
| 1376 | |
} |
| 1377 | |
|
| 1378 | 0 | if (m_Filter != null) { |
| 1379 | 0 | m_Filter.input(inst); |
| 1380 | 0 | m_Filter.batchFinished(); |
| 1381 | 0 | inst = m_Filter.output(); |
| 1382 | |
} |
| 1383 | |
|
| 1384 | 0 | if (!m_fitLogisticModels) { |
| 1385 | 0 | double[] result = new double[inst.numClasses()]; |
| 1386 | 0 | for (int i = 0; i < inst.numClasses(); i++) { |
| 1387 | 0 | for (int j = i + 1; j < inst.numClasses(); j++) { |
| 1388 | 0 | if ((m_classifiers[i][j].m_alpha != null) || |
| 1389 | |
(m_classifiers[i][j].m_sparseWeights != null)) { |
| 1390 | 0 | double output = m_classifiers[i][j].SVMOutput(-1, inst); |
| 1391 | 0 | if (output > 0) { |
| 1392 | 0 | result[j] += 1; |
| 1393 | |
} else { |
| 1394 | 0 | result[i] += 1; |
| 1395 | |
} |
| 1396 | |
} |
| 1397 | |
} |
| 1398 | |
} |
| 1399 | 0 | Utils.normalize(result); |
| 1400 | 0 | return result; |
| 1401 | |
} else { |
| 1402 | |
|
| 1403 | |
|
| 1404 | |
|
| 1405 | 0 | if (inst.numClasses() == 2) { |
| 1406 | 0 | double[] newInst = new double[2]; |
| 1407 | 0 | newInst[0] = m_classifiers[0][1].SVMOutput(-1, inst); |
| 1408 | 0 | newInst[1] = Utils.missingValue(); |
| 1409 | 0 | return m_classifiers[0][1].m_logistic. |
| 1410 | |
distributionForInstance(new DenseInstance(1, newInst)); |
| 1411 | |
} |
| 1412 | 0 | double[][] r = new double[inst.numClasses()][inst.numClasses()]; |
| 1413 | 0 | double[][] n = new double[inst.numClasses()][inst.numClasses()]; |
| 1414 | 0 | for (int i = 0; i < inst.numClasses(); i++) { |
| 1415 | 0 | for (int j = i + 1; j < inst.numClasses(); j++) { |
| 1416 | 0 | if ((m_classifiers[i][j].m_alpha != null) || |
| 1417 | |
(m_classifiers[i][j].m_sparseWeights != null)) { |
| 1418 | 0 | double[] newInst = new double[2]; |
| 1419 | 0 | newInst[0] = m_classifiers[i][j].SVMOutput(-1, inst); |
| 1420 | 0 | newInst[1] = Utils.missingValue(); |
| 1421 | 0 | r[i][j] = m_classifiers[i][j].m_logistic. |
| 1422 | |
distributionForInstance(new DenseInstance(1, newInst))[0]; |
| 1423 | 0 | n[i][j] = m_classifiers[i][j].m_sumOfWeights; |
| 1424 | |
} |
| 1425 | |
} |
| 1426 | |
} |
| 1427 | 0 | return weka.classifiers.meta.MultiClassClassifier.pairwiseCoupling(n, r); |
| 1428 | |
} |
| 1429 | |
} |
| 1430 | |
|
| 1431 | |
|
| 1432 | |
|
| 1433 | |
|
| 1434 | |
|
| 1435 | |
|
| 1436 | |
|
| 1437 | |
public int[] obtainVotes(Instance inst) throws Exception { |
| 1438 | |
|
| 1439 | |
|
| 1440 | 0 | if (!m_checksTurnedOff) { |
| 1441 | 0 | m_Missing.input(inst); |
| 1442 | 0 | m_Missing.batchFinished(); |
| 1443 | 0 | inst = m_Missing.output(); |
| 1444 | |
} |
| 1445 | |
|
| 1446 | 0 | if (m_NominalToBinary != null) { |
| 1447 | 0 | m_NominalToBinary.input(inst); |
| 1448 | 0 | m_NominalToBinary.batchFinished(); |
| 1449 | 0 | inst = m_NominalToBinary.output(); |
| 1450 | |
} |
| 1451 | |
|
| 1452 | 0 | if (m_Filter != null) { |
| 1453 | 0 | m_Filter.input(inst); |
| 1454 | 0 | m_Filter.batchFinished(); |
| 1455 | 0 | inst = m_Filter.output(); |
| 1456 | |
} |
| 1457 | |
|
| 1458 | 0 | int[] votes = new int[inst.numClasses()]; |
| 1459 | 0 | for (int i = 0; i < inst.numClasses(); i++) { |
| 1460 | 0 | for (int j = i + 1; j < inst.numClasses(); j++) { |
| 1461 | 0 | double output = m_classifiers[i][j].SVMOutput(-1, inst); |
| 1462 | 0 | if (output > 0) { |
| 1463 | 0 | votes[j] += 1; |
| 1464 | |
} else { |
| 1465 | 0 | votes[i] += 1; |
| 1466 | |
} |
| 1467 | |
} |
| 1468 | |
} |
| 1469 | 0 | return votes; |
| 1470 | |
} |
| 1471 | |
|
| 1472 | |
|
| 1473 | |
|
| 1474 | |
|
| 1475 | |
public double [][][] sparseWeights() { |
| 1476 | |
|
| 1477 | 0 | int numValues = m_classAttribute.numValues(); |
| 1478 | 0 | double [][][] sparseWeights = new double[numValues][numValues][]; |
| 1479 | |
|
| 1480 | 0 | for (int i = 0; i < numValues; i++) { |
| 1481 | 0 | for (int j = i + 1; j < numValues; j++) { |
| 1482 | 0 | sparseWeights[i][j] = m_classifiers[i][j].m_sparseWeights; |
| 1483 | |
} |
| 1484 | |
} |
| 1485 | |
|
| 1486 | 0 | return sparseWeights; |
| 1487 | |
} |
| 1488 | |
|
| 1489 | |
|
| 1490 | |
|
| 1491 | |
|
| 1492 | |
public int [][][] sparseIndices() { |
| 1493 | |
|
| 1494 | 0 | int numValues = m_classAttribute.numValues(); |
| 1495 | 0 | int [][][] sparseIndices = new int[numValues][numValues][]; |
| 1496 | |
|
| 1497 | 0 | for (int i = 0; i < numValues; i++) { |
| 1498 | 0 | for (int j = i + 1; j < numValues; j++) { |
| 1499 | 0 | sparseIndices[i][j] = m_classifiers[i][j].m_sparseIndices; |
| 1500 | |
} |
| 1501 | |
} |
| 1502 | |
|
| 1503 | 0 | return sparseIndices; |
| 1504 | |
} |
| 1505 | |
|
| 1506 | |
|
| 1507 | |
|
| 1508 | |
|
| 1509 | |
public double [][] bias() { |
| 1510 | |
|
| 1511 | 0 | int numValues = m_classAttribute.numValues(); |
| 1512 | 0 | double [][] bias = new double[numValues][numValues]; |
| 1513 | |
|
| 1514 | 0 | for (int i = 0; i < numValues; i++) { |
| 1515 | 0 | for (int j = i + 1; j < numValues; j++) { |
| 1516 | 0 | bias[i][j] = m_classifiers[i][j].m_b; |
| 1517 | |
} |
| 1518 | |
} |
| 1519 | |
|
| 1520 | 0 | return bias; |
| 1521 | |
} |
| 1522 | |
|
| 1523 | |
|
| 1524 | |
|
| 1525 | |
|
| 1526 | |
public int numClassAttributeValues() { |
| 1527 | |
|
| 1528 | 0 | return m_classAttribute.numValues(); |
| 1529 | |
} |
| 1530 | |
|
| 1531 | |
|
| 1532 | |
|
| 1533 | |
|
| 1534 | |
public String [] classAttributeNames() { |
| 1535 | |
|
| 1536 | 0 | int numValues = m_classAttribute.numValues(); |
| 1537 | |
|
| 1538 | 0 | String [] classAttributeNames = new String[numValues]; |
| 1539 | |
|
| 1540 | 0 | for (int i = 0; i < numValues; i++) { |
| 1541 | 0 | classAttributeNames[i] = m_classAttribute.value(i); |
| 1542 | |
} |
| 1543 | |
|
| 1544 | 0 | return classAttributeNames; |
| 1545 | |
} |
| 1546 | |
|
| 1547 | |
|
| 1548 | |
|
| 1549 | |
|
| 1550 | |
public String [][][] attributeNames() { |
| 1551 | |
|
| 1552 | 0 | int numValues = m_classAttribute.numValues(); |
| 1553 | 0 | String [][][] attributeNames = new String[numValues][numValues][]; |
| 1554 | |
|
| 1555 | 0 | for (int i = 0; i < numValues; i++) { |
| 1556 | 0 | for (int j = i + 1; j < numValues; j++) { |
| 1557 | |
|
| 1558 | 0 | int numAttributes = m_classifiers[i][j].m_sparseIndices.length; |
| 1559 | 0 | String [] attrNames = new String[numAttributes]; |
| 1560 | 0 | for (int k = 0; k < numAttributes; k++) { |
| 1561 | 0 | attrNames[k] = m_classifiers[i][j]. |
| 1562 | |
m_data.attribute(m_classifiers[i][j].m_sparseIndices[k]).name(); |
| 1563 | |
} |
| 1564 | 0 | attributeNames[i][j] = attrNames; |
| 1565 | |
} |
| 1566 | |
} |
| 1567 | 0 | return attributeNames; |
| 1568 | |
} |
| 1569 | |
|
| 1570 | |
|
| 1571 | |
|
| 1572 | |
|
| 1573 | |
|
| 1574 | |
|
| 1575 | |
public Enumeration listOptions() { |
| 1576 | |
|
| 1577 | 0 | Vector result = new Vector(); |
| 1578 | |
|
| 1579 | 0 | Enumeration enm = super.listOptions(); |
| 1580 | 0 | while (enm.hasMoreElements()) |
| 1581 | 0 | result.addElement(enm.nextElement()); |
| 1582 | |
|
| 1583 | 0 | result.addElement(new Option( |
| 1584 | |
"\tTurns off all checks - use with caution!\n" |
| 1585 | |
+ "\tTurning them off assumes that data is purely numeric, doesn't\n" |
| 1586 | |
+ "\tcontain any missing values, and has a nominal class. Turning them\n" |
| 1587 | |
+ "\toff also means that no header information will be stored if the\n" |
| 1588 | |
+ "\tmachine is linear. Finally, it also assumes that no instance has\n" |
| 1589 | |
+ "\ta weight equal to 0.\n" |
| 1590 | |
+ "\t(default: checks on)", |
| 1591 | |
"no-checks", 0, "-no-checks")); |
| 1592 | |
|
| 1593 | 0 | result.addElement(new Option( |
| 1594 | |
"\tThe complexity constant C. (default 1)", |
| 1595 | |
"C", 1, "-C <double>")); |
| 1596 | |
|
| 1597 | 0 | result.addElement(new Option( |
| 1598 | |
"\tWhether to 0=normalize/1=standardize/2=neither. " + |
| 1599 | |
"(default 0=normalize)", |
| 1600 | |
"N", 1, "-N")); |
| 1601 | |
|
| 1602 | 0 | result.addElement(new Option( |
| 1603 | |
"\tThe tolerance parameter. " + |
| 1604 | |
"(default 1.0e-3)", |
| 1605 | |
"L", 1, "-L <double>")); |
| 1606 | |
|
| 1607 | 0 | result.addElement(new Option( |
| 1608 | |
"\tThe epsilon for round-off error. " + |
| 1609 | |
"(default 1.0e-12)", |
| 1610 | |
"P", 1, "-P <double>")); |
| 1611 | |
|
| 1612 | 0 | result.addElement(new Option( |
| 1613 | |
"\tFit logistic models to SVM outputs. ", |
| 1614 | |
"M", 0, "-M")); |
| 1615 | |
|
| 1616 | 0 | result.addElement(new Option( |
| 1617 | |
"\tThe number of folds for the internal\n" + |
| 1618 | |
"\tcross-validation. " + |
| 1619 | |
"(default -1, use training data)", |
| 1620 | |
"V", 1, "-V <double>")); |
| 1621 | |
|
| 1622 | 0 | result.addElement(new Option( |
| 1623 | |
"\tThe random number seed. " + |
| 1624 | |
"(default 1)", |
| 1625 | |
"W", 1, "-W <double>")); |
| 1626 | |
|
| 1627 | 0 | result.addElement(new Option( |
| 1628 | |
"\tThe Kernel to use.\n" |
| 1629 | |
+ "\t(default: weka.classifiers.functions.supportVector.PolyKernel)", |
| 1630 | |
"K", 1, "-K <classname and parameters>")); |
| 1631 | |
|
| 1632 | 0 | result.addElement(new Option( |
| 1633 | |
"", |
| 1634 | |
"", 0, "\nOptions specific to kernel " |
| 1635 | |
+ getKernel().getClass().getName() + ":")); |
| 1636 | |
|
| 1637 | 0 | enm = ((OptionHandler) getKernel()).listOptions(); |
| 1638 | 0 | while (enm.hasMoreElements()) |
| 1639 | 0 | result.addElement(enm.nextElement()); |
| 1640 | |
|
| 1641 | 0 | return result.elements(); |
| 1642 | |
} |
| 1643 | |
|
| 1644 | |
|
| 1645 | |
|
| 1646 | |
|
| 1647 | |
|
| 1648 | |
|
| 1649 | |
|
| 1650 | |
|
| 1651 | |
|
| 1652 | |
|
| 1653 | |
|
| 1654 | |
|
| 1655 | |
|
| 1656 | |
|
| 1657 | |
|
| 1658 | |
|
| 1659 | |
|
| 1660 | |
|
| 1661 | |
|
| 1662 | |
|
| 1663 | |
|
| 1664 | |
|
| 1665 | |
|
| 1666 | |
|
| 1667 | |
|
| 1668 | |
|
| 1669 | |
|
| 1670 | |
|
| 1671 | |
|
| 1672 | |
|
| 1673 | |
|
| 1674 | |
|
| 1675 | |
|
| 1676 | |
|
| 1677 | |
|
| 1678 | |
|
| 1679 | |
|
| 1680 | |
|
| 1681 | |
|
| 1682 | |
|
| 1683 | |
|
| 1684 | |
|
| 1685 | |
|
| 1686 | |
|
| 1687 | |
|
| 1688 | |
|
| 1689 | |
|
| 1690 | |
|
| 1691 | |
|
| 1692 | |
|
| 1693 | |
|
| 1694 | |
|
| 1695 | |
|
| 1696 | |
|
| 1697 | |
|
| 1698 | |
|
| 1699 | |
|
| 1700 | |
|
| 1701 | |
|
| 1702 | |
|
| 1703 | |
|
| 1704 | |
|
| 1705 | |
|
| 1706 | |
|
| 1707 | |
|
| 1708 | |
|
| 1709 | |
|
| 1710 | |
|
| 1711 | |
|
| 1712 | |
|
| 1713 | |
|
| 1714 | |
|
| 1715 | |
|
| 1716 | |
|
| 1717 | |
|
| 1718 | |
|
| 1719 | |
public void setOptions(String[] options) throws Exception { |
| 1720 | |
String tmpStr; |
| 1721 | |
String[] tmpOptions; |
| 1722 | |
|
| 1723 | 0 | setChecksTurnedOff(Utils.getFlag("no-checks", options)); |
| 1724 | |
|
| 1725 | 0 | tmpStr = Utils.getOption('C', options); |
| 1726 | 0 | if (tmpStr.length() != 0) |
| 1727 | 0 | setC(Double.parseDouble(tmpStr)); |
| 1728 | |
else |
| 1729 | 0 | setC(1.0); |
| 1730 | |
|
| 1731 | 0 | tmpStr = Utils.getOption('L', options); |
| 1732 | 0 | if (tmpStr.length() != 0) |
| 1733 | 0 | setToleranceParameter(Double.parseDouble(tmpStr)); |
| 1734 | |
else |
| 1735 | 0 | setToleranceParameter(1.0e-3); |
| 1736 | |
|
| 1737 | 0 | tmpStr = Utils.getOption('P', options); |
| 1738 | 0 | if (tmpStr.length() != 0) |
| 1739 | 0 | setEpsilon(Double.parseDouble(tmpStr)); |
| 1740 | |
else |
| 1741 | 0 | setEpsilon(1.0e-12); |
| 1742 | |
|
| 1743 | 0 | tmpStr = Utils.getOption('N', options); |
| 1744 | 0 | if (tmpStr.length() != 0) |
| 1745 | 0 | setFilterType(new SelectedTag(Integer.parseInt(tmpStr), TAGS_FILTER)); |
| 1746 | |
else |
| 1747 | 0 | setFilterType(new SelectedTag(FILTER_NORMALIZE, TAGS_FILTER)); |
| 1748 | |
|
| 1749 | 0 | setBuildLogisticModels(Utils.getFlag('M', options)); |
| 1750 | |
|
| 1751 | 0 | tmpStr = Utils.getOption('V', options); |
| 1752 | 0 | if (tmpStr.length() != 0) |
| 1753 | 0 | setNumFolds(Integer.parseInt(tmpStr)); |
| 1754 | |
else |
| 1755 | 0 | setNumFolds(-1); |
| 1756 | |
|
| 1757 | 0 | tmpStr = Utils.getOption('W', options); |
| 1758 | 0 | if (tmpStr.length() != 0) |
| 1759 | 0 | setRandomSeed(Integer.parseInt(tmpStr)); |
| 1760 | |
else |
| 1761 | 0 | setRandomSeed(1); |
| 1762 | |
|
| 1763 | 0 | tmpStr = Utils.getOption('K', options); |
| 1764 | 0 | tmpOptions = Utils.splitOptions(tmpStr); |
| 1765 | 0 | if (tmpOptions.length != 0) { |
| 1766 | 0 | tmpStr = tmpOptions[0]; |
| 1767 | 0 | tmpOptions[0] = ""; |
| 1768 | 0 | setKernel(Kernel.forName(tmpStr, tmpOptions)); |
| 1769 | |
} |
| 1770 | |
|
| 1771 | 0 | super.setOptions(options); |
| 1772 | 0 | } |
| 1773 | |
|
| 1774 | |
|
| 1775 | |
|
| 1776 | |
|
| 1777 | |
|
| 1778 | |
|
| 1779 | |
public String[] getOptions() { |
| 1780 | |
int i; |
| 1781 | |
Vector result; |
| 1782 | |
String[] options; |
| 1783 | |
|
| 1784 | 0 | result = new Vector(); |
| 1785 | 0 | options = super.getOptions(); |
| 1786 | 0 | for (i = 0; i < options.length; i++) |
| 1787 | 0 | result.add(options[i]); |
| 1788 | |
|
| 1789 | 0 | if (getChecksTurnedOff()) |
| 1790 | 0 | result.add("-no-checks"); |
| 1791 | |
|
| 1792 | 0 | result.add("-C"); |
| 1793 | 0 | result.add("" + getC()); |
| 1794 | |
|
| 1795 | 0 | result.add("-L"); |
| 1796 | 0 | result.add("" + getToleranceParameter()); |
| 1797 | |
|
| 1798 | 0 | result.add("-P"); |
| 1799 | 0 | result.add("" + getEpsilon()); |
| 1800 | |
|
| 1801 | 0 | result.add("-N"); |
| 1802 | 0 | result.add("" + m_filterType); |
| 1803 | |
|
| 1804 | 0 | if (getBuildLogisticModels()) |
| 1805 | 0 | result.add("-M"); |
| 1806 | |
|
| 1807 | 0 | result.add("-V"); |
| 1808 | 0 | result.add("" + getNumFolds()); |
| 1809 | |
|
| 1810 | 0 | result.add("-W"); |
| 1811 | 0 | result.add("" + getRandomSeed()); |
| 1812 | |
|
| 1813 | 0 | result.add("-K"); |
| 1814 | 0 | result.add("" + getKernel().getClass().getName() + " " + Utils.joinOptions(getKernel().getOptions())); |
| 1815 | |
|
| 1816 | 0 | return (String[]) result.toArray(new String[result.size()]); |
| 1817 | |
} |
| 1818 | |
|
| 1819 | |
|
| 1820 | |
|
| 1821 | |
|
| 1822 | |
|
| 1823 | |
|
| 1824 | |
|
| 1825 | |
public void setChecksTurnedOff(boolean value) { |
| 1826 | 0 | if (value) |
| 1827 | 0 | turnChecksOff(); |
| 1828 | |
else |
| 1829 | 0 | turnChecksOn(); |
| 1830 | 0 | } |
| 1831 | |
|
| 1832 | |
|
| 1833 | |
|
| 1834 | |
|
| 1835 | |
|
| 1836 | |
|
| 1837 | |
public boolean getChecksTurnedOff() { |
| 1838 | 0 | return m_checksTurnedOff; |
| 1839 | |
} |
| 1840 | |
|
| 1841 | |
|
| 1842 | |
|
| 1843 | |
|
| 1844 | |
|
| 1845 | |
|
| 1846 | |
|
| 1847 | |
public String checksTurnedOffTipText() { |
| 1848 | 0 | return "Turns time-consuming checks off - use with caution."; |
| 1849 | |
} |
| 1850 | |
|
| 1851 | |
|
| 1852 | |
|
| 1853 | |
|
| 1854 | |
|
| 1855 | |
|
| 1856 | |
|
| 1857 | |
public String kernelTipText() { |
| 1858 | 0 | return "The kernel to use."; |
| 1859 | |
} |
| 1860 | |
|
| 1861 | |
|
| 1862 | |
|
| 1863 | |
|
| 1864 | |
|
| 1865 | |
|
| 1866 | |
public void setKernel(Kernel value) { |
| 1867 | 0 | m_kernel = value; |
| 1868 | 0 | } |
| 1869 | |
|
| 1870 | |
|
| 1871 | |
|
| 1872 | |
|
| 1873 | |
|
| 1874 | |
|
| 1875 | |
public Kernel getKernel() { |
| 1876 | 0 | return m_kernel; |
| 1877 | |
} |
| 1878 | |
|
| 1879 | |
|
| 1880 | |
|
| 1881 | |
|
| 1882 | |
|
| 1883 | |
|
| 1884 | |
public String cTipText() { |
| 1885 | 0 | return "The complexity parameter C."; |
| 1886 | |
} |
| 1887 | |
|
| 1888 | |
|
| 1889 | |
|
| 1890 | |
|
| 1891 | |
|
| 1892 | |
|
| 1893 | |
public double getC() { |
| 1894 | |
|
| 1895 | 0 | return m_C; |
| 1896 | |
} |
| 1897 | |
|
| 1898 | |
|
| 1899 | |
|
| 1900 | |
|
| 1901 | |
|
| 1902 | |
|
| 1903 | |
public void setC(double v) { |
| 1904 | |
|
| 1905 | 0 | m_C = v; |
| 1906 | 0 | } |
| 1907 | |
|
| 1908 | |
|
| 1909 | |
|
| 1910 | |
|
| 1911 | |
|
| 1912 | |
|
| 1913 | |
public String toleranceParameterTipText() { |
| 1914 | 0 | return "The tolerance parameter (shouldn't be changed)."; |
| 1915 | |
} |
| 1916 | |
|
| 1917 | |
|
| 1918 | |
|
| 1919 | |
|
| 1920 | |
|
| 1921 | |
public double getToleranceParameter() { |
| 1922 | |
|
| 1923 | 0 | return m_tol; |
| 1924 | |
} |
| 1925 | |
|
| 1926 | |
|
| 1927 | |
|
| 1928 | |
|
| 1929 | |
|
| 1930 | |
public void setToleranceParameter(double v) { |
| 1931 | |
|
| 1932 | 0 | m_tol = v; |
| 1933 | 0 | } |
| 1934 | |
|
| 1935 | |
|
| 1936 | |
|
| 1937 | |
|
| 1938 | |
|
| 1939 | |
|
| 1940 | |
public String epsilonTipText() { |
| 1941 | 0 | return "The epsilon for round-off error (shouldn't be changed)."; |
| 1942 | |
} |
| 1943 | |
|
| 1944 | |
|
| 1945 | |
|
| 1946 | |
|
| 1947 | |
|
| 1948 | |
public double getEpsilon() { |
| 1949 | |
|
| 1950 | 0 | return m_eps; |
| 1951 | |
} |
| 1952 | |
|
| 1953 | |
|
| 1954 | |
|
| 1955 | |
|
| 1956 | |
|
| 1957 | |
public void setEpsilon(double v) { |
| 1958 | |
|
| 1959 | 0 | m_eps = v; |
| 1960 | 0 | } |
| 1961 | |
|
| 1962 | |
|
| 1963 | |
|
| 1964 | |
|
| 1965 | |
|
| 1966 | |
|
| 1967 | |
public String filterTypeTipText() { |
| 1968 | 0 | return "Determines how/if the data will be transformed."; |
| 1969 | |
} |
| 1970 | |
|
| 1971 | |
|
| 1972 | |
|
| 1973 | |
|
| 1974 | |
|
| 1975 | |
|
| 1976 | |
|
| 1977 | |
public SelectedTag getFilterType() { |
| 1978 | |
|
| 1979 | 0 | return new SelectedTag(m_filterType, TAGS_FILTER); |
| 1980 | |
} |
| 1981 | |
|
| 1982 | |
|
| 1983 | |
|
| 1984 | |
|
| 1985 | |
|
| 1986 | |
|
| 1987 | |
|
| 1988 | |
public void setFilterType(SelectedTag newType) { |
| 1989 | |
|
| 1990 | 0 | if (newType.getTags() == TAGS_FILTER) { |
| 1991 | 0 | m_filterType = newType.getSelectedTag().getID(); |
| 1992 | |
} |
| 1993 | 0 | } |
| 1994 | |
|
| 1995 | |
|
| 1996 | |
|
| 1997 | |
|
| 1998 | |
|
| 1999 | |
|
| 2000 | |
public String buildLogisticModelsTipText() { |
| 2001 | 0 | return "Whether to fit logistic models to the outputs (for proper " |
| 2002 | |
+ "probability estimates)."; |
| 2003 | |
} |
| 2004 | |
|
| 2005 | |
|
| 2006 | |
|
| 2007 | |
|
| 2008 | |
|
| 2009 | |
|
| 2010 | |
public boolean getBuildLogisticModels() { |
| 2011 | |
|
| 2012 | 0 | return m_fitLogisticModels; |
| 2013 | |
} |
| 2014 | |
|
| 2015 | |
|
| 2016 | |
|
| 2017 | |
|
| 2018 | |
|
| 2019 | |
|
| 2020 | |
public void setBuildLogisticModels(boolean newbuildLogisticModels) { |
| 2021 | |
|
| 2022 | 0 | m_fitLogisticModels = newbuildLogisticModels; |
| 2023 | 0 | } |
| 2024 | |
|
| 2025 | |
|
| 2026 | |
|
| 2027 | |
|
| 2028 | |
|
| 2029 | |
|
| 2030 | |
public String numFoldsTipText() { |
| 2031 | 0 | return "The number of folds for cross-validation used to generate " |
| 2032 | |
+ "training data for logistic models (-1 means use training data)."; |
| 2033 | |
} |
| 2034 | |
|
| 2035 | |
|
| 2036 | |
|
| 2037 | |
|
| 2038 | |
|
| 2039 | |
|
| 2040 | |
public int getNumFolds() { |
| 2041 | |
|
| 2042 | 0 | return m_numFolds; |
| 2043 | |
} |
| 2044 | |
|
| 2045 | |
|
| 2046 | |
|
| 2047 | |
|
| 2048 | |
|
| 2049 | |
|
| 2050 | |
public void setNumFolds(int newnumFolds) { |
| 2051 | |
|
| 2052 | 0 | m_numFolds = newnumFolds; |
| 2053 | 0 | } |
| 2054 | |
|
| 2055 | |
|
| 2056 | |
|
| 2057 | |
|
| 2058 | |
|
| 2059 | |
|
| 2060 | |
public String randomSeedTipText() { |
| 2061 | 0 | return "Random number seed for the cross-validation."; |
| 2062 | |
} |
| 2063 | |
|
| 2064 | |
|
| 2065 | |
|
| 2066 | |
|
| 2067 | |
|
| 2068 | |
|
| 2069 | |
public int getRandomSeed() { |
| 2070 | |
|
| 2071 | 0 | return m_randomSeed; |
| 2072 | |
} |
| 2073 | |
|
| 2074 | |
|
| 2075 | |
|
| 2076 | |
|
| 2077 | |
|
| 2078 | |
|
| 2079 | |
public void setRandomSeed(int newrandomSeed) { |
| 2080 | |
|
| 2081 | 0 | m_randomSeed = newrandomSeed; |
| 2082 | 0 | } |
| 2083 | |
|
| 2084 | |
|
| 2085 | |
|
| 2086 | |
|
| 2087 | |
|
| 2088 | |
|
| 2089 | |
public String toString() { |
| 2090 | |
|
| 2091 | 0 | StringBuffer text = new StringBuffer(); |
| 2092 | |
|
| 2093 | 0 | if ((m_classAttribute == null)) { |
| 2094 | 0 | return "SMO: No model built yet."; |
| 2095 | |
} |
| 2096 | |
try { |
| 2097 | 0 | text.append("SMO\n\n"); |
| 2098 | 0 | text.append("Kernel used:\n " + m_kernel.toString() + "\n\n"); |
| 2099 | |
|
| 2100 | 0 | for (int i = 0; i < m_classAttribute.numValues(); i++) { |
| 2101 | 0 | for (int j = i + 1; j < m_classAttribute.numValues(); j++) { |
| 2102 | 0 | text.append("Classifier for classes: " + |
| 2103 | |
m_classAttribute.value(i) + ", " + |
| 2104 | |
m_classAttribute.value(j) + "\n\n"); |
| 2105 | 0 | text.append(m_classifiers[i][j]); |
| 2106 | 0 | if (m_fitLogisticModels) { |
| 2107 | 0 | text.append("\n\n"); |
| 2108 | 0 | if ( m_classifiers[i][j].m_logistic == null) { |
| 2109 | 0 | text.append("No logistic model has been fit.\n"); |
| 2110 | |
} else { |
| 2111 | 0 | text.append(m_classifiers[i][j].m_logistic); |
| 2112 | |
} |
| 2113 | |
} |
| 2114 | 0 | text.append("\n\n"); |
| 2115 | |
} |
| 2116 | |
} |
| 2117 | 0 | } catch (Exception e) { |
| 2118 | 0 | return "Can't print SMO classifier."; |
| 2119 | 0 | } |
| 2120 | |
|
| 2121 | 0 | return text.toString(); |
| 2122 | |
} |
| 2123 | |
|
| 2124 | |
|
| 2125 | |
|
| 2126 | |
|
| 2127 | |
|
| 2128 | |
|
| 2129 | |
public String getRevision() { |
| 2130 | 0 | return RevisionUtils.extract("$Revision: 8034 $"); |
| 2131 | |
} |
| 2132 | |
|
| 2133 | |
|
| 2134 | |
|
| 2135 | |
|
| 2136 | |
public static void main(String[] argv) { |
| 2137 | 0 | runClassifier(new SMO(), argv); |
| 2138 | 0 | } |
| 2139 | |
} |
| 2140 | |
|