Skip to content

Commit 57d6af3

Browse files
authored
Respect wandb environment variables (#1440)
1 parent 05453fb commit 57d6af3

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

src/fairseq2/recipe/composition/metric_recorders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_wandb_run(resolver: DependencyResolver) -> WandbRun:
9595

9696
return run_factory.create()
9797

98-
container.register(WandbRun, get_wandb_run)
98+
container.register(WandbRun, get_wandb_run, singleton=True)
9999

100100
container.register_type(_WandbRunFactory)
101101

src/fairseq2/recipe/internal/metric_recorders.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from fairseq2.recipe.config import CommonSection
2929
from fairseq2.recipe.error import WandbInitializationError
3030
from fairseq2.recipe.internal.config import _RecipeConfigHolder
31+
from fairseq2.utils.env import Environment
3132
from fairseq2.utils.structured import ValueConverter
3233

3334

@@ -86,40 +87,70 @@ def __init__(
8687
self,
8788
section: CommonSection,
8889
output_dir: Path,
90+
env: Environment,
8991
config_holder: _RecipeConfigHolder,
9092
value_converter: ValueConverter,
9193
initializer: _WandbInitializer,
9294
run_id_manager: _WandbRunIdManager,
9395
) -> None:
9496
self._section = section
9597
self._output_dir = output_dir
98+
self._env = env
9699
self._config_holder = config_holder
97100
self._value_converter = value_converter
98101
self._initializer = initializer
99102
self._run_id_manager = run_id_manager
100103

101104
def create(self) -> WandbRun:
105+
wandb_config = self._section.metric_recorders.wandb
106+
107+
if self._env.has("WANDB_ENTITY"):
108+
entity = None
109+
else:
110+
entity = wandb_config.entity
111+
112+
if self._env.has("WANDB_PROJECT"):
113+
project = None
114+
else:
115+
project = wandb_config.project
116+
117+
if self._env.has("WANDB_RUN_ID"):
118+
run_id = None
119+
else:
120+
run_id = self._run_id_manager.get_id()
121+
122+
if self._env.has("WANDB_NAME"):
123+
run_name = None
124+
else:
125+
run_name = wandb_config.run_name
126+
127+
if self._env.has("WANDB_RUN_GROUP"):
128+
run_group = None
129+
else:
130+
run_group = wandb_config.group
131+
132+
if self._env.has("WANDB_JOB_TYPE"):
133+
job_type = None
134+
else:
135+
job_type = wandb_config.job_type
136+
102137
unstructured_config = self._value_converter.unstructure(
103138
self._config_holder.config
104139
)
105140

106141
if not isinstance(unstructured_config, dict):
107142
unstructured_config = None
108143

109-
id_ = self._run_id_manager.get_id()
110-
111-
wandb_config = self._section.metric_recorders.wandb
112-
113144
try:
114145
return self._initializer(
115-
entity=wandb_config.entity,
116-
project=wandb_config.project,
146+
entity=entity,
147+
project=project,
117148
dir=self._output_dir,
118-
id=id_,
119-
name=wandb_config.run_name,
149+
id=run_id,
150+
name=run_name,
120151
config=unstructured_config,
121-
group=wandb_config.group,
122-
job_type=wandb_config.job_type,
152+
group=run_group,
153+
job_type=job_type,
123154
resume=wandb_config.resume_mode,
124155
)
125156
except (RuntimeError, ValueError) as ex:

0 commit comments

Comments
 (0)