|
| 1 | +import concurrent.futures |
1 | 2 | from collections.abc import Iterable |
| 3 | +from dataclasses import dataclass |
| 4 | +from enum import Enum |
2 | 5 | from typing import TYPE_CHECKING, Callable, NoReturn, Optional, TypeVar, Union, cast |
3 | 6 |
|
4 | 7 | from funcy import ldistinct |
@@ -120,14 +123,6 @@ def _reproduce_stage(stage: "Stage", **kwargs) -> Optional["Stage"]: |
120 | 123 | return ret |
121 | 124 |
|
122 | 125 |
|
123 | | -def _get_upstream_downstream_nodes( |
124 | | - graph: Optional["DiGraph"], node: T |
125 | | -) -> tuple[list[T], list[T]]: |
126 | | - succ = list(graph.successors(node)) if graph else [] |
127 | | - pre = list(graph.predecessors(node)) if graph else [] |
128 | | - return succ, pre |
129 | | - |
130 | | - |
131 | 126 | def _repr(stages: Iterable["Stage"]) -> str: |
132 | 127 | return humanize.join(repr(stage.addressing) for stage in stages) |
133 | 128 |
|
@@ -155,54 +150,159 @@ def _raise_error(exc: Optional[Exception], *stages: "Stage") -> NoReturn: |
155 | 150 | raise ReproductionError(f"failed to reproduce{segment} {names}") from exc |
156 | 151 |
|
157 | 152 |
|
158 | | -def _reproduce( |
159 | | - stages: list["Stage"], |
160 | | - graph: Optional["DiGraph"] = None, |
161 | | - force_downstream: bool = False, |
162 | | - on_error: str = "fail", |
163 | | - force: bool = False, |
| 153 | +class ReproStatus(Enum): |
| 154 | + READY = "ready" |
| 155 | + IN_PROGRESS = "in-progress" |
| 156 | + COMPLETE = "complete" |
| 157 | + SKIPPED = "skipped" |
| 158 | + FAILED = "failed" |
| 159 | + |
| 160 | + |
| 161 | +@dataclass |
| 162 | +class StageInfo: |
| 163 | + upstream: list["Stage"] |
| 164 | + upstream_unfinished: set["Stage"] |
| 165 | + downstream: list["Stage"] |
| 166 | + force: bool |
| 167 | + status: ReproStatus |
| 168 | + result: Optional["Stage"] |
| 169 | + |
| 170 | + |
| 171 | +def _start_ready_stages( |
| 172 | + to_repro: dict["Stage", StageInfo], |
| 173 | + executor: concurrent.futures.ThreadPoolExecutor, |
| 174 | + max_stages: int, |
164 | 175 | repro_fn: Callable = _reproduce_stage, |
165 | 176 | **kwargs, |
| 177 | +) -> dict[concurrent.futures.Future["Stage"], "Stage"]: |
| 178 | + ready = [ |
| 179 | + (stage, stage_info) |
| 180 | + for stage, stage_info in to_repro.items() |
| 181 | + if stage_info.status == ReproStatus.READY and not stage_info.upstream_unfinished |
| 182 | + ] |
| 183 | + if not ready: |
| 184 | + return {} |
| 185 | + |
| 186 | + futures = { |
| 187 | + executor.submit( |
| 188 | + repro_fn, |
| 189 | + stage, |
| 190 | + upstream=stage_info.upstream, |
| 191 | + force=stage_info.force, |
| 192 | + **kwargs, |
| 193 | + ): stage |
| 194 | + for stage, stage_info in ready[:max_stages] |
| 195 | + } |
| 196 | + for stage in futures.values(): |
| 197 | + to_repro[stage].status = ReproStatus.IN_PROGRESS |
| 198 | + return futures |
| 199 | + |
| 200 | + |
| 201 | +def _result_or_raise( |
| 202 | + to_repro: dict["Stage", StageInfo], stages: list["Stage"], on_error: str |
166 | 203 | ) -> list["Stage"]: |
167 | | - assert on_error in ("fail", "keep-going", "ignore") |
168 | | - |
169 | 204 | result: list[Stage] = [] |
170 | 205 | failed: list[Stage] = [] |
171 | | - to_skip: dict[Stage, Stage] = {} |
172 | | - ret: Optional[Stage] = None |
173 | | - |
174 | | - force_state = dict.fromkeys(stages, force) |
175 | | - |
| 206 | + # Preserve original order |
176 | 207 | for stage in stages: |
177 | | - if stage in to_skip: |
178 | | - continue |
| 208 | + stage_info = to_repro[stage] |
| 209 | + if stage_info.status == ReproStatus.FAILED: |
| 210 | + failed.append(stage) |
| 211 | + elif stage_info.result: |
| 212 | + result.append(stage_info.result) |
179 | 213 |
|
180 | | - if ret: |
181 | | - logger.info("") # add a newline |
| 214 | + if on_error != "ignore" and failed: |
| 215 | + _raise_error(None, *failed) |
182 | 216 |
|
183 | | - upstream, downstream = _get_upstream_downstream_nodes(graph, stage) |
184 | | - force_stage = force_state[stage] |
| 217 | + return result |
185 | 218 |
|
186 | | - try: |
187 | | - ret = repro_fn(stage, upstream=upstream, force=force_stage, **kwargs) |
188 | | - except Exception as exc: # noqa: BLE001 |
189 | | - failed.append(stage) |
190 | | - if on_error == "fail": |
191 | | - _raise_error(exc, stage) |
192 | 219 |
|
193 | | - dependents = handle_error(graph, on_error, exc, stage) |
194 | | - to_skip.update(dict.fromkeys(dependents, stage)) |
| 220 | +def _handle_result( |
| 221 | + to_repro: dict["Stage", StageInfo], |
| 222 | + future: concurrent.futures.Future["Stage"], |
| 223 | + stage: "Stage", |
| 224 | + stage_info: StageInfo, |
| 225 | + graph: Optional["DiGraph"], |
| 226 | + on_error: str, |
| 227 | + force_downstream: bool, |
| 228 | +): |
| 229 | + ret: Optional[Stage] = None |
| 230 | + success = False |
| 231 | + try: |
| 232 | + ret = future.result() |
| 233 | + except Exception as exc: # noqa: BLE001 |
| 234 | + if on_error == "fail": |
| 235 | + _raise_error(exc, stage) |
| 236 | + |
| 237 | + stage_info.status = ReproStatus.FAILED |
| 238 | + dependents = handle_error(graph, on_error, exc, stage) |
| 239 | + for dependent in dependents: |
| 240 | + to_repro[dependent].status = ReproStatus.SKIPPED |
| 241 | + else: |
| 242 | + stage_info.status = ReproStatus.COMPLETE |
| 243 | + success = True |
| 244 | + |
| 245 | + for dependent in stage_info.downstream: |
| 246 | + if dependent not in to_repro: |
195 | 247 | continue |
| 248 | + dependent_info = to_repro[dependent] |
| 249 | + if stage in dependent_info.upstream_unfinished: |
| 250 | + dependent_info.upstream_unfinished.remove(stage) |
| 251 | + if success and force_downstream and (ret or stage_info.force): |
| 252 | + dependent_info.force = True |
196 | 253 |
|
197 | | - if force_downstream and (ret or force_stage): |
198 | | - force_state.update(dict.fromkeys(downstream, True)) |
| 254 | + if success and ret: |
| 255 | + stage_info.result = ret |
199 | 256 |
|
200 | | - if ret: |
201 | | - result.append(ret) |
202 | 257 |
|
203 | | - if on_error != "ignore" and failed: |
204 | | - _raise_error(None, *failed) |
205 | | - return result |
| 258 | +def _reproduce( |
| 259 | + stages: list["Stage"], |
| 260 | + graph: Optional["DiGraph"] = None, |
| 261 | + force_downstream: bool = False, |
| 262 | + on_error: str = "fail", |
| 263 | + force: bool = False, |
| 264 | + jobs: int = 1, |
| 265 | + **kwargs, |
| 266 | +) -> list["Stage"]: |
| 267 | + assert on_error in ("fail", "keep-going", "ignore") |
| 268 | + |
| 269 | + to_repro = { |
| 270 | + stage: StageInfo( |
| 271 | + upstream=(upstream := list(graph.successors(stage)) if graph else []), |
| 272 | + upstream_unfinished=set(upstream).intersection(stages), |
| 273 | + downstream=list(graph.predecessors(stage)) if graph else [], |
| 274 | + force=force, |
| 275 | + status=ReproStatus.READY, |
| 276 | + result=None, |
| 277 | + ) |
| 278 | + for stage in stages |
| 279 | + } |
| 280 | + |
| 281 | + if jobs == -1: |
| 282 | + jobs = len(stages) |
| 283 | + max_workers = max(1, min(jobs, len(stages))) |
| 284 | + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: |
| 285 | + futures = _start_ready_stages(to_repro, executor, max_workers, **kwargs) |
| 286 | + while futures: |
| 287 | + done, _ = concurrent.futures.wait( |
| 288 | + futures, return_when=concurrent.futures.FIRST_COMPLETED |
| 289 | + ) |
| 290 | + for future in done: |
| 291 | + stage = futures.pop(future) |
| 292 | + stage_info = to_repro[stage] |
| 293 | + _handle_result( |
| 294 | + to_repro, |
| 295 | + future, |
| 296 | + stage, |
| 297 | + stage_info, |
| 298 | + graph, |
| 299 | + on_error, |
| 300 | + force_downstream, |
| 301 | + ) |
| 302 | + |
| 303 | + futures.update(_start_ready_stages(to_repro, executor, len(done), **kwargs)) |
| 304 | + |
| 305 | + return _result_or_raise(to_repro, stages, on_error) |
206 | 306 |
|
207 | 307 |
|
208 | 308 | @locked |
|
0 commit comments