Skip to content

Commit fb2a3f3

Browse files
authored
Clean up dependency registrations (#1444)
1 parent 57d6af3 commit fb2a3f3

28 files changed

+267
-277
lines changed

src/fairseq2/composition/assets.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,46 +47,46 @@
4747
def register_file_assets(
4848
container: DependencyContainer, path: Path, *, not_exist_ok: bool = False
4949
) -> None:
50-
def get_source(resolver: DependencyResolver) -> AssetMetadataSource:
50+
def create_source(resolver: DependencyResolver) -> AssetMetadataSource:
5151
return wire_object(
5252
resolver, FileAssetMetadataSource, path=path, not_exist_ok=not_exist_ok
5353
)
5454

55-
container.collection.register(AssetMetadataSource, get_source)
55+
container.collection.register(AssetMetadataSource, create_source)
5656

5757

5858
def register_package_assets(container: DependencyContainer, package: str) -> None:
59-
def get_source(resolver: DependencyResolver) -> AssetMetadataSource:
59+
def create_source(resolver: DependencyResolver) -> AssetMetadataSource:
6060
return wire_object(resolver, PackageAssetMetadataSource, package=package)
6161

62-
container.collection.register(AssetMetadataSource, get_source)
62+
container.collection.register(AssetMetadataSource, create_source)
6363

6464

6565
def register_in_memory_assets(
6666
container: DependencyContainer, source: str, entries: Sequence[dict[str, object]]
6767
) -> None:
68-
def get_source(resolver: DependencyResolver) -> AssetMetadataSource:
68+
def create_source(resolver: DependencyResolver) -> AssetMetadataSource:
6969
return wire_object(
7070
resolver, InMemoryAssetMetadataSource, name=source, entries=entries
7171
)
7272

73-
container.collection.register(AssetMetadataSource, get_source)
73+
container.collection.register(AssetMetadataSource, create_source)
7474

7575

7676
def register_checkpoint_models(
7777
container: DependencyContainer, checkpoint_dir: Path
7878
) -> None:
79-
def get_source(resolver: DependencyResolver) -> AssetMetadataSource:
79+
def create_source(resolver: DependencyResolver) -> AssetMetadataSource:
8080
return wire_object(resolver, ModelMetadataSource, checkpoint_dir=checkpoint_dir)
8181

82-
container.collection.register(AssetMetadataSource, get_source)
82+
container.collection.register(AssetMetadataSource, create_source)
8383

8484

8585
def _register_asset(container: DependencyContainer) -> None:
8686
container.register_type(AssetDirectoryAccessor, StandardAssetDirectoryAccessor)
8787

8888
# Store
89-
def get_asset_store(resolver: DependencyResolver) -> AssetStore:
89+
def load_asset_store(resolver: DependencyResolver) -> AssetStore:
9090
sources = resolver.collection.resolve(AssetMetadataSource)
9191

9292
def load_providers() -> Iterator[AssetMetadataProvider]:
@@ -97,14 +97,16 @@ def load_providers() -> Iterator[AssetMetadataProvider]:
9797

9898
env = env_detector.detect()
9999

100+
metadata_providers = load_providers()
101+
100102
return wire_object(
101103
resolver,
102104
StandardAssetStore,
103-
metadata_providers=load_providers(),
105+
metadata_providers=metadata_providers,
104106
default_env=env,
105107
)
106108

107-
container.register(AssetStore, get_asset_store, singleton=True)
109+
container.register(AssetStore, load_asset_store, singleton=True)
108110

109111
container.register_type(AssetEnvironmentDetector)
110112

@@ -128,7 +130,7 @@ def load_providers() -> Iterator[AssetMetadataProvider]:
128130
container.collection.register_type(AssetDownloadManager, LocalAssetDownloadManager)
129131
container.collection.register_type(AssetDownloadManager, HuggingFaceHub)
130132

131-
def get_standard_asset_download_manager(
133+
def create_standard_asset_download_manager(
132134
resolver: DependencyResolver,
133135
) -> AssetDownloadManager:
134136
dirs = resolver.resolve(AssetDirectoryAccessor)
@@ -150,5 +152,5 @@ def get_standard_asset_download_manager(
150152
)
151153

152154
container.collection.register(
153-
AssetDownloadManager, get_standard_asset_download_manager
155+
AssetDownloadManager, create_standard_asset_download_manager
154156
)

src/fairseq2/composition/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def register_dataset_family(
4949
elif opener is None:
5050
raise ValueError("`opener` or `advanced_opener` must be specified.")
5151

52-
def get_family(resolver: DependencyResolver) -> DatasetFamily:
52+
def create_family(resolver: DependencyResolver) -> DatasetFamily:
5353
nonlocal opener
5454

5555
if advanced_opener is not None:
@@ -70,7 +70,7 @@ def open_dataset(config: DatasetConfigT) -> DatasetT:
7070
opener=opener,
7171
)
7272

73-
container.register(DatasetFamily, get_family, key=name)
73+
container.register(DatasetFamily, create_family, key=name)
7474

7575

7676
def _register_dataset_families(container: DependencyContainer) -> None:

src/fairseq2/composition/lib.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _register_library(
128128
else:
129129
container.register_type(ProgressReporter, RichProgressReporter, singleton=True)
130130

131-
def get_download_progress_reporter(
131+
def create_download_progress_reporter(
132132
resolver: DependencyResolver,
133133
) -> ProgressReporter:
134134
columns = _create_rich_download_progress_columns()
@@ -137,7 +137,7 @@ def get_download_progress_reporter(
137137

138138
container.register(
139139
ProgressReporter,
140-
get_download_progress_reporter,
140+
create_download_progress_reporter,
141141
key="download_reporter",
142142
singleton=True,
143143
)
@@ -151,30 +151,29 @@ def get_world_info(resolver: DependencyResolver) -> WorldInfo:
151151
container.register(WorldInfo, get_world_info, singleton=True)
152152

153153
# Device
154-
def get_default_device(resolver: DependencyResolver) -> Device:
154+
def detect_default_device(resolver: DependencyResolver) -> Device:
155155
device_detector = resolver.resolve(DefaultDeviceDetector)
156156

157157
return device_detector.detect()
158158

159-
container.register(Device, get_default_device, singleton=True)
159+
container.register(Device, detect_default_device, singleton=True)
160160

161161
# ThreadPool
162-
def get_default_thread_pool(resolver: DependencyResolver) -> ThreadPool:
162+
def create_thread_pool(resolver: DependencyResolver) -> ThreadPool:
163163
world_info = resolver.resolve(WorldInfo)
164164

165165
return StandardThreadPool.create_default(world_info.local_size)
166166

167-
container.register(ThreadPool, get_default_thread_pool, singleton=True)
167+
container.register(ThreadPool, create_thread_pool, singleton=True)
168168

169169
# RngBag
170-
def get_default_rng_bag(resolver: DependencyResolver) -> RngBag:
170+
def create_rng_bag(resolver: DependencyResolver) -> RngBag:
171171
device = resolver.resolve(Device)
172172

173173
return RngBag.from_device_defaults(CPU, device)
174174

175-
container.register(RngBag, get_default_rng_bag, singleton=True)
175+
container.register(RngBag, create_rng_bag, singleton=True)
176176

177-
# fmt: off
178177
container.register_type(AssetConfigLoader, StandardAssetConfigLoader)
179178
container.register_type(ClusterResolver, StandardClusterResolver)
180179
container.register_type(ConfigMerger, StandardConfigMerger)
@@ -184,13 +183,17 @@ def get_default_rng_bag(resolver: DependencyResolver) -> RngBag:
184183
container.register_type(FileSystem, LocalFileSystem, singleton=True)
185184
container.register_type(GlobalModelLoader, singleton=True)
186185
container.register_type(GlobalTokenizerLoader, singleton=True)
187-
container.register_type(ModelCheckpointLoader, DelegatingModelCheckpointLoader, singleton=True)
186+
container.register_type(
187+
ModelCheckpointLoader, DelegatingModelCheckpointLoader, singleton=True
188+
)
188189
container.register_type(ModelMetadataDumper, StandardModelMetadataDumper)
189190
container.register_type(ModelMetadataLoader, StandardModelMetadataLoader)
190191
container.register_type(ModelSharder, StandardModelSharder, singleton=True)
191192
container.register_type(ObjectValidator, StandardObjectValidator, singleton=True)
192193
container.register_type(SafetensorsLoader, HuggingFaceSafetensorsLoader)
193-
container.register_type(SentencePieceModelLoader, StandardSentencePieceModelLoader, singleton=True)
194+
container.register_type(
195+
SentencePieceModelLoader, StandardSentencePieceModelLoader, singleton=True
196+
)
194197
container.register_type(TensorDumper, TorchTensorDumper, singleton=True)
195198
container.register_type(TensorLoader, TorchTensorLoader, singleton=True)
196199
container.register_type(ValueConverter, StandardValueConverter, singleton=True)
@@ -203,13 +206,18 @@ def get_default_rng_bag(resolver: DependencyResolver) -> RngBag:
203206
container.collection.register_type(ModuleSharder, LinearSharder)
204207
container.collection.register_type(ModuleSharder, MoESharder)
205208

206-
container.collection.register_type(ModelCheckpointLoader, BasicModelCheckpointLoader)
207-
container.collection.register_type(ModelCheckpointLoader, NativeModelCheckpointLoader)
208-
container.collection.register_type(ModelCheckpointLoader, SafetensorsCheckpointLoader)
209+
container.collection.register_type(
210+
ModelCheckpointLoader, BasicModelCheckpointLoader
211+
)
212+
container.collection.register_type(
213+
ModelCheckpointLoader, NativeModelCheckpointLoader
214+
)
215+
container.collection.register_type(
216+
ModelCheckpointLoader, SafetensorsCheckpointLoader
217+
)
209218
container.collection.register_type(ModelCheckpointLoader, LLaMACheckpointLoader)
210219

211220
container.collection.register_type(ConfigDirective, ReplaceEnvDirective)
212-
# fmt: on
213221

214222
_register_asset(container)
215223

src/fairseq2/composition/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def register_model_family(
176176
elif factory is None:
177177
raise ValueError("`factory` or `advanced_factory` must be specified.")
178178

179-
def get_family(resolver: DependencyResolver) -> ModelFamily:
179+
def create_family(resolver: DependencyResolver) -> ModelFamily:
180180
nonlocal factory
181181

182182
if advanced_factory is not None:
@@ -207,7 +207,7 @@ def create_model(config: ModelConfigT) -> ModelT:
207207
hg_exporter=hg_exporter,
208208
)
209209

210-
container.register(ModelFamily, get_family, key=name)
210+
container.register(ModelFamily, create_family, key=name)
211211

212212

213213
def _register_model_families(container: DependencyContainer) -> None:

src/fairseq2/composition/tokenizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def register_tokenizer_family(
8484
elif loader is None:
8585
raise ValueError("`loader` or `advanced_loader` must be specified.")
8686

87-
def get_family(resolver: DependencyResolver) -> TokenizerFamily:
87+
def create_family(resolver: DependencyResolver) -> TokenizerFamily:
8888
nonlocal loader
8989

9090
if advanced_loader is not None:
@@ -105,7 +105,7 @@ def load_tokenizer(path: Path, config: TokenizerConfigT) -> Tokenizer:
105105
loader=loader,
106106
)
107107

108-
container.register(TokenizerFamily, get_family, key=name)
108+
container.register(TokenizerFamily, create_family, key=name)
109109

110110

111111
def _register_tokenizer_families(container: DependencyContainer) -> None:

src/fairseq2/recipe/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def get_config(resolver: DependencyResolver) -> _RecipeConfigHolder:
202202
container.register(_RecipeConfigHolder, get_config)
203203

204204
container.register_type(_RecipeConfigLoader)
205-
container.register_type(_RecipeConfigPrinter)
206205
container.register_type(_RecipeConfigStructurer, _StandardRecipeConfigStructurer)
207206

208207
# Recipe Output Directory
@@ -213,6 +212,8 @@ def get_output_dir(resolver: DependencyResolver) -> Path:
213212

214213
container.register(Path, get_output_dir)
215214

215+
container.register_type(_RecipeConfigPrinter)
216+
216217
# CLI Errors
217218
_register_cli_errors(container)
218219

src/fairseq2/recipe/composition/data_parallel.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111
from fairseq2.recipe.base import Recipe, RecipeContext
1212
from fairseq2.recipe.internal.data_parallel import (
1313
_DataParallelModelWrapper,
14-
_DDPFactory,
1514
_DDPModelWrapper,
1615
_DelegatingDPModelWrapper,
17-
_FSDPFactory,
1816
_FSDPModelWrapper,
1917
)
2018
from fairseq2.runtime.dependency import (
@@ -29,20 +27,21 @@ def _register_data_parallel_wrappers(container: DependencyContainer) -> None:
2927
container.register_type(_DataParallelModelWrapper, _DelegatingDPModelWrapper)
3028

3129
# DDP
32-
def get_ddp_wrapper(resolver: DependencyResolver) -> _DataParallelModelWrapper:
30+
def create_ddp_wrapper(resolver: DependencyResolver) -> _DataParallelModelWrapper:
3331
train_recipe = resolver.resolve(Recipe)
3432

3533
context = RecipeContext(resolver)
3634

3735
static_graph = train_recipe.has_static_autograd_graph(context)
3836

39-
return wire_object(resolver, _DDPModelWrapper, static_graph=static_graph)
37+
return wire_object(
38+
resolver, _DDPModelWrapper, ddp_factory=to_ddp, static_graph=static_graph
39+
)
4040

41-
container.register(_DataParallelModelWrapper, get_ddp_wrapper, key="ddp")
42-
43-
container.register_instance(_DDPFactory, to_ddp) # type: ignore[arg-type]
41+
container.register(_DataParallelModelWrapper, create_ddp_wrapper, key="ddp")
4442

4543
# FSDP
46-
container.register_type(_DataParallelModelWrapper, _FSDPModelWrapper, key="fsdp")
44+
def create_fsdp_wrapper(resolver: DependencyResolver) -> _FSDPModelWrapper:
45+
return wire_object(resolver, _FSDPModelWrapper, fsdp_factory=to_fsdp)
4746

48-
container.register_instance(_FSDPFactory, to_fsdp) # type: ignore[arg-type]
47+
container.register(_DataParallelModelWrapper, create_fsdp_wrapper, key="fsdp")

src/fairseq2/recipe/composition/dataset.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,21 @@
88

99
from fairseq2.recipe.composition.config import register_config_section
1010
from fairseq2.recipe.config import DatasetSection
11-
from fairseq2.recipe.internal.dataset import _DatasetHolder, _RecipeDatasetOpener
11+
from fairseq2.recipe.internal.dataset import _DatasetHolder, _DatasetOpener
1212
from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver
1313

1414

1515
def register_dataset(container: DependencyContainer, section_name: str) -> None:
1616
register_config_section(container, section_name, DatasetSection, keyed=True)
1717

18-
def get_dataset_holder(resolver: DependencyResolver) -> _DatasetHolder:
18+
def open_dataset(resolver: DependencyResolver) -> _DatasetHolder:
1919
section = resolver.resolve(DatasetSection, key=section_name)
2020

21-
dataset_opener = resolver.resolve(_RecipeDatasetOpener)
21+
dataset_opener = resolver.resolve(_DatasetOpener)
2222

2323
return dataset_opener.open(section_name, section)
2424

25-
container.register(
26-
_DatasetHolder, get_dataset_holder, key=section_name, singleton=True
27-
)
25+
container.register(_DatasetHolder, open_dataset, key=section_name, singleton=True)
2826

2927
def get_dataset(resolver: DependencyResolver) -> object:
3028
dataset_holder = resolver.resolve(_DatasetHolder, key=section_name)
@@ -34,11 +32,11 @@ def get_dataset(resolver: DependencyResolver) -> object:
3432
container.register(object, get_dataset, key=section_name, singleton=True)
3533

3634

37-
def _register_datasets(container: DependencyContainer) -> None:
38-
container.register_type(_RecipeDatasetOpener)
39-
35+
def _register_default_dataset(container: DependencyContainer) -> None:
4036
register_dataset(container, section_name="dataset")
4137

38+
container.register_type(_DatasetOpener)
39+
4240
# Default Dataset
4341
def get_dataset_holder(resolver: DependencyResolver) -> _DatasetHolder:
4442
return resolver.resolve(_DatasetHolder, key="dataset")

src/fairseq2/recipe/composition/evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818

1919

2020
def _register_evaluator_factory(container: DependencyContainer) -> None:
21-
def get_evaluator_factory(resolver: DependencyResolver) -> _EvaluatorFactory:
21+
def create_evaluator_factory(resolver: DependencyResolver) -> _EvaluatorFactory:
2222
def create_evaluator(**kwargs: Any) -> Evaluator:
2323
return wire_object(resolver, Evaluator, **kwargs)
2424

25-
return wire_object(resolver, _EvaluatorFactory, activator=create_evaluator)
25+
return wire_object(resolver, _EvaluatorFactory, base_factory=create_evaluator)
2626

27-
container.register(_EvaluatorFactory, get_evaluator_factory)
27+
container.register(_EvaluatorFactory, create_evaluator_factory)

src/fairseq2/recipe/composition/generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818

1919

2020
def _register_generator_factory(container: DependencyContainer) -> None:
21-
def get_generator_factory(resolver: DependencyResolver) -> _GeneratorFactory:
21+
def create_generator_factory(resolver: DependencyResolver) -> _GeneratorFactory:
2222
def create_generator(**kwargs: Any) -> Generator:
2323
return wire_object(resolver, Generator, **kwargs)
2424

25-
return wire_object(resolver, _GeneratorFactory, activator=create_generator)
25+
return wire_object(resolver, _GeneratorFactory, base_factory=create_generator)
2626

27-
container.register(_GeneratorFactory, get_generator_factory)
27+
container.register(_GeneratorFactory, create_generator_factory)

0 commit comments

Comments
 (0)