From 5a4b6674ab373db443e684b69284a31c5c6a9b6b Mon Sep 17 00:00:00 2001 From: jhartman Date: Mon, 7 Jul 2014 17:03:13 -0700 Subject: [PATCH] Adding a new UDF called Ndcg to support calculating Ndcg. A couple of different configuration options are available for calculating ndcg. You can specify positional values per range, use a standard logarithmic discounting function, or use a custom function. --- .../datafu/pig/stats/LogScoringFunction.java | 16 ++ src/java/datafu/pig/stats/Ndcg.java | 138 ++++++++++ .../pig/stats/PositionScoringFunction.java | 14 + .../pig/stats/RangeScoringFunction.java | 43 +++ .../pig/stats/UnaryScoringFunction.java | 15 ++ src/java/datafu/pig/util/NumericalRange.java | 249 ++++++++++++++++++ src/java/datafu/pig/util/RangeMap.java | 186 +++++++++++++ src/test/datafu/pig/util/TestRange.java | 79 ++++++ src/test/datafu/pig/util/TestRangeMap.java | 46 ++++ test/pig/datafu/test/pig/stats/NdcgTests.java | 139 ++++++++++ 10 files changed, 925 insertions(+) create mode 100644 src/java/datafu/pig/stats/LogScoringFunction.java create mode 100644 src/java/datafu/pig/stats/Ndcg.java create mode 100644 src/java/datafu/pig/stats/PositionScoringFunction.java create mode 100644 src/java/datafu/pig/stats/RangeScoringFunction.java create mode 100644 src/java/datafu/pig/stats/UnaryScoringFunction.java create mode 100644 src/java/datafu/pig/util/NumericalRange.java create mode 100644 src/java/datafu/pig/util/RangeMap.java create mode 100644 src/test/datafu/pig/util/TestRange.java create mode 100644 src/test/datafu/pig/util/TestRangeMap.java create mode 100644 test/pig/datafu/test/pig/stats/NdcgTests.java diff --git a/src/java/datafu/pig/stats/LogScoringFunction.java b/src/java/datafu/pig/stats/LogScoringFunction.java new file mode 100644 index 0000000..153b004 --- /dev/null +++ b/src/java/datafu/pig/stats/LogScoringFunction.java @@ -0,0 +1,16 @@ +package datafu.pig.stats; + + +/** + * Standard NDCG log discounting + */ +public class LogScoringFunction implements PositionScoringFunction +{ + private static final double LOG2 = Math.log(2); + + @Override + public double score(int position) + { + return position < 1 ? 1 : LOG2 / Math.log(2 + position); + } +} diff --git a/src/java/datafu/pig/stats/Ndcg.java b/src/java/datafu/pig/stats/Ndcg.java new file mode 100644 index 0000000..c5d0560 --- /dev/null +++ b/src/java/datafu/pig/stats/Ndcg.java @@ -0,0 +1,138 @@ +package datafu.pig.stats; + + +import datafu.pig.util.SimpleEvalFunc; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.Arrays; +import org.apache.pig.data.DataBag; +import org.apache.pig.data.DataType; +import org.apache.pig.data.Tuple; +import org.apache.pig.data.TupleFactory; +import org.apache.pig.impl.logicalLayer.schema.Schema; + + +/** + * Calculates a Normalized Discounted Cumulative Gain on a list of items. + * + * Datafu supports several different discounting algorithms out of the box. + * + * First is the usual log2 discounting function for a given position - log2(1 + pos) assuming 1-based indexing. Invoke Ndcg with the 'log2' + * constructor first this behavior + * + * The last discounting algorithm supported by datafu is a range & index-based position discounting function. + * Simply invoke the constructor with a list of numerical ranges for this behavior. See NumericalRange and RangeScoringFunction + * for a detailed specification. The parameter is intended to honor the principle of least surprise. + * + * The following example should be pretty straightforward: + * + * DEFINE NDCG datafu.pig.stats.Ndcg('0: 1.0', '1: 0.8', '[2,4): 0.75', '[4,8): 0.6', '[8,*): 0.5') + * + * If you are building your own carefully tuned discounting function code, it is possible to plug in any PositionScoringFunction + * into the Ndcg UDF. This can be done by invoking Ndcg such as: + * + * DEFINE NDCG datafu.pig.stats.Ndcg('custom', 'fully.qualified.class.name', 'arg1', 'arg2', ..., 'argn'); + * + * Your constructor should take a list of strings and must implement the PositionScoringFunction interface + + */ +public class Ndcg extends SimpleEvalFunc +{ + private final PositionScoringFunction positionDiscountingFunction; + + public Ndcg(String... config) + { + positionDiscountingFunction = fromConfig(config); + } + + public static PositionScoringFunction fromConfig(String... config) + { + if(config.length == 1 && "log2".equals(config[0].trim())) + { + return new LogScoringFunction(); + } + else if(config.length >=2 && "custom".equals(config[0].trim())) + { + String scoringFunctionClassName = config[1].trim(); + try + { + Class specifiedClass = Class.forName(scoringFunctionClassName); + try + { + Constructor constructor = specifiedClass.getConstructor(String[].class); + String[] parameters = Arrays.asList(config).subList(2, config.length).toArray(new String[config.length - 2]); + try + { + return (PositionScoringFunction) constructor.newInstance((Object)parameters); + } + catch (Exception e) + { + throw new IllegalArgumentException("Could not instantiate the position scoring function", e); + } + } + catch (NoSuchMethodException e) + { + throw new IllegalArgumentException("The constructor for class " + scoringFunctionClassName + "must implement a String[] constructor", e); + } + } + catch (ClassNotFoundException e) + { + throw new IllegalArgumentException("Class " + scoringFunctionClassName + "could not be found on classpath", e); + } + } + else + { + return new RangeScoringFunction(config); + } + } + + public Tuple call(DataBag bag) throws IOException + { + if (bag == null || bag.size() == 0) + return null; + + double cumulativeGain = 0.0; + double maxCumulativeGain = 0.0; + int position = 0; + + for(Tuple t : bag) + { + Object o = t.get(0); + if (!(o instanceof Number)) + { + throw new IllegalArgumentException("bag must have numerical values (and be non-null)"); + } + + Number n = (Number) o; + double itemScore = n.doubleValue(); + + if(itemScore < 0 || itemScore > 1) + { + throw new IllegalArgumentException("Scores must be already normalized from 0 to 1"); + } + + double positionScore = positionDiscountingFunction.score(position); + + cumulativeGain += itemScore * positionScore; + maxCumulativeGain += positionScore; + + position++; + } + + double ndcg = cumulativeGain / maxCumulativeGain; + + Tuple t = TupleFactory.getInstance().newTuple(1); + t.set(0, ndcg); + return t; + } + + @Override + public Schema outputSchema(Schema inputSchema) + { + Schema tupleSchema = new Schema(); + tupleSchema.add(new Schema.FieldSchema("ndcg", DataType.DOUBLE)); + + return tupleSchema; + } +} diff --git a/src/java/datafu/pig/stats/PositionScoringFunction.java b/src/java/datafu/pig/stats/PositionScoringFunction.java new file mode 100644 index 0000000..7141880 --- /dev/null +++ b/src/java/datafu/pig/stats/PositionScoringFunction.java @@ -0,0 +1,14 @@ +package datafu.pig.stats; + + +/** + * Interface for abstracting discounting according to a position within a list from other algorithms + */ +public interface PositionScoringFunction +{ + /** + * @param position The 0-based index to score + * @return + */ + double score(int position); +} diff --git a/src/java/datafu/pig/stats/RangeScoringFunction.java b/src/java/datafu/pig/stats/RangeScoringFunction.java new file mode 100644 index 0000000..d7054c5 --- /dev/null +++ b/src/java/datafu/pig/stats/RangeScoringFunction.java @@ -0,0 +1,43 @@ +package datafu.pig.stats; + + +import datafu.pig.util.RangeMap; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + + +/** + * Format for configuration strings is NumericalRange : Value + */ +public class RangeScoringFunction implements PositionScoringFunction +{ + private final RangeMap rangeMap; + + public RangeScoringFunction(String... configuration) + { + final List> ranges = new ArrayList>(); + + for(String config : configuration) + { + String[] split = config.split(":"); + if(split.length == 2) + { + ranges.add(new AbstractMap.SimpleEntry(split[0].trim(), Double.parseDouble(split[1]))); + } + else + { + throw new IllegalArgumentException("Format for range discounting function should be NumericalRange : Value"); + } + } + + this.rangeMap = new RangeMap(ranges); + } + + @Override + public double score(int position) + { + return rangeMap.get(position); + } +} diff --git a/src/java/datafu/pig/stats/UnaryScoringFunction.java b/src/java/datafu/pig/stats/UnaryScoringFunction.java new file mode 100644 index 0000000..4d8ac98 --- /dev/null +++ b/src/java/datafu/pig/stats/UnaryScoringFunction.java @@ -0,0 +1,15 @@ +package datafu.pig.stats; + + +public class UnaryScoringFunction implements PositionScoringFunction +{ + public UnaryScoringFunction(String... args) + { + } + + @Override + public double score(int position) + { + return 1.0; + } +} diff --git a/src/java/datafu/pig/util/NumericalRange.java b/src/java/datafu/pig/util/NumericalRange.java new file mode 100644 index 0000000..824c5fa --- /dev/null +++ b/src/java/datafu/pig/util/NumericalRange.java @@ -0,0 +1,249 @@ +package datafu.pig.util; + + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + + +/** + * This class represents a range of numbers, including whether or not the range is open or closed + * + * It has utility methods to parse numerical ranges using standard mathematical notation. For instance: + * [-1e-9,10) + * + * It also has support for infinite ranges through the following syntax + * [-10,*) + * + * Single points are also supported: + * 5 + * + * As are entirely open ranges: * + * + * @author jhartman + * @author rugupta + */ +public class NumericalRange implements Comparable +{ + private static final String DOUBLE_REGEX = "-?\\d+(\\.\\d+)?"; + private static final String NUMBER_REGEX = DOUBLE_REGEX + "(" + "[eE]" + DOUBLE_REGEX + ")?"; + private static final String NUMBER_OR_STAR_REGEX = "(" + NUMBER_REGEX + ")" + "|\\*"; + private static final Pattern NUMBER_OR_STAR_PATTERN = Pattern.compile(NUMBER_OR_STAR_REGEX); + + private static final String LOWER_REGEX = "[\\(\\[]" + "(" + NUMBER_OR_STAR_REGEX + ")" + "\\s*,\\s*"; + private static final String UPPER_REGEX = "(" + NUMBER_OR_STAR_REGEX + ")" + "\\s*[\\)\\]]"; + + private static final Pattern LOWER_PATTERN = Pattern.compile(LOWER_REGEX); + private static final Pattern UPPER_PATTERN = Pattern.compile(UPPER_REGEX); + + public static final String RANGE_REGEX = LOWER_REGEX + UPPER_REGEX; + public static final Pattern RANGE_PATTERN = Pattern.compile(RANGE_REGEX); + + private final double lower; + private final double upper; + private final boolean lowerClosed; + private final boolean upperClosed; + + public NumericalRange(double lower, double upper, boolean lowerClosed, boolean upperClosed) + { + this.lower = lower; + this.upper = upper; + this.lowerClosed = lowerClosed; + this.upperClosed = upperClosed; + } + + public double getLower() + { + return lower; + } + + public double getUpper() + { + return upper; + } + + public boolean getLowerClosed() + { + return lowerClosed; + } + + public boolean getUpperClosed() + { + return upperClosed; + } + + public static NumericalRange fromRangeString(String range) + { + range = range.trim(); + if(NUMBER_OR_STAR_PATTERN.matcher(range).matches()) + { + if("*".equals(range)) + { + return new NumericalRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, true, false); + } + else + { + double point = Double.parseDouble(range); + return new NumericalRange(point, point, true, true); + } + } + else if(RANGE_PATTERN.matcher(range).matches()) + { + double lower = handleRange(range, true); + double upper = handleRange(range, false); + if (lower > upper) + { + throw new IllegalArgumentException("Malformed range string: " + range); + } + + return new NumericalRange(lower, upper, isClosed(range, false), isClosed(range, true)); + } + else + { + throw new IllegalArgumentException("Malformed range. Expected a point or a range, but received " + range); + } + } + + private static boolean isClosed(String range, boolean upper) + { + if (upper) + { + return (range.charAt(range.length() - 1) == ']' ? true : false); + } + else + { + return (range.charAt(0) == '[' ? true : false); + } + } + + @Override + public String toString() + { + return (lowerClosed ? "[" : "(") + lower + "," + upper + (upperClosed ? "]" : ")"); + } + + public boolean isInRange(double value) + { + boolean aboveLower = lowerClosed ? value >= lower : value > lower; + + if(aboveLower) + { + boolean belowUpper = upperClosed ? value <= upper : value < upper; + return belowUpper; + } + else + { + return false; + } + } + + /** + * Determines whether or not two ranges are overlapping + * @param that + * @return + */ + public boolean overlaps(NumericalRange that) + { + NumericalRange lower, higher; + if(this.compareTo(that) <= 0) + { + lower = this; higher = that; + } + else + { + lower = that; higher = this; + } + + if(lower.getUpperClosed()) + { + if(higher.getLowerClosed()) + { + return lower.getUpper() >= higher.getLower(); + } + else + { + return lower.getUpper() > higher.getLower(); + } + } + else + { + if(higher.getLowerClosed()) + { + return lower.getUpper() > higher.getLower(); + } + else + { + return lower.getUpper() >= higher.getLower(); + } + } + } + + // (*,2) indicates -infinity to 2 and [5,*) indicates 5 to infinity + private static double handleRange(String range, boolean isLower) + { + Pattern pattern = isLower ? LOWER_PATTERN : UPPER_PATTERN; + Matcher matcher = pattern.matcher(range); + if (matcher.find()) + { + String numberOrStarString = matcher.group(1); + if("*".equals(numberOrStarString)) + { + return isLower ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY; + } + else + { + return Double.parseDouble(numberOrStarString); + } + } + else + { + throw new IllegalArgumentException("Can not parse bucket from range string " + range); + } + } + + @Override + public int compareTo(NumericalRange o) + { + // Rule of the comparison: + // 1. compare lower bound, close < open if they are equal + // 2. compare upper bound, close > open if they are equal + if (lower < o.lower) + { + return -1; + } + else if (lower == o.lower) + { + if (lowerClosed && !o.lowerClosed) + { + return -1; + } + else if (!lowerClosed && o.lowerClosed) + { + return 1; + } + else + { + if (upper < o.upper) + { + return -1; + } + else if (upper == o.upper) + { + if (!upperClosed && o.upperClosed) + return -1; + else if (upperClosed && !o.upperClosed) + return 1; + else + return 0; + } + else + { + return 1; + } + } + } + else + { + return 1; + } + } +} diff --git a/src/java/datafu/pig/util/RangeMap.java b/src/java/datafu/pig/util/RangeMap.java new file mode 100644 index 0000000..eb7835c --- /dev/null +++ b/src/java/datafu/pig/util/RangeMap.java @@ -0,0 +1,186 @@ +package datafu.pig.util; + + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + + +/** + * Representing a mapping from a range to any value. The result of calling get() on any number will be to find + * the value associated with the range containing that number + * @param + */ +public class RangeMap implements Map +{ + private final TreeMap treeMap; + + public RangeMap(List> rangeStrings) + { + List> ranges = new ArrayList>(rangeStrings.size()); + for(Entry rangeString : rangeStrings) + { + ranges.add(new AbstractMap.SimpleEntry(NumericalRange.fromRangeString(rangeString.getKey()), rangeString.getValue())); + } + + Collections.sort(ranges, new Comparator>() + { + @Override + public int compare(Entry entry1, Entry entry2) + { + return entry1.getKey().compareTo(entry2.getKey()); + } + }); + + for(int i = 1; i < ranges.size(); i++) + { + if(ranges.get(i).getKey().overlaps(ranges.get(i-1).getKey())) + { + throw new IllegalArgumentException("Ranges should not overlap, but found " + ranges.get(i) + " and " + ranges.get(i-1)); + } + } + + TreeMap map = new TreeMap(); + for(Entry entry : ranges) + { + map.put(entry.getKey(), entry.getValue()); + } + + treeMap = map; + } + + @Override + public int size() + { + return treeMap.size(); + } + + @Override + public boolean isEmpty() + { + return treeMap.isEmpty(); + } + + @Override + public boolean containsKey(Object o) + { + if(o instanceof Number) + { + Number num = (Number) o; + NumericalRange numRange = new NumericalRange(num.doubleValue(), num.doubleValue(), true, true); + + Entry floorEntry = treeMap.floorEntry(numRange); + if(floorEntry == null) + { + floorEntry = treeMap.firstEntry(); + } + + if(floorEntry.getKey().overlaps(numRange)) + { + return true; + } + else + { + Entry nextEntry = treeMap.higherEntry(floorEntry.getKey()); + if(nextEntry != null && nextEntry.getKey().overlaps(numRange)) + { + return true; + } + } + return false; + } + else + { + return false; + } + } + + + @Override + public boolean containsValue(Object o) + { + throw new UnsupportedOperationException(); + } + + @Override + public V get(Object o) + { + if(o instanceof Number) + { + Number num = (Number) o; + NumericalRange numRange = new NumericalRange(num.doubleValue(), num.doubleValue(), true, true); + + Entry floorEntry = treeMap.floorEntry(numRange); + if(floorEntry == null) + { + floorEntry = treeMap.firstEntry(); + } + + if(floorEntry.getKey().overlaps(numRange)) + { + return floorEntry.getValue(); + } + else + { + Entry nextEntry = treeMap.higherEntry(floorEntry.getKey()); + if(nextEntry != null && nextEntry.getKey().overlaps(numRange)) + { + return nextEntry.getValue(); + } + } + return null; + } + else + { + return null; + } + } + + @Override + public V put(Double aDouble, V v) + { + throw new UnsupportedOperationException("The range map is immutable"); + } + + @Override + public V remove(Object o) + { + throw new UnsupportedOperationException("The range map is immutable"); + } + + @Override + public void putAll(Map map) + { + throw new UnsupportedOperationException("The range map is immutable"); + } + + @Override + public void clear() + { + throw new UnsupportedOperationException("The range map is immutable"); + } + + @Override + public Set keySet() + { + throw new UnsupportedOperationException("Listing all possible numbers does not make sense"); + } + + @Override + public Collection values() + { + return treeMap.values(); + } + + @Override + public Set> entrySet() + { + throw new UnsupportedOperationException("Listing all possible numbers does not make sense"); + } +} diff --git a/src/test/datafu/pig/util/TestRange.java b/src/test/datafu/pig/util/TestRange.java new file mode 100644 index 0000000..7860453 --- /dev/null +++ b/src/test/datafu/pig/util/TestRange.java @@ -0,0 +1,79 @@ +package datafu.pig.util; + + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + +public class TestRange +{ + @Test + public void testIntegerBuckets() + { + NumericalRange range = NumericalRange.fromRangeString("[-10, 10)"); + assertEquals(range.isInRange(-10), true); + assertEquals(range.isInRange(9.99), true); + assertEquals(range.isInRange(10), false); + } + + @Test + public void testFloatingBuckets() + { + NumericalRange range = NumericalRange.fromRangeString("[-10.5, 10.0)"); + assertEquals(range.isInRange(-10), true); + } + + @Test + public void testScientificBuckets() + { + NumericalRange range = NumericalRange.fromRangeString("[-1.5e-9, 0]"); + assertEquals(range.isInRange(-1.5e-9), true); + assertEquals(range.isInRange(-1e-8), false); + } + + @Test + public void testNegInfinity() + { + NumericalRange range = NumericalRange.fromRangeString("(*,1)"); + assertEquals(range.isInRange(-1e10), true); + assertEquals(range.isInRange(0), true); + assertEquals(range.isInRange(1), false); + } + + @Test + public void testInfinity() + { + NumericalRange range = NumericalRange.fromRangeString("[0,*)"); + assertEquals(range.isInRange(0), false); + assertEquals(range.isInRange(1e10), true); + } + + @Test + public void testMalformedBuckets1() + { + try + { + NumericalRange range = NumericalRange.fromRangeString("[-10 10)"); + fail("invalid"); + } + catch(IllegalArgumentException e) + { + + } + } + + @Test + public void testMalformedBuckets2() + { + try + { + NumericalRange range = NumericalRange.fromRangeString("[-10 IMANUMBER)"); + fail("invalid"); + } + catch(IllegalArgumentException e) + { + + } + } +} diff --git a/src/test/datafu/pig/util/TestRangeMap.java b/src/test/datafu/pig/util/TestRangeMap.java new file mode 100644 index 0000000..bcbf7e7 --- /dev/null +++ b/src/test/datafu/pig/util/TestRangeMap.java @@ -0,0 +1,46 @@ +package datafu.pig.util; + + +import java.util.AbstractMap.SimpleEntry; +import java.util.Arrays; +import org.testng.annotations.Test; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + + +public class TestRangeMap +{ + @Test + public void testGet() + { + RangeMap rangeMap = new RangeMap(Arrays.asList( + new SimpleEntry("[0,1)",8), + new SimpleEntry("[1,3)",4), + new SimpleEntry("[3,8)",2), + new SimpleEntry("[8,*)",1) + )); + + assertEquals(rangeMap.get(0.0).intValue(), 8); + assertEquals(rangeMap.get(1.5).intValue(), 4); + assertEquals(rangeMap.get(3).intValue(), 2); + assertEquals(rangeMap.get(10).intValue(), 1); + assertEquals(rangeMap.get(-5), null); + } + + @Test + public void testOverlapping() + { + try + { + RangeMap rangeMap = new RangeMap(Arrays.asList( + new SimpleEntry("[0,1]",8), + new SimpleEntry("[1,3)",4) + )); + fail("illegal"); + } + catch(IllegalArgumentException e) + { + + } + } +} diff --git a/test/pig/datafu/test/pig/stats/NdcgTests.java b/test/pig/datafu/test/pig/stats/NdcgTests.java new file mode 100644 index 0000000..f765ea4 --- /dev/null +++ b/test/pig/datafu/test/pig/stats/NdcgTests.java @@ -0,0 +1,139 @@ +package datafu.test.pig.stats; + + +import datafu.pig.stats.PositionScoringFunction; +import datafu.test.pig.PigTests; +import java.util.List; +import org.adrianwalker.multilinestring.Multiline; +import org.apache.pig.data.Tuple; +import org.apache.pig.pigunit.PigTest; +import org.testng.Assert; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + + +/** + * @author jhartman + */ +public class NdcgTests extends PigTests +{ + + /** + register $JAR_PATH + + define Ndcg datafu.pig.stats.Ndcg($POSITIONSCORES); + + data_in = LOAD 'input' as (val:double); + + --describe data_in; + + data_out = GROUP data_in ALL; + + --describe data_out; + + data_out = FOREACH data_out GENERATE Ndcg($1) AS ndcg; + data_out = FOREACH data_out GENERATE FLATTEN(ndcg); + + --describe data_out; + + STORE data_out into 'output'; + */ + @Multiline + private String ndcgTest; + + @Test + public void testPositionalNdcg() throws Exception + { + PigTest test = createPigTestFromString(ndcgTest, + "POSITIONSCORES='0:1.0','1:0.75','[2,3]:0.5','[4,5]:0.25','(5,*]:0.0'"); + + String[] input = {"1.0", "0.9", "0.8", "0.7", "0.6", "0.5", "0.4", "0.3", "0.2", "0.1"}; + writeLinesToFile("input", input); + + test.runScript(); + + List output = getLinesForAlias(test, "data_out", true); + + double expectedNumerator = 1*1 + .9*.75 + .8*.5 + .7*.5 + .6*.25+.5*.25; + double expectedDenominator = 1 + .75 + 2*.5 + 2*.25; + double expectedNdcg = expectedNumerator / expectedDenominator; + + assertEquals(output.size(),1); + Assert.assertEquals(Double.parseDouble(output.get(0).get(0).toString()), expectedNdcg, 1e-9); + } + + @Test + public void testLogarithmicNdcg() throws Exception + { + PigTest test = createPigTestFromString(ndcgTest, + "POSITIONSCORES='log2'"); + + String[] input = {"1.0", "0.9", "0.8", "0.7", "0.6", "0.5", "0.4", "0.3", "0.2", "0.1"}; + writeLinesToFile("input", input); + + test.runScript(); + + List output = getLinesForAlias(test, "data_out", true); + + double[] positionValues = {1, 1/log2(3), 1/log2(4), 1/log2(5), 1/log2(6), 1/log2(7), 1/log2(8), 1/log2(9), 1/log2(10), 1/log2(11)}; + double[] scores = {1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1}; + double numerator = dotProduct(positionValues, scores); + double denominator = sum(positionValues); + double expectedNdcg = numerator / denominator; + + assertEquals(output.size(),1); + System.out.println(output); + Assert.assertEquals(Double.parseDouble(output.get(0).get(0).toString()), expectedNdcg, 1e-9); + } + + @Test + public void testCustomDiscounter() throws Exception + { + PigTest test = createPigTestFromString(ndcgTest, + "POSITIONSCORES='custom','datafu.pig.stats.UnaryScoringFunction'"); + + String[] input = {"1.0", "0.9", "0.8", "0.7", "0.6", "0.5", "0.4", "0.3", "0.2", "0.1"}; + writeLinesToFile("input", input); + + test.runScript(); + + List output = getLinesForAlias(test, "data_out", true); + + double[] scores = {1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1}; + double numerator = sum(scores); + double denominator = 10; + double expectedNdcg = numerator / denominator; + + assertEquals(output.size(),1); + System.out.println(output); + Assert.assertEquals(Double.parseDouble(output.get(0).get(0).toString()), expectedNdcg, 1e-9); + } + + + private final double sum(double... nums) + { + double sum = 0; + for(double num : nums) sum += num; + return sum; + } + + private final double dotProduct(double[] a, double[] b) + { + if(a.length != b.length) + { + throw new IllegalArgumentException(); + } + double sum = 0; + for(int i = 0; i < a.length; i++) + { + sum += a[i] * b[i]; + } + return sum; + } + + private final double log2(int num) + { + return Math.log(num) / Math.log(2); + } +}