Skip to content

Commit ed33d83

Browse files
committed
repro: parallel execution of stages
1 parent f342cc7 commit ed33d83

File tree

8 files changed

+267
-70
lines changed

8 files changed

+267
-70
lines changed

dvc/commands/repro.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _repro_kwargs(self):
4040
"run_cache": not self.args.no_run_cache,
4141
"no_commit": self.args.no_commit,
4242
"glob": self.args.glob,
43+
"jobs": self.args.jobs,
4344
}
4445

4546

@@ -188,4 +189,12 @@ def add_parser(subparsers, parent_parser):
188189
"the same command/dependencies/outputs/etc before."
189190
),
190191
)
192+
repro_parser.add_argument(
193+
"-j",
194+
"--jobs",
195+
type=int,
196+
default=1,
197+
help="Run (at most) specified number of stages at a time in parallel.",
198+
metavar="<number>",
199+
)
191200
repro_parser.set_defaults(func=CmdRepro)

dvc/dvcfile.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import contextlib
22
import os
3+
import threading
4+
from collections import defaultdict
35
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, Union
46

57
from dvc.exceptions import DvcException
@@ -77,6 +79,10 @@ def check_dvcfile_path(repo, path):
7779
raise FileIsGitIgnored(relpath(path), True)
7880

7981

82+
_file_locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
83+
_file_locks_lock = threading.Lock()
84+
85+
8086
class FileMixin:
8187
SCHEMA: Callable[[_T], _T]
8288

@@ -85,6 +91,10 @@ def __init__(self, repo, path, verify=True, **kwargs):
8591
self.path = path
8692
self.verify = verify
8793

94+
def _thread_lock(self) -> threading.Lock:
95+
with _file_locks_lock:
96+
return _file_locks[self.path]
97+
8898
def __repr__(self):
8999
return f"{self.__class__.__name__}: {relpath(self.path, self.repo.root_dir)}"
90100

@@ -148,15 +158,19 @@ def validate(cls, d: _T, fname: Optional[str] = None) -> _T:
148158
def _load_yaml(self, **kwargs: Any) -> tuple[Any, str]:
149159
from dvc.utils import strictyaml
150160

151-
return strictyaml.load(
152-
self.path,
153-
self.SCHEMA, # type: ignore[arg-type]
154-
self.repo.fs,
155-
**kwargs,
156-
)
161+
with self._thread_lock():
162+
return strictyaml.load(
163+
self.path,
164+
self.SCHEMA, # type: ignore[arg-type]
165+
self.repo.fs,
166+
**kwargs,
167+
)
157168

158169
def remove(self, force=False): # noqa: ARG002
159-
with contextlib.suppress(FileNotFoundError):
170+
with (
171+
self._thread_lock(),
172+
contextlib.suppress(FileNotFoundError),
173+
):
160174
os.unlink(self.path)
161175

162176
def dump(self, stage, **kwargs):
@@ -407,7 +421,10 @@ def _load(self, **kwargs: Any):
407421
return {}, ""
408422

409423
def dump_dataset(self, dataset: dict):
410-
with modify_yaml(self.path, fs=self.repo.fs) as data:
424+
with (
425+
self._thread_lock(),
426+
modify_yaml(self.path, fs=self.repo.fs) as data,
427+
):
411428
data.update({"schema": "2.0"})
412429
if not data:
413430
logger.info("Generating lock file '%s'", self.relpath)
@@ -430,7 +447,10 @@ def dump_stages(self, stages, **kwargs):
430447

431448
is_modified = False
432449
log_updated = False
433-
with modify_yaml(self.path, fs=self.repo.fs) as data:
450+
with (
451+
self._thread_lock(),
452+
modify_yaml(self.path, fs=self.repo.fs) as data,
453+
):
434454
if not data:
435455
data.update({"schema": "2.0"})
436456
# order is important, meta should always be at the top
@@ -468,7 +488,8 @@ def remove_stage(self, stage):
468488
del data[stage.name]
469489

470490
if data:
471-
dump_yaml(self.path, d)
491+
with self._thread_lock():
492+
dump_yaml(self.path, d)
472493
else:
473494
self.remove()
474495

dvc/repo/reproduce.py

Lines changed: 142 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import concurrent.futures
12
from collections.abc import Iterable
3+
from dataclasses import dataclass
4+
from enum import Enum
25
from typing import TYPE_CHECKING, Callable, NoReturn, Optional, TypeVar, Union, cast
36

47
from funcy import ldistinct
@@ -120,14 +123,6 @@ def _reproduce_stage(stage: "Stage", **kwargs) -> Optional["Stage"]:
120123
return ret
121124

122125

123-
def _get_upstream_downstream_nodes(
124-
graph: Optional["DiGraph"], node: T
125-
) -> tuple[list[T], list[T]]:
126-
succ = list(graph.successors(node)) if graph else []
127-
pre = list(graph.predecessors(node)) if graph else []
128-
return succ, pre
129-
130-
131126
def _repr(stages: Iterable["Stage"]) -> str:
132127
return humanize.join(repr(stage.addressing) for stage in stages)
133128

@@ -155,54 +150,159 @@ def _raise_error(exc: Optional[Exception], *stages: "Stage") -> NoReturn:
155150
raise ReproductionError(f"failed to reproduce{segment} {names}") from exc
156151

157152

158-
def _reproduce(
159-
stages: list["Stage"],
160-
graph: Optional["DiGraph"] = None,
161-
force_downstream: bool = False,
162-
on_error: str = "fail",
163-
force: bool = False,
153+
class ReproStatus(Enum):
154+
READY = "ready"
155+
IN_PROGRESS = "in-progress"
156+
COMPLETE = "complete"
157+
SKIPPED = "skipped"
158+
FAILED = "failed"
159+
160+
161+
@dataclass
162+
class StageInfo:
163+
upstream: list["Stage"]
164+
upstream_unfinished: set["Stage"]
165+
downstream: list["Stage"]
166+
force: bool
167+
status: ReproStatus
168+
result: Optional["Stage"]
169+
170+
171+
def _start_ready_stages(
172+
to_repro: dict["Stage", StageInfo],
173+
executor: concurrent.futures.ThreadPoolExecutor,
174+
max_stages: int,
164175
repro_fn: Callable = _reproduce_stage,
165176
**kwargs,
177+
) -> dict[concurrent.futures.Future["Stage"], "Stage"]:
178+
ready = [
179+
(stage, stage_info)
180+
for stage, stage_info in to_repro.items()
181+
if stage_info.status == ReproStatus.READY and not stage_info.upstream_unfinished
182+
]
183+
if not ready:
184+
return {}
185+
186+
futures = {
187+
executor.submit(
188+
repro_fn,
189+
stage,
190+
upstream=stage_info.upstream,
191+
force=stage_info.force,
192+
**kwargs,
193+
): stage
194+
for stage, stage_info in ready[:max_stages]
195+
}
196+
for stage in futures.values():
197+
to_repro[stage].status = ReproStatus.IN_PROGRESS
198+
return futures
199+
200+
201+
def _result_or_raise(
202+
to_repro: dict["Stage", StageInfo], stages: list["Stage"], on_error: str
166203
) -> list["Stage"]:
167-
assert on_error in ("fail", "keep-going", "ignore")
168-
169204
result: list[Stage] = []
170205
failed: list[Stage] = []
171-
to_skip: dict[Stage, Stage] = {}
172-
ret: Optional[Stage] = None
173-
174-
force_state = dict.fromkeys(stages, force)
175-
206+
# Preserve original order
176207
for stage in stages:
177-
if stage in to_skip:
178-
continue
208+
stage_info = to_repro[stage]
209+
if stage_info.status == ReproStatus.FAILED:
210+
failed.append(stage)
211+
elif stage_info.result:
212+
result.append(stage_info.result)
179213

180-
if ret:
181-
logger.info("") # add a newline
214+
if on_error != "ignore" and failed:
215+
_raise_error(None, *failed)
182216

183-
upstream, downstream = _get_upstream_downstream_nodes(graph, stage)
184-
force_stage = force_state[stage]
217+
return result
185218

186-
try:
187-
ret = repro_fn(stage, upstream=upstream, force=force_stage, **kwargs)
188-
except Exception as exc: # noqa: BLE001
189-
failed.append(stage)
190-
if on_error == "fail":
191-
_raise_error(exc, stage)
192219

193-
dependents = handle_error(graph, on_error, exc, stage)
194-
to_skip.update(dict.fromkeys(dependents, stage))
220+
def _handle_result(
221+
to_repro: dict["Stage", StageInfo],
222+
future: concurrent.futures.Future["Stage"],
223+
stage: "Stage",
224+
stage_info: StageInfo,
225+
graph: Optional["DiGraph"],
226+
on_error: str,
227+
force_downstream: bool,
228+
):
229+
ret: Optional[Stage] = None
230+
success = False
231+
try:
232+
ret = future.result()
233+
except Exception as exc: # noqa: BLE001
234+
if on_error == "fail":
235+
_raise_error(exc, stage)
236+
237+
stage_info.status = ReproStatus.FAILED
238+
dependents = handle_error(graph, on_error, exc, stage)
239+
for dependent in dependents:
240+
to_repro[dependent].status = ReproStatus.SKIPPED
241+
else:
242+
stage_info.status = ReproStatus.COMPLETE
243+
success = True
244+
245+
for dependent in stage_info.downstream:
246+
if dependent not in to_repro:
195247
continue
248+
dependent_info = to_repro[dependent]
249+
if stage in dependent_info.upstream_unfinished:
250+
dependent_info.upstream_unfinished.remove(stage)
251+
if success and force_downstream and (ret or stage_info.force):
252+
dependent_info.force = True
196253

197-
if force_downstream and (ret or force_stage):
198-
force_state.update(dict.fromkeys(downstream, True))
254+
if success and ret:
255+
stage_info.result = ret
199256

200-
if ret:
201-
result.append(ret)
202257

203-
if on_error != "ignore" and failed:
204-
_raise_error(None, *failed)
205-
return result
258+
def _reproduce(
259+
stages: list["Stage"],
260+
graph: Optional["DiGraph"] = None,
261+
force_downstream: bool = False,
262+
on_error: str = "fail",
263+
force: bool = False,
264+
jobs: int = 1,
265+
**kwargs,
266+
) -> list["Stage"]:
267+
assert on_error in ("fail", "keep-going", "ignore")
268+
269+
to_repro = {
270+
stage: StageInfo(
271+
upstream=(upstream := list(graph.successors(stage)) if graph else []),
272+
upstream_unfinished=set(upstream).intersection(stages),
273+
downstream=list(graph.predecessors(stage)) if graph else [],
274+
force=force,
275+
status=ReproStatus.READY,
276+
result=None,
277+
)
278+
for stage in stages
279+
}
280+
281+
if jobs == -1:
282+
jobs = len(stages)
283+
max_workers = max(1, min(jobs, len(stages)))
284+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
285+
futures = _start_ready_stages(to_repro, executor, max_workers, **kwargs)
286+
while futures:
287+
done, _ = concurrent.futures.wait(
288+
futures, return_when=concurrent.futures.FIRST_COMPLETED
289+
)
290+
for future in done:
291+
stage = futures.pop(future)
292+
stage_info = to_repro[stage]
293+
_handle_result(
294+
to_repro,
295+
future,
296+
stage,
297+
stage_info,
298+
graph,
299+
on_error,
300+
force_downstream,
301+
)
302+
303+
futures.update(_start_ready_stages(to_repro, executor, len(done), **kwargs))
304+
305+
return _result_or_raise(to_repro, stages, on_error)
206306

207307

208308
@locked

dvc/rwlock.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import threading
34
from collections import defaultdict
45
from contextlib import contextmanager
56

@@ -25,6 +26,8 @@
2526
}
2627
)
2728

29+
RWLOCK_THREAD_LOCK = threading.Lock()
30+
RWLOCK_THREAD_TIMEOUT = 3
2831
RWLOCK_FILE = "rwlock"
2932
RWLOCK_LOCK = "rwlock.lock"
3033

@@ -50,21 +53,26 @@ def _edit_rwlock(lock_dir, fs, hardlink):
5053
tmp_dir=lock_dir,
5154
hardlink_lock=hardlink,
5255
)
53-
with rwlock_guard:
54-
try:
55-
with fs.open(path, encoding="utf-8") as fobj:
56-
lock = SCHEMA(json.load(fobj))
57-
except FileNotFoundError:
58-
lock = SCHEMA({})
59-
except json.JSONDecodeError as exc:
60-
raise RWLockFileCorruptedError(path) from exc
61-
except Invalid as exc:
62-
raise RWLockFileFormatError(path) from exc
63-
lock["read"] = defaultdict(list, lock["read"])
64-
lock["write"] = defaultdict(dict, lock["write"])
65-
yield lock
66-
with fs.open(path, "w", encoding="utf-8") as fobj:
67-
json.dump(lock, fobj)
56+
RWLOCK_THREAD_LOCK.acquire(timeout=RWLOCK_THREAD_TIMEOUT)
57+
try:
58+
with rwlock_guard:
59+
try:
60+
with fs.open(path, encoding="utf-8") as fobj:
61+
lock = SCHEMA(json.load(fobj))
62+
except FileNotFoundError:
63+
lock = SCHEMA({})
64+
except json.JSONDecodeError as exc:
65+
raise RWLockFileCorruptedError(path) from exc
66+
except Invalid as exc:
67+
raise RWLockFileFormatError(path) from exc
68+
lock["read"] = defaultdict(list, lock["read"])
69+
lock["write"] = defaultdict(dict, lock["write"])
70+
yield lock
71+
with fs.open(path, "w", encoding="utf-8") as fobj:
72+
json.dump(lock, fobj)
73+
finally:
74+
if RWLOCK_THREAD_LOCK.locked():
75+
RWLOCK_THREAD_LOCK.release()
6876

6977

7078
def _infos_to_str(infos):

0 commit comments

Comments
 (0)