Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions diffir/measure/unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,26 @@ def kl_div(self, x, y):
y = y / y.sum()
return -(stats.entropy(x, y) + stats.entropy(y, x)) / 2

def rbo(self, x, y, p=0.95, depth=1000):
# adapted from <https://github.com/terrierteam/ir_measures/blob/main/ir_measures/providers/compat_provider.py>
from collections import Counter
x_set = set()
y_set = set()
score = 0.0
normalizer = 0.0
weight = 1.0
x = [d for d, s in Counter(x).most_common()]
y = [d for d, s in Counter(y).most_common()]
for i in range(depth):
if i < len(x):
x_set.add(x[i])
if i < len(y):
y_set.add(y[i])
score += weight*len(x_set.intersection(y_set))/(i + 1)
normalizer += weight
weight *= p
return score/normalizer

def _query_differences(self, run1, run2, *args, **kwargs):
"""
:param run1: TREC run. Has the format {qid: {docid: score}, ...}
Expand Down Expand Up @@ -146,6 +166,8 @@ def _query_differences(self, run1, run2, *args, **kwargs):
tau = (self.pearson_rank(union_score1, union_score2) + self.pearson_rank(union_score2, union_score1)) / 2
elif metric == "kldiv":
tau = self.kl_div(union_score1, union_score2)
elif metric == "rbo":
tau = self.rbo(run1[qid], run2[qid])
else:
raise ValueError("Metric {} not supported for the measure {}".format(self.metric, "metric"))
id2measure[qid] = tau
Expand Down
2 changes: 1 addition & 1 deletion diffir/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, dataset="none", queries="none", measure="topk", metric="weigh
self.queries = queries
if measure == "qrel":
self.measure = QrelMeasure(metric, topk)
elif measure in ["tauap", "weightedtau", "spearmanr", "pearsonrank", "kldiv"]:
elif measure in ["tauap", "weightedtau", "spearmanr", "pearsonrank", "kldiv", "rbo"]:
self.measure = TopkMeasure(measure, topk)
else:
raise ValueError("Measure {} is not supported".format(measure))
Expand Down