Skip to content

Commit b5fa9b2

Browse files
committed
Evaluate Provider args and kwargs in Factory/Singleton providers
1 parent e3bda43 commit b5fa9b2

File tree

4 files changed

+129
-2
lines changed

4 files changed

+129
-2
lines changed

pif/providers/factory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Callable
1313

1414
from pif.providers.provider import Provider
15+
from pif.wiring import intercept
1516

1617

1718
class Factory[T](Provider):
@@ -22,7 +23,7 @@ class Factory[T](Provider):
2223
__slots__ = ("_func", "_depends")
2324

2425
def __init__(self, func: Callable[[...], T], *args, **kwargs):
25-
self._func = functools.partial(func, *args, **kwargs)
26+
self._func = functools.partial(intercept(func), *args, **kwargs)
2627

2728
def _evaluate(self) -> T:
2829
return self._func()

pif/providers/singleton.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,22 @@
1212
from typing import Callable
1313

1414
from pif.providers.provider import Provider
15+
from pif.wiring import intercept
1516

1617
UNSET = object()
1718

1819

1920
class Singleton[T](Provider):
2021
"""
2122
Provide a singleton instance.
23+
24+
Note that overriding any provider arguments will not cause the singleton to reevaluate.
2225
"""
2326

2427
__slots__ = ("_func", "_func", "_result", "_depends")
2528

2629
def __init__(self, func: Callable[[...], T], *args, **kwargs):
27-
self._func = functools.partial(func, *args, **kwargs)
30+
self._func = functools.partial(intercept(func), *args, **kwargs)
2831
self._result = UNSET
2932

3033
def _evaluate(self) -> T:

pif/wiring.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@
1818
from pif.providers.provider import Provider
1919

2020

21+
def intercept(func):
22+
"""
23+
Intercepts the args and kwargs at runtime evaluating any Provider values.
24+
"""
25+
26+
@functools.wraps(func)
27+
def wrapper(*args, **kwargs):
28+
return func(
29+
*(a() if isinstance(a, Provider) else a for a in args),
30+
**{k: v() if isinstance(v, Provider) else v for k, v in kwargs.items()},
31+
)
32+
33+
return wrapper
34+
35+
2136
def patch_args(
2237
signature: inspect.Signature,
2338
args: tuple[Any, ...],

tests/test_providers.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import namedtuple
2+
13
import pytest
24

35
from pif import exceptions, providers
@@ -199,3 +201,109 @@ def test_factory():
199201

200202
assert dict_alt == dict_1
201203
assert dict_alt is not dict_1
204+
205+
206+
def test_transitive_factory_wired():
207+
"""
208+
Checking the factory provider evaluates Provider args and kwargs.
209+
"""
210+
model = namedtuple("Model", "a b")
211+
212+
provider_a = providers.Factory[str](lambda: "a")
213+
provider_b = providers.Factory[str](lambda: "b")
214+
215+
provider = providers.Factory[tuple](model, provider_a, b=provider_b)
216+
217+
assert provider_a() == "a"
218+
assert provider_b() == "b"
219+
220+
model_1 = provider()
221+
model_2 = provider()
222+
assert model("a", "b") == model_1
223+
assert model_1 == model_2
224+
assert model_1 is not model_2
225+
226+
227+
def test_transitive_factory_override():
228+
"""
229+
Checking the factory provider generates different value when Provider arg and kwarg is overridden.
230+
"""
231+
model = namedtuple("Model", "a b")
232+
233+
provider_a = providers.Factory[str](lambda: "a")
234+
provider_b = providers.Factory[str](lambda: "b")
235+
assert provider_a() == "a"
236+
assert provider_b() == "b"
237+
238+
provider = providers.Factory[tuple](model, provider_a, b=provider_b)
239+
240+
model_1 = provider()
241+
model_2 = provider()
242+
assert model("a", "b") == model_1
243+
assert model_1 == model_2
244+
assert model_1 is not model_2
245+
246+
with (
247+
provider_a.override_existing("b"),
248+
provider_b.override_existing("a"),
249+
):
250+
assert provider_a() == "b"
251+
assert provider_b() == "a"
252+
253+
model_3 = provider()
254+
model_4 = provider()
255+
assert model("b", "a") == model_3
256+
assert model_3 == model_4
257+
assert model_3 is not model_4
258+
259+
model_5 = provider()
260+
model_6 = provider()
261+
assert model("a", "b") == model_5
262+
assert model_5 == model_6
263+
assert model_5 is not model_6
264+
265+
266+
def test_transitive_singleton_wired():
267+
"""
268+
Checking the singleton provider evaluates Provider args and kwargs.
269+
"""
270+
model = namedtuple("Model", "a b")
271+
272+
provider_a = providers.Singleton[str](lambda: "a")
273+
provider_b = providers.Singleton[str](lambda: "b")
274+
assert provider_a() == "a"
275+
assert provider_b() == "b"
276+
277+
provider = providers.Singleton[tuple](model, provider_a, provider_b)
278+
279+
model_1 = provider()
280+
model_2 = provider()
281+
assert model("a", "b") == model_1
282+
assert model_1 == model_2
283+
assert model_1 is model_2
284+
285+
286+
def test_transitive_singleton_override():
287+
"""
288+
Checking the singleton provide retains cached value even when Provider arg and kwarg is overridden.
289+
"""
290+
model = namedtuple("Model", "a b")
291+
292+
provider_a = providers.Singleton[str](lambda: "a")
293+
provider_b = providers.Singleton[str](lambda: "b")
294+
assert provider_a() == "a"
295+
assert provider_b() == "b"
296+
297+
provider = providers.Singleton[tuple](model, provider_a, provider_b)
298+
299+
model_1 = provider()
300+
model_2 = provider()
301+
assert model("a", "b") == model_1
302+
assert model_1 == model_2
303+
assert model_1 is model_2
304+
305+
with (
306+
provider_a.override_existing("b"),
307+
provider_b.override_existing("a"),
308+
):
309+
assert model_1 == provider() # Overriding does not change the cached value.

0 commit comments

Comments
 (0)