|
48 | 48 | _GangsFactory, |
49 | 49 | _warmup_gangs, |
50 | 50 | ) |
| 51 | +from fairseq2.recipe.internal.hook import _HookManager, _TrainHookManager |
51 | 52 | from fairseq2.recipe.internal.log import _log_ranks, _LogHelper, _StandardLogHelper |
52 | 53 | from fairseq2.recipe.internal.logging import _DistributedLogConfigurer |
53 | 54 | from fairseq2.recipe.internal.output_dir import _OutputDirectoryCreator |
|
58 | 59 | from fairseq2.recipe.internal.task import _TaskRunner |
59 | 60 | from fairseq2.recipe.internal.torch import _TorchConfigurer |
60 | 61 | from fairseq2.recipe.run import _RecipeConfigDumper |
61 | | -from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver |
| 62 | +from fairseq2.runtime.dependency import ( |
| 63 | + DependencyContainer, |
| 64 | + DependencyResolver, |
| 65 | + wire_object, |
| 66 | +) |
62 | 67 | from fairseq2.task import Task |
63 | 68 | from fairseq2.utils.stopwatch import Stopwatch |
64 | 69 |
|
@@ -107,6 +112,8 @@ def create_gangs(resolver: DependencyResolver) -> Gangs: |
107 | 112 | CheckpointManager, StandardCheckpointManager, singleton=True |
108 | 113 | ) |
109 | 114 |
|
| 115 | + container.register_type(_TrainHookManager, singleton=True) |
| 116 | + |
110 | 117 | _register_data_parallel_wrappers(container) |
111 | 118 | _register_lr_schedulers(container) |
112 | 119 | _register_optim(container) |
@@ -176,16 +183,27 @@ def _register_recipe_common(container: DependencyContainer) -> None: |
176 | 183 | _register_sampling(container) |
177 | 184 | _register_seq_generators(container) |
178 | 185 |
|
| 186 | + def create_task_runner(resolver: DependencyResolver) -> _TaskRunner: |
| 187 | + task_runner = wire_object(resolver, _TaskRunner) |
| 188 | + |
| 189 | + hook_manager = resolver.resolve(_HookManager) |
| 190 | + |
| 191 | + hook_manager.maybe_register_task_hooks(task_runner) |
| 192 | + |
| 193 | + return task_runner |
| 194 | + |
| 195 | + container.register(_TaskRunner, create_task_runner) |
| 196 | + |
179 | 197 | container.register_type(_AssetConfigOverrider, _StandardAssetConfigOverrider) |
180 | 198 | container.register_type(_ClusterPreparer) |
181 | 199 | container.register_type(ComponentManager, _StandardComponentManager, singleton=True) |
182 | 200 | container.register_type(_DistributedLogConfigurer) |
183 | 201 | container.register_type(_GangsFactory) |
| 202 | + container.register_type(_HookManager, singleton=True) |
184 | 203 | container.register_type(_LogHelper, _StandardLogHelper) |
185 | 204 | container.register_type(_OutputDirectoryCreator) |
186 | 205 | container.register_type(_RecipeConfigDumper) |
187 | 206 | container.register_type(_SweepTagGenerator, _StandardSweepTagGenerator) |
188 | | - container.register_type(_TaskRunner) |
189 | 207 | container.register_type(_TorchConfigurer) |
190 | 208 |
|
191 | 209 | container.collection.register_type(AssetMetadataSource, _ExtraAssetMetadataSource) |
|
0 commit comments