diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index a25eea3c..666166eb 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -14,7 +14,7 @@ 'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv', 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate', 'sliding_window', 'partition', 'partition_all', 'count', 'pluck', - 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample') + 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample', 'power', 'quotient', 'sortby', 'repeatby') def remove(predicate, seq): @@ -980,3 +980,52 @@ def random_sample(prob, seq, random_state=None): if not hasattr(random_state, 'random'): random_state = Random(random_state) return filter(lambda _: random_state.random() < prob, seq) + + +def power(iterable, hook=set): + # the power set of iterable + return (hook(a) for a in concat(itertools.combinations(iterable, r) for r in range(len(iterable)+1))) + + + +def quotient(lst, key=None, rel=lambda x, y: x==y): + '''rel is an equivalent relation + return a partition of X, X/rel + also see groupby in toolz + Remark: It is named partition at first, but conflits with the original one. + quotient is another acceptable name. +''' + if lst==[]: + return lst + elif len(lst)==1: + return [lst] + if key: + rel = lambda x, y: key(x)==key(y) + #~ if rel: key = lambda x: {a for a in lst if rel(x, a)} + p = [[lst[0]]] + for a in lst[1:]: + for cls in p: + if rel(a, cls[0]): + cls.append(a) + break + else: + p.append([a]) + return p + +def sortby(lst, key=None, rel=lambda x, y: x==y): + # see quotient + return concat(quotient(lst, key, rel)) + +def repeatby(lst, nums): + '''example: +>>> repeatby(['w','r','y'],[3,2,1,2]) +['w', 'w', 'w', 'r', 'r', 'y', 'w', 'w'] +''' + new = [] + l = len(lst) + for k, n in enumerate(nums): + if l>k: + new.extend([lst[k]]*n) + else: + new.extend([lst[k%l]]*n) + return new diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 93aa856d..65c1192b 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -13,7 +13,7 @@ reduceby, iterate, accumulate, sliding_window, count, partition, partition_all, take_nth, pluck, join, - diff, topk, peek, random_sample) + diff, topk, peek, random_sample, power, quotient, sortby, repeatby) from toolz.compatibility import range, filter from operator import add, mul @@ -524,3 +524,26 @@ def test_random_sample(): assert mk_rsample(b"a") == mk_rsample(u"a") assert raises(TypeError, lambda: mk_rsample([])) + + + +def test_power(): + assert set(power([1])) == {{1}, set()} + + +def test_quotient(): + S = [1,2,3,4,5,6] + Q = quotient(S, rel=lambda x, y: (x-y) % 3 ==0) + assert Q == [[1, 4], [2, 5], [3, 6]] + Q1 = quotient(S, key=lambda x: x % 3) + assert Q1 == [[1, 4], [2, 5], [3, 6]] + + assert quotient([], rel=lambda x:1) == [] + +def test_sortby(): + S = [1,2,3,4,5,6] + Q = sortby(S, rel=lambda x, y: (x-y) % 3 ==0) + assert Q == [1, 4, 2, 5, 3, 6] + +def test_repeatby(): + assert repeatby(['w','r','y'], [3,2,1,2]) == ['w', 'w', 'w', 'r', 'r', 'y', 'w', 'w']