|
28 | 28 | from fairseq2.recipe.config import CommonSection |
29 | 29 | from fairseq2.recipe.error import WandbInitializationError |
30 | 30 | from fairseq2.recipe.internal.config import _RecipeConfigHolder |
| 31 | +from fairseq2.utils.env import Environment |
31 | 32 | from fairseq2.utils.structured import ValueConverter |
32 | 33 |
|
33 | 34 |
|
@@ -86,40 +87,70 @@ def __init__( |
86 | 87 | self, |
87 | 88 | section: CommonSection, |
88 | 89 | output_dir: Path, |
| 90 | + env: Environment, |
89 | 91 | config_holder: _RecipeConfigHolder, |
90 | 92 | value_converter: ValueConverter, |
91 | 93 | initializer: _WandbInitializer, |
92 | 94 | run_id_manager: _WandbRunIdManager, |
93 | 95 | ) -> None: |
94 | 96 | self._section = section |
95 | 97 | self._output_dir = output_dir |
| 98 | + self._env = env |
96 | 99 | self._config_holder = config_holder |
97 | 100 | self._value_converter = value_converter |
98 | 101 | self._initializer = initializer |
99 | 102 | self._run_id_manager = run_id_manager |
100 | 103 |
|
101 | 104 | def create(self) -> WandbRun: |
| 105 | + wandb_config = self._section.metric_recorders.wandb |
| 106 | + |
| 107 | + if self._env.has("WANDB_ENTITY"): |
| 108 | + entity = None |
| 109 | + else: |
| 110 | + entity = wandb_config.entity |
| 111 | + |
| 112 | + if self._env.has("WANDB_PROJECT"): |
| 113 | + project = None |
| 114 | + else: |
| 115 | + project = wandb_config.project |
| 116 | + |
| 117 | + if self._env.has("WANDB_RUN_ID"): |
| 118 | + run_id = None |
| 119 | + else: |
| 120 | + run_id = self._run_id_manager.get_id() |
| 121 | + |
| 122 | + if self._env.has("WANDB_NAME"): |
| 123 | + run_name = None |
| 124 | + else: |
| 125 | + run_name = wandb_config.run_name |
| 126 | + |
| 127 | + if self._env.has("WANDB_RUN_GROUP"): |
| 128 | + run_group = None |
| 129 | + else: |
| 130 | + run_group = wandb_config.group |
| 131 | + |
| 132 | + if self._env.has("WANDB_JOB_TYPE"): |
| 133 | + job_type = None |
| 134 | + else: |
| 135 | + job_type = wandb_config.job_type |
| 136 | + |
102 | 137 | unstructured_config = self._value_converter.unstructure( |
103 | 138 | self._config_holder.config |
104 | 139 | ) |
105 | 140 |
|
106 | 141 | if not isinstance(unstructured_config, dict): |
107 | 142 | unstructured_config = None |
108 | 143 |
|
109 | | - id_ = self._run_id_manager.get_id() |
110 | | - |
111 | | - wandb_config = self._section.metric_recorders.wandb |
112 | | - |
113 | 144 | try: |
114 | 145 | return self._initializer( |
115 | | - entity=wandb_config.entity, |
116 | | - project=wandb_config.project, |
| 146 | + entity=entity, |
| 147 | + project=project, |
117 | 148 | dir=self._output_dir, |
118 | | - id=id_, |
119 | | - name=wandb_config.run_name, |
| 149 | + id=run_id, |
| 150 | + name=run_name, |
120 | 151 | config=unstructured_config, |
121 | | - group=wandb_config.group, |
122 | | - job_type=wandb_config.job_type, |
| 152 | + group=run_group, |
| 153 | + job_type=job_type, |
123 | 154 | resume=wandb_config.resume_mode, |
124 | 155 | ) |
125 | 156 | except (RuntimeError, ValueError) as ex: |
|
0 commit comments