This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.claytantor.mahout.test.app; | |
import java.util.Collections; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
import org.apache.mahout.classifier.sgd.L1; | |
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; | |
import org.apache.mahout.math.DenseVector; | |
import org.apache.mahout.math.RandomAccessSparseVector; | |
import org.apache.mahout.math.Vector; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import com.google.common.collect.Lists; | |
public class OnlineLROne { | |
public static class Point { | |
public int x; | |
public int y; | |
public Point(int x, int y) { | |
this.x = x; | |
this.y = y; | |
} | |
@Override | |
public boolean equals(Object arg0) { | |
Point p = (Point) arg0; | |
return ((this.x == p.x) && (this.y == p.y)); | |
} | |
@Override | |
public String toString() { | |
// TODO Auto-generated method stub | |
return this.x + " , " + this.y; | |
} | |
} | |
public static Vector getVector(Point point) { | |
Vector v = new DenseVector(3); | |
v.set(0, point.x); | |
v.set(1, point.y); | |
v.set(2, 1); | |
return v; | |
} | |
/** | |
* @param args | |
*/ | |
public static void main(String[] args) { | |
Logger logger = LoggerFactory.getLogger(ReccomenderOne.class); | |
logger.info("Hello World"); | |
Map<Point, Integer> points = new HashMap<Point, Integer>(); | |
points.put(new Point(0, 0), 0); | |
points.put(new Point(1, 1), 0); | |
points.put(new Point(1, 0), 0); | |
points.put(new Point(0, 1), 0); | |
points.put(new Point(2, 2), 0); | |
points.put(new Point(8, 8), 1); | |
points.put(new Point(8, 9), 1); | |
points.put(new Point(9, 8), 1); | |
points.put(new Point(9, 9), 1); | |
OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression(2, | |
3, new L1()); | |
// this is a really big value which will make the model very cautious | |
// for lambda = 0.1, the first example below should be about .83 certain | |
// for lambda = 0.01, the first example below should be about 0.98 | |
// certain | |
learningAlgo.lambda(0.1); | |
learningAlgo.learningRate(4); | |
System.out.println("training model \n"); | |
final List<Point> keys = Lists.newArrayList(points.keySet()); | |
// 200 times through the training data is probably over-kill. It doesn't | |
// matter | |
// for tiny data. The key here is total number of points seen, not | |
// number of passes. | |
for (int i = 0; i < 200; i++) { | |
// randomize training data on each iteration | |
Collections.shuffle(keys); | |
for (Point point : keys) { | |
Vector v = getVector(point); | |
learningAlgo.train(points.get(point), v); | |
} | |
} | |
learningAlgo.close(); | |
// now classify real data | |
Vector v = new RandomAccessSparseVector(3); | |
v.set(0, 0.5); | |
v.set(1, 0.5); | |
v.set(2, 1); | |
Vector r = learningAlgo.classifyFull(v); | |
System.out.println(r); | |
System.out.println("ans = "); | |
System.out.printf("no of categories = %d\n", | |
learningAlgo.numCategories()); | |
System.out.printf("no of features = %d\n", learningAlgo.numFeatures()); | |
System.out.printf("Probability of cluster 0 = %.3f\n", r.get(0)); | |
System.out.printf("Probability of cluster 1 = %.3f\n", r.get(1)); | |
v.set(0, 4.5); | |
v.set(1, 6.5); | |
v.set(2, 1); | |
r = learningAlgo.classifyFull(v); | |
System.out.println("ans = "); | |
System.out.printf("no of categories = %d\n", | |
learningAlgo.numCategories()); | |
System.out.printf("no of features = %d\n", learningAlgo.numFeatures()); | |
System.out.printf("Probability of cluster 0 = %.3f\n", r.get(0)); | |
System.out.printf("Probability of cluster 1 = %.3f\n", r.get(1)); | |
// show how the score varies along a line from 0,0 to 1,1 | |
System.out.printf("\nx\tscore\n"); | |
for (int i = 0; i < 100; i++) { | |
final double x = 0.0 + i / 10.0; | |
v.set(0, x); | |
v.set(1, x); | |
v.set(2, 1); | |
r = learningAlgo.classifyFull(v); | |
System.out.printf("%.2f\t%.3f\n", x, r.get(1)); | |
} | |
} | |
} |
No comments:
Post a Comment