diff --git a/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/Optimizer.java b/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/Optimizer.java new file mode 100644 index 000000000..b94d092e3 --- /dev/null +++ b/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/Optimizer.java @@ -0,0 +1,8 @@ +package com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier; + +import com.expleague.commons.math.vectors.Mx; +import com.expleague.commons.math.vectors.impl.mx.SparseMx; + +public interface Optimizer { + SparseMx optimizeWeights(Mx trainingSet, String[] correctTopics, SparseMx prevWeights); +} diff --git a/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/SklearnSgdPredictor.java b/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/SklearnSgdPredictor.java index 34ab809e7..32b82b220 100644 --- a/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/SklearnSgdPredictor.java +++ b/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/SklearnSgdPredictor.java @@ -1,15 +1,16 @@ package com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier; -import com.expleague.commons.math.vectors.Mx; import com.expleague.commons.math.vectors.MxTools; import com.expleague.commons.math.vectors.Vec; import com.expleague.commons.math.vectors.VecTools; -import com.expleague.commons.math.vectors.impl.mx.RowsVecArrayMx; +import com.expleague.commons.math.vectors.impl.mx.SparseMx; import com.expleague.commons.math.vectors.impl.vectors.ArrayVec; import com.expleague.commons.math.vectors.impl.vectors.SparseVec; import com.google.common.annotations.VisibleForTesting; import gnu.trove.map.TObjectIntMap; import gnu.trove.map.hash.TObjectIntHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.BufferedReader; import java.io.File; @@ -24,6 +25,7 @@ import java.util.stream.StreamSupport; public class SklearnSgdPredictor implements TopicsPredictor { + private static final Logger LOGGER = LoggerFactory.getLogger(SklearnSgdPredictor.class.getName()); private static final Pattern PATTERN = Pattern.compile("\\b\\w\\w+\\b", Pattern.UNICODE_CHARACTER_CLASS); private final String weightsPath; @@ -32,7 +34,7 @@ public class SklearnSgdPredictor implements TopicsPredictor { //lazy loading private TObjectIntMap countVectorizer; private Vec intercept; - private Mx weights; + private SparseMx weights; private String[] topics; public SklearnSgdPredictor(String cntVectorizerPath, String weightsPath) { @@ -40,35 +42,37 @@ public SklearnSgdPredictor(String cntVectorizerPath, String weightsPath) { this.cntVectorizerPath = cntVectorizerPath; } + private SparseVec vectorize(Map tfIdf) { + final int[] indices = new int[tfIdf.size()]; + final double[] values = new double[tfIdf.size()]; + + int ind = 0; + for (String key : tfIdf.keySet()) { + final int valueIndex = countVectorizer.get(key); + indices[ind] = valueIndex; + values[ind] = tfIdf.get(key); + ind++; + } + + return new SparseVec(countVectorizer.size(), indices, values); + } + @Override public Topic[] predict(Document document) { loadMeta(); loadVocabulary(); - final Map tfIdf = document.tfIdf(); - final int[] indices = new int[tfIdf.size()]; - final double[] values = new double[tfIdf.size()]; - { //convert TF-IDF features to sparse vector - int ind = 0; - for (String key : tfIdf.keySet()) { - final int valueIndex = countVectorizer.get(key); - indices[ind] = valueIndex; - values[ind] = tfIdf.get(key); - ind++; - } - } - final Vec probabilities; { // compute topic probabilities - final SparseVec vectorized = new SparseVec(countVectorizer.size(), indices, values); + final SparseVec vectorized = vectorize(document.tfIdf()); final Vec score = MxTools.multiply(weights, vectorized); final Vec sum = VecTools.sum(score, intercept); final Vec scaled = VecTools.scale(sum, -1); VecTools.exp(scaled); - final double[] ones = new double[score.dim()]; - Arrays.fill(ones, 1); - final Vec vecOnes = new ArrayVec(ones, 0, ones.length); + final Vec vecOnes = new ArrayVec(score.dim()); + VecTools.fill(vecOnes, 1); + probabilities = VecTools.sum(scaled, vecOnes); for (int i = 0; i < probabilities.dim(); i++) { double changed = 1 / probabilities.get(i); @@ -87,6 +91,19 @@ public Topic[] predict(Document document) { return result; } + @Override + public void updateWeights(SparseMx weights) { + this.weights = weights; + } + + public SparseMx getWeights() { + return weights; + } + + public String[] getTopics() { + return topics; + } + public void init() { loadMeta(); loadVocabulary(); @@ -107,7 +124,7 @@ private void loadMeta() { topics[i] = br.readLine(); } - final Vec[] coef = new Vec[classes]; + final SparseVec[] coef = new SparseVec[classes]; String line; for (int index = 0; index < classes; index++) { line = br.readLine(); @@ -123,11 +140,10 @@ private void loadMeta() { values[i / 2] = value; } - final SparseVec sparseVec = new SparseVec(currentFeatures, indeces, values); - coef[index] = sparseVec; + coef[index] = new SparseVec(currentFeatures, indeces, values); } - weights = new RowsVecArrayMx(coef); + weights = new SparseMx(coef); MxTools.transpose(weights); line = br.readLine(); diff --git a/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/SoftmaxRegressionOptimizer.java b/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/SoftmaxRegressionOptimizer.java new file mode 100644 index 000000000..acf4eac0e --- /dev/null +++ b/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/SoftmaxRegressionOptimizer.java @@ -0,0 +1,174 @@ +package com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier; + +import com.expleague.commons.math.vectors.Mx; +import com.expleague.commons.math.vectors.MxTools; +import com.expleague.commons.math.vectors.Vec; +import com.expleague.commons.math.vectors.VecTools; +import com.expleague.commons.math.vectors.impl.mx.SparseMx; +import com.expleague.commons.math.vectors.impl.vectors.ArrayVec; +import com.expleague.commons.math.vectors.impl.vectors.SparseVec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +public class SoftmaxRegressionOptimizer implements Optimizer { + private static final Logger LOGGER = LoggerFactory.getLogger(SoftmaxRegressionOptimizer.class.getName()); + private final List topicList; + private ExecutorService executor = Executors.newFixedThreadPool(8); + + public SoftmaxRegressionOptimizer(String[] topics) { + topicList = Arrays.asList(topics); + } + + private Mx l1Gradient(SparseMx weights) { + final Mx gradient = new SparseMx(weights.rows(), weights.columns()); + + for (int i = 0; i < weights.rows(); i++) { + for (int j = 0; j < weights.columns(); j++) { + gradient.set(i, j, Math.signum(weights.get(i, j))); + } + } + + return gradient; + } + + private Mx l2Gradient(SparseMx weights, SparseMx prevWeights) { + //return VecTools.subtract(VecTools.scale(weights, 2), prevWeights); ??? + Mx gradient = new SparseMx(weights.rows(), weights.columns()); + + for (int i = 0; i < weights.rows(); i++) { + for (int j = 0; j < weights.columns(); j++) { + gradient.set(i, j, 2 * (weights.get(i, j) - prevWeights.get(i, j))); + } + } + + return gradient; + } + + private Vec computeSoftmaxValues(SparseMx weights, Mx trainingSet, int[] correctTopics) { + final double[] softmaxValues = new double[trainingSet.rows()]; + + CountDownLatch latch = new CountDownLatch(trainingSet.rows()); + for (int i = 0; i < trainingSet.rows(); i++) { + final int finalI = i; + executor.execute(() -> { + final Vec x = VecTools.copySparse(trainingSet.row(finalI)); + final int index = correctTopics[finalI]; + final Vec mul = MxTools.multiply(weights, x); + VecTools.exp(mul); + + final double numer = mul.get(index); + double denom = 0.0; + for (int k = 0; k < weights.rows(); k++) { + denom += mul.get(k); + } + + //LOGGER.info("values size {} and setting value index {}", softmaxValues.dim(), finalI); + //System.out.println("size " + softmaxValues.length + " index" + finalI); + softmaxValues[finalI] = numer / denom; + latch.countDown(); + }); + } + + try { + latch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + return new ArrayVec(softmaxValues, 0, softmaxValues.length); + //return softmaxValues; + } + + private SoftmaxData softmaxGradient(SparseMx weights, Mx trainingSet, int[] correctTopics) { + final SparseVec[] gradients = new SparseVec[weights.rows()]; + final Vec softmaxValues = computeSoftmaxValues(weights, trainingSet, correctTopics); + + CountDownLatch latch = new CountDownLatch(weights.rows()); + for (int j = 0; j < weights.rows(); j++) { + //LOGGER.info("weights {} component", j); + final int finalJ = j; + + executor.execute(() -> { + SparseVec grad = new SparseVec(weights.columns()); + final SparseVec scales = new SparseVec(trainingSet.rows()); + for (int i = 0; i < trainingSet.rows(); i++) { + final int index = correctTopics[i]; + final int indicator = index == finalJ ? 1 : 0; + scales.set(i, indicator - softmaxValues.get(i)); + } + + for (int i = 0; i < trainingSet.rows(); i++) { + final Vec x = VecTools.copySparse(trainingSet.row(i)); + VecTools.scale(x, scales); + grad = VecTools.sum(grad, x); + + } + + gradients[finalJ] = grad;//VecTools.scale(grad, -1.0 / trainingSet.rows()); + latch.countDown(); + }); + } + + try { + latch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + return new SoftmaxData(VecTools.sum(softmaxValues), new SparseMx(gradients)); + } + + public SparseMx optimizeWeights(Mx trainingSet, String[] correctTopics, SparseMx prevWeights) { + final double alpha = 1e-1; + final double lambda1 = 1e-2; + final double lambda2 = 1e-1; + final double maxIter = 100; + final int[] indeces = Stream.of(correctTopics).mapToInt(topicList::indexOf).toArray(); + + double previousValue = 0; + SparseMx weights = new SparseMx(prevWeights.rows(), prevWeights.columns()); + for (int iteration = 1; iteration <= maxIter; iteration++) { + LOGGER.info("Iteration {}", iteration); + final SoftmaxData data = softmaxGradient(weights, trainingSet, indeces); + LOGGER.info("Softmax value : {}", data.value); + if (Math.abs(data.value - previousValue) < 1e-3) { + break; + } + + previousValue = data.value; + Mx l1 = l1Gradient(weights); + Mx l2 = l2Gradient(weights, prevWeights); + + //SoftmaxData = VecTools.scale(SoftmaxData, alpha); + //l1 = VecTools.scale(l1, lambda1); + //l2 = VecTools.scale(l2, lambda2); + // weights = VecTools.subtract(weights, VecTools.sum(SoftmaxData, VecTools.sum(l1, l2))); + + for (int i = 0; i < weights.rows(); i++) { + for (int j = 0; j < weights.columns(); j++) { + final double value = weights.get(i, j) + - alpha * (data.gradients.get(i, j) - lambda1 * l1.get(i, j) - lambda2 * l2.get(i, j)); + weights.set(i, j, value); + } + } + + } + + return weights; + } + + private class SoftmaxData { + private final double value; + private final SparseMx gradients; + + SoftmaxData(double value, SparseMx gradients) { + this.value = value; + this.gradients = gradients; + } + } + +} diff --git a/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/TopicsPredictor.java b/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/TopicsPredictor.java index 74a73a44c..3e5925987 100644 --- a/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/TopicsPredictor.java +++ b/examples/src/main/java/com/spbsu/flamestream/example/bl/text_classifier/ops/filtering/classifier/TopicsPredictor.java @@ -1,8 +1,11 @@ package com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier; +import com.expleague.commons.math.vectors.impl.mx.SparseMx; + public interface TopicsPredictor { default void init() { } Topic[] predict(Document document); + void updateWeights(SparseMx weights); } diff --git a/examples/src/main/resources/classifier_weights b/examples/src/main/resources/classifier_weights index 1abf71572..ebb3fe9bb 100644 --- a/examples/src/main/resources/classifier_weights +++ b/examples/src/main/resources/classifier_weights @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fa8f7d3ee36da0fbed333705f04a885d11359bd3cab6259532e2d8c3bf6686b5 -size 19339313 +oid sha256:17571895048675d5fd00160e848954c7cb0926050a08d7c76660b4d168de66e5 +size 20795758 diff --git a/examples/src/main/resources/cnt_vectorizer b/examples/src/main/resources/cnt_vectorizer index 4317a8804..cb178c9b4 100644 --- a/examples/src/main/resources/cnt_vectorizer +++ b/examples/src/main/resources/cnt_vectorizer @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1963459dc6e19e3698c5d7a0a47abe3a15b924cba314b61df41cf6ee0c2d83f5 -size 15398717 +oid sha256:13a42e2a6523a6138cb94826eab355e98f9f028a4e0973d40c2a4302fff04e48 +size 15040612 diff --git a/examples/src/test/java/com/spbsu/flamestream/example/bl/classifier/PredictorStreamTest.java b/examples/src/test/java/com/spbsu/flamestream/example/bl/classifier/PredictorStreamTest.java index 210ca7f86..f31cefd48 100644 --- a/examples/src/test/java/com/spbsu/flamestream/example/bl/classifier/PredictorStreamTest.java +++ b/examples/src/test/java/com/spbsu/flamestream/example/bl/classifier/PredictorStreamTest.java @@ -156,7 +156,7 @@ public Stream apply(Document document) { } } - private static double[] parseDoubles(String line) { + public static double[] parseDoubles(String line) { return Arrays .stream(line.split(" ")) .mapToDouble(Double::parseDouble) diff --git a/examples/src/test/java/com/spbsu/flamestream/example/bl/text_classifier/LentaTest.java b/examples/src/test/java/com/spbsu/flamestream/example/bl/text_classifier/LentaTest.java index 8d8fffa24..5054d0e9a 100644 --- a/examples/src/test/java/com/spbsu/flamestream/example/bl/text_classifier/LentaTest.java +++ b/examples/src/test/java/com/spbsu/flamestream/example/bl/text_classifier/LentaTest.java @@ -1,10 +1,16 @@ package com.spbsu.flamestream.example.bl.text_classifier; import akka.actor.ActorSystem; +import com.expleague.commons.math.vectors.Mx; +import com.expleague.commons.math.vectors.impl.mx.SparseMx; +import com.expleague.commons.math.vectors.impl.vectors.SparseVec; import com.spbsu.flamestream.example.bl.text_classifier.model.Prediction; import com.spbsu.flamestream.example.bl.text_classifier.model.TextDocument; import com.spbsu.flamestream.example.bl.text_classifier.model.TfIdfObject; +import com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier.Document; +import com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier.Optimizer; import com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier.SklearnSgdPredictor; +import com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier.SoftmaxRegressionOptimizer; import com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier.Topic; import com.spbsu.flamestream.example.bl.text_classifier.ops.filtering.classifier.TopicsPredictor; import com.spbsu.flamestream.runtime.FlameRuntime; @@ -25,9 +31,13 @@ import scala.concurrent.Await; import scala.concurrent.duration.Duration; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.Reader; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; @@ -46,7 +56,9 @@ import java.util.stream.Stream; import java.util.stream.StreamSupport; +import static com.spbsu.flamestream.example.bl.classifier.PredictorStreamTest.parseDoubles; import static java.util.stream.Collectors.toList; +import static org.testng.Assert.assertTrue; public class LentaTest extends FlameAkkaSuite { private static final Logger LOGGER = LoggerFactory.getLogger(LentaTest.class); @@ -84,6 +96,95 @@ private Stream documents(String path) throws IOException { } + @Test + public void partialFitTest() { + final String CNT_VECTORIZER_PATH = "src/main/resources/cnt_vectorizer"; + final String WEIGHTS_PATH = "src/main/resources/classifier_weights"; + final String PATH_TO_TEST_DATA = "src/test/resources/sklearn_prediction"; + + final List topics = new ArrayList<>(); + final List texts = new ArrayList<>(); + final List mx = new ArrayList<>(); + List documents = new ArrayList<>(); + final SklearnSgdPredictor predictor = new SklearnSgdPredictor(CNT_VECTORIZER_PATH, WEIGHTS_PATH); + predictor.init(); + try (BufferedReader br = new BufferedReader(new FileReader(new File(PATH_TO_TEST_DATA)))) { + final double[] data = parseDoubles(br.readLine()); + final int testCount = (int) data[0]; + final int features = (int) data[1]; + + for (int i = 0; i < testCount; i++) { + //final double[] pyPrediction = parseDoubles(br.readLine()); + + final String docText = br.readLine().toLowerCase(); + texts.add(docText); + + String topic = br.readLine(); + topics.add(topic); + final double[] info = parseDoubles(br.readLine()); + final int[] indeces = new int[info.length / 2]; + final double[] values = new double[info.length / 2]; + for (int k = 0; k < info.length; k += 2) { + final int index = (int) info[k]; + final double value = info[k + 1]; + + indeces[k / 2] = index; + values[k / 2] = value; + } + + final Map tfIdf = new HashMap<>(); + SparseVec vec = new SparseVec(features, indeces, values); + + SklearnSgdPredictor.text2words(docText).forEach(word -> { + final int featureIndex = predictor.wordIndex(word); + tfIdf.put(word, vec.get(featureIndex)); + }); + final Document document = new Document(tfIdf); + documents.add(document); + + mx.add(vec); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + final int len = topics.size(); + final int testsize = 30; + + List testTopics = topics.stream().skip(len - testsize).collect(Collectors.toList()); + List testTexts = texts.stream().skip(len - testsize).collect(Collectors.toList()); + documents = documents.stream().skip(len - testsize).collect(Collectors.toList()); + + Mx trainingSet = new SparseMx(mx.stream().limit(len - testsize).toArray(SparseVec[]::new)); + LOGGER.info("Updating weights"); + Optimizer optimizer = new SoftmaxRegressionOptimizer(predictor.getTopics()); + String[] correctTopics = topics.stream().limit(len - testsize).toArray(String[]::new); + SparseMx newWeights = optimizer.optimizeWeights(trainingSet, correctTopics, predictor.getWeights()); + predictor.updateWeights(newWeights); + + double truePositives = 0; + for (int i = 0; i < testsize; i++) { + String text = testTexts.get(i); + String ans = testTopics.get(i); + Document doc = documents.get(i); + + Topic[] prediction = predictor.predict(doc); + + Arrays.sort(prediction); + if (ans.equals(prediction[0].name())) { + truePositives++; + } + LOGGER.info("Doc: {}", text); + LOGGER.info("Real answers: {}", ans); + LOGGER.info("Predict: {}", (Object) prediction); + LOGGER.info("\n"); + } + + double accuracy = truePositives / testsize; + LOGGER.info("Accuracy: {}", accuracy); + //assertTrue(accuracy >= 0.62); + } + @Test public void lentaTest() throws InterruptedException, IOException, TimeoutException { final String testFilePath = "lenta/lenta-ru-news.csv"; diff --git a/examples/src/test/resources/sklearn_prediction b/examples/src/test/resources/sklearn_prediction index 3eeab30e1..535adb372 100644 --- a/examples/src/test/resources/sklearn_prediction +++ b/examples/src/test/resources/sklearn_prediction @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c470f51a842885dbf78a178daf1f6a184301cf1ccc9859330f40c9b8fea8b5a9 -size 24066214 +oid sha256:341a97825e238895ec4a64b347de19bc77bab0d74b58c08d8b477d8f47aa7506 +size 50757145