diff --git a/fn/func.py b/fn/func.py index b9d7562..70ed3d1 100644 --- a/fn/func.py +++ b/fn/func.py @@ -1,12 +1,13 @@ -from functools import partial, wraps +from functools import partial, wraps, update_wrapper from inspect import getargspec -from .op import identity, flip +from .op import identity + class F(object): """Provide simple syntax for functions composition - (through << and >> operators) and partial function - application (through simple tuple syntax). + (through << and >> operators) and partial function + application (through simple tuple syntax). Usage example: @@ -14,18 +15,18 @@ class F(object): >>> print(func(10)) 25 >>> func = F() >> (filter, _ < 6) >> sum - >>> print(func(range(10))) + >>> print(func(range(10))) 15 """ - __slots__ = "f", + __slots__ = "f", - def __init__(self, f = identity, *args, **kwargs): + def __init__(self, f=identity, *args, **kwargs): self.f = partial(f, *args, **kwargs) if any([args, kwargs]) else f @classmethod def __compose(cls, f, g): - """Produces new class intance that will + """Produces new class intance that will execute given functions one by one. Internal method that was added to avoid code duplication in other methods. @@ -33,8 +34,8 @@ def __compose(cls, f, g): return cls(lambda *args, **kwargs: f(g(*args, **kwargs))) def __ensure_callable(self, f): - """Simplify partial execution syntax. - Rerurn partial function built from tuple + """Simplify partial execution syntax. + Rerurn partial function built from tuple (func, arg1, arg2, ...) """ return self.__class__(*f) if isinstance(f, tuple) else f @@ -47,7 +48,7 @@ def __lshift__(self, g): """Overload << operator for F instances""" return self.__class__.__compose(self.f, self.__ensure_callable(g)) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs): """Overload apply operator""" return self.f(*args, **kwargs) @@ -80,5 +81,8 @@ def _curried(*args, **kwargs): if count == len(spec.args) - len(args): return func(*args, **kwargs) - return curried(partial(func, *args, **kwargs)) + para_func = partial(func, *args, **kwargs) + update_wrapper(para_func, f) + return curried(para_func) + return _curried diff --git a/tests.py b/tests.py index a5454b4..3445c24 100644 --- a/tests.py +++ b/tests.py @@ -10,12 +10,42 @@ from fn import op, _, F, Stream, iters, underscore, monad, recur from fn.uniform import reduce from fn.immutable import SkewHeap, PairingHeap, LinkedList, Stack, Queue, Vector, Deque +from fn.func import curried + class InstanceChecker(object): if sys.version_info[0] == 2 and sys.version_info[1] <= 6: def assertIsInstance(self, inst, cls): self.assertTrue(isinstance(inst, cls)) + +class Curriedtest(unittest.TestCase): + + def test_curried_wrapper(self): + + @curried + def _child(a, b, c, d): + return a + b + c + d + + @curried + def _moma(a, b): + return _child(a, b) + + def _assert_instance(expected, acutal): + self.assertEqual(expected.__module__, acutal.__module__) + self.assertEqual(expected.__name__, acutal.__name__) + + res1 = _moma(1) + _assert_instance(_moma, res1) + res2 = res1(2) + _assert_instance(_child, res2) + res3 = res2(3) + _assert_instance(_child, res3) + res4 = res3(4) + + self.assertEqual(res4, 10) + + class OperatorTestCase(unittest.TestCase): def test_currying(self):