Skip to content

Commit d76b3a2

Browse files
authored
Add skeleton of HookManager (#1460)
1 parent 15dd33f commit d76b3a2

File tree

4 files changed

+70
-4
lines changed

4 files changed

+70
-4
lines changed

src/fairseq2/recipe/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from fairseq2.recipe.internal.dataset import _DatasetHolder
3838
from fairseq2.recipe.internal.evaluator import _EvaluatorFactory
3939
from fairseq2.recipe.internal.generator import _GeneratorFactory
40+
from fairseq2.recipe.internal.hook import _TrainHookManager
4041
from fairseq2.recipe.internal.model import _ModelHolder
4142
from fairseq2.recipe.internal.reference_model import _ReferenceModelBootstrapper
4243
from fairseq2.recipe.internal.tokenizer import _TokenizerHolder
@@ -187,7 +188,13 @@ def create_trainer(
187188

188189
trainer_factory = self._resolver.resolve(_TrainerFactory)
189190

190-
return trainer_factory.create(unit, data_reader, validator)
191+
trainer = trainer_factory.create(unit, data_reader, validator)
192+
193+
hook_manager = self._resolver.resolve(_TrainHookManager)
194+
195+
hook_manager.maybe_register_trainer_hooks(trainer)
196+
197+
return trainer
191198

192199
def create_evaluator(
193200
self,

src/fairseq2/recipe/composition/root.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
_GangsFactory,
4949
_warmup_gangs,
5050
)
51+
from fairseq2.recipe.internal.hook import _HookManager, _TrainHookManager
5152
from fairseq2.recipe.internal.log import _log_ranks, _LogHelper, _StandardLogHelper
5253
from fairseq2.recipe.internal.logging import _DistributedLogConfigurer
5354
from fairseq2.recipe.internal.output_dir import _OutputDirectoryCreator
@@ -58,7 +59,11 @@
5859
from fairseq2.recipe.internal.task import _TaskRunner
5960
from fairseq2.recipe.internal.torch import _TorchConfigurer
6061
from fairseq2.recipe.run import _RecipeConfigDumper
61-
from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver
62+
from fairseq2.runtime.dependency import (
63+
DependencyContainer,
64+
DependencyResolver,
65+
wire_object,
66+
)
6267
from fairseq2.task import Task
6368
from fairseq2.utils.stopwatch import Stopwatch
6469

@@ -107,6 +112,8 @@ def create_gangs(resolver: DependencyResolver) -> Gangs:
107112
CheckpointManager, StandardCheckpointManager, singleton=True
108113
)
109114

115+
container.register_type(_TrainHookManager, singleton=True)
116+
110117
_register_data_parallel_wrappers(container)
111118
_register_lr_schedulers(container)
112119
_register_optim(container)
@@ -176,16 +183,27 @@ def _register_recipe_common(container: DependencyContainer) -> None:
176183
_register_sampling(container)
177184
_register_seq_generators(container)
178185

186+
def create_task_runner(resolver: DependencyResolver) -> _TaskRunner:
187+
task_runner = wire_object(resolver, _TaskRunner)
188+
189+
hook_manager = resolver.resolve(_HookManager)
190+
191+
hook_manager.maybe_register_task_hooks(task_runner)
192+
193+
return task_runner
194+
195+
container.register(_TaskRunner, create_task_runner)
196+
179197
container.register_type(_AssetConfigOverrider, _StandardAssetConfigOverrider)
180198
container.register_type(_ClusterPreparer)
181199
container.register_type(ComponentManager, _StandardComponentManager, singleton=True)
182200
container.register_type(_DistributedLogConfigurer)
183201
container.register_type(_GangsFactory)
202+
container.register_type(_HookManager, singleton=True)
184203
container.register_type(_LogHelper, _StandardLogHelper)
185204
container.register_type(_OutputDirectoryCreator)
186205
container.register_type(_RecipeConfigDumper)
187206
container.register_type(_SweepTagGenerator, _StandardSweepTagGenerator)
188-
container.register_type(_TaskRunner)
189207
container.register_type(_TorchConfigurer)
190208

191209
container.collection.register_type(AssetMetadataSource, _ExtraAssetMetadataSource)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# All rights reserved.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
from typing import final
9+
10+
from fairseq2.recipe.internal.task import _TaskRunner
11+
from fairseq2.trainer import Trainer
12+
13+
14+
@final
15+
class _TrainHookManager:
16+
def maybe_register_trainer_hooks(self, trainer: Trainer) -> None:
17+
pass
18+
19+
20+
@final
21+
class _HookManager:
22+
def maybe_register_task_hooks(self, task_runner: _TaskRunner) -> None:
23+
pass

src/fairseq2/recipe/internal/task.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66

77
from __future__ import annotations
88

9+
from collections import OrderedDict
910
from signal import SIGUSR1, signal
1011
from types import FrameType
11-
from typing import final
12+
from typing import Protocol, final
13+
14+
from torch.utils.hooks import RemovableHandle
1215

1316
from fairseq2.error import raise_operational_system_error
1417
from fairseq2.gang import GangError, Gangs, raise_operational_gang_error
@@ -17,11 +20,16 @@
1720
from fairseq2.utils.stopwatch import Stopwatch
1821

1922

23+
class _TaskStartHook(Protocol):
24+
def __call__(self) -> None: ...
25+
26+
2027
@final
2128
class _TaskRunner:
2229
def __init__(self, gangs: Gangs, wall_watch: Stopwatch) -> None:
2330
self._gangs = gangs
2431
self._wall_watch = wall_watch
32+
self._start_hooks: dict[int, _TaskStartHook] = OrderedDict()
2533

2634
def run(self, task: Task) -> None:
2735
log.info("Running on {} process(es).", self._gangs.root.size)
@@ -34,6 +42,9 @@ def request_stop(signum: int, frame: FrameType | None) -> None:
3442

3543
original_signal_handler = signal(SIGUSR1, request_stop)
3644

45+
for hook in self._start_hooks.values():
46+
hook()
47+
3748
try:
3849
task.run()
3950
except OSError as ex:
@@ -69,3 +80,10 @@ def request_stop(signum: int, frame: FrameType | None) -> None:
6980
task.close()
7081

7182
signal(SIGUSR1, original_signal_handler)
83+
84+
def register_start_hook(self, hook: _TaskStartHook) -> RemovableHandle:
85+
handle = RemovableHandle(self._start_hooks)
86+
87+
self._start_hooks[handle.id] = hook
88+
89+
return handle

0 commit comments

Comments
 (0)