Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/basics/design_philosophy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Core Components
The dependency injection system is built around several key abstractions:

.. autoclass:: DependencyResolver
:members: resolve, resolve_optional, iter_keys
:members: resolve, iter_keys
:noindex:

.. autoclass:: DependencyContainer
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/data/tokenizers/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __call__(self) -> TokenizerHub[TokenizerT, TokenizerConfigT]:

name = self._family_name

family = resolver.resolve_optional(TokenizerFamily, key=name)
family = resolver.maybe_resolve(TokenizerFamily, key=name)
if family is None:
raise TokenizerFamilyNotKnownError(name)

Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/datasets/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __call__(self) -> DatasetHub[DatasetT, DatasetConfigT]:

name = self._family_name

family = resolver.resolve_optional(DatasetFamily, key=name)
family = resolver.maybe_resolve(DatasetFamily, key=name)
if family is None:
raise DatasetFamilyNotKnownError(name)

Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/hg.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_hugging_face_converter(family_name: str) -> HuggingFaceConverter:
"""
resolver = get_dependency_resolver()

hg_converter = resolver.resolve_optional(HuggingFaceConverter, key=family_name)
hg_converter = resolver.maybe_resolve(HuggingFaceConverter, key=family_name)
if hg_converter is None:
raise NotSupportedError(
f"{family_name} model family does not support Hugging Face conversion."
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def __call__(self) -> ModelHub[ModelT, ModelConfigT]:

name = self._family_name

family = resolver.resolve_optional(ModelFamily, key=name)
family = resolver.maybe_resolve(ModelFamily, key=name)
if family is None:
raise ModelFamilyNotKnownError(name)

Expand Down
22 changes: 10 additions & 12 deletions src/fairseq2/runtime/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DependencyResolver(ABC):
def resolve(self, kls: type[T], *, key: Hashable | None = None) -> T: ...

@abstractmethod
def resolve_optional(
def maybe_resolve(
self, kls: type[T], *, key: Hashable | None = None
) -> T | None: ...

Expand All @@ -60,7 +60,7 @@ def iter_keys(self, kls: type[T]) -> Iterator[Hashable]: ...


class DependencyNotFoundError(Exception):
def __init__(self, kls: type, key: Hashable | None, msg: str) -> None:
def __init__(self, kls: type[object], key: Hashable | None, msg: str) -> None:
super().__init__(msg)

self.kls = kls
Expand Down Expand Up @@ -133,7 +133,7 @@ def register_instance(

def _do_register(
self,
kls: type,
kls: type[object],
key: Hashable | None,
registration: _DependencyRegistration,
) -> None:
Expand Down Expand Up @@ -178,9 +178,7 @@ def resolve(self, kls: type[T], *, key: Hashable | None = None) -> T:
return obj

@override
def resolve_optional(
self, kls: type[T], *, key: Hashable | None = None
) -> T | None:
def maybe_resolve(self, kls: type[T], *, key: Hashable | None = None) -> T | None:
try:
return self.resolve(kls, key=key)
except DependencyNotFoundError as ex:
Expand Down Expand Up @@ -249,7 +247,7 @@ def register_instance(

def _do_register(
self,
kls: type,
kls: type[object],
key: Hashable | None,
registration: _DependencyRegistration,
) -> None:
Expand Down Expand Up @@ -362,7 +360,7 @@ def __init__(self, resolver: DependencyResolver, kls: type[T]) -> None:

@override
def maybe_get(self, key: Hashable) -> T | None:
return self._resolver.resolve_optional(self._kls, key=key)
return self._resolver.maybe_resolve(self._kls, key=key)

@override
def iter_keys(self) -> Iterator[Hashable]:
Expand Down Expand Up @@ -396,15 +394,15 @@ def wire_object(resolver: DependencyResolver, wire_kls: type[T], /, **kwargs: An


class AutoWireError(Exception):
def __init__(self, kls: type, reason: str) -> None:
def __init__(self, kls: type[object], reason: str) -> None:
super().__init__(f"`{kls}` cannot be auto-wired. {reason}")

self.kls = kls
self.reason = reason


def _create_wired_instance(
kls: type, resolver: DependencyResolver, custom_kwargs: dict[str, object]
kls: type[object], resolver: DependencyResolver, custom_kwargs: dict[str, object]
) -> object:
def wire_error(reason: str) -> Exception:
return AutoWireError(kls, reason)
Expand Down Expand Up @@ -573,7 +571,7 @@ def get_param_kls() -> type | None:
if element_kls is None:
continue

arg = resolver.resolve_optional(element_kls)
arg = resolver.maybe_resolve(element_kls)
else:
param_kls = get_param_kls()
if param_kls is None:
Expand All @@ -582,7 +580,7 @@ def get_param_kls() -> type | None:
if param_kls is DependencyResolver:
arg = resolver
elif param.default != Parameter.empty:
arg = resolver.resolve_optional(param_kls)
arg = resolver.maybe_resolve(param_kls)
if arg is None:
arg = _NOT_SET
else:
Expand Down
Loading