From 10383e1f36775812b1e4ce7843991198c2e57be7 Mon Sep 17 00:00:00 2001 From: Sami Jawhar Date: Sun, 7 Dec 2025 01:35:57 +0000 Subject: [PATCH] repro: parallel execution of stages --- dvc/commands/repro.py | 9 ++ dvc/dvcfile.py | 41 +++++-- dvc/repo/reproduce.py | 184 ++++++++++++++++++++++++------- dvc/rwlock.py | 38 ++++--- dvc/stage/decorators.py | 3 + tests/func/repro/test_repro.py | 28 +++++ tests/unit/command/test_repro.py | 16 ++- tests/unit/test_rwlock.py | 18 +++ 8 files changed, 267 insertions(+), 70 deletions(-) diff --git a/dvc/commands/repro.py b/dvc/commands/repro.py index bbcb132e90..f6f8d12968 100644 --- a/dvc/commands/repro.py +++ b/dvc/commands/repro.py @@ -40,6 +40,7 @@ def _repro_kwargs(self): "run_cache": not self.args.no_run_cache, "no_commit": self.args.no_commit, "glob": self.args.glob, + "jobs": self.args.jobs, } @@ -188,4 +189,12 @@ def add_parser(subparsers, parent_parser): "the same command/dependencies/outputs/etc before." ), ) + repro_parser.add_argument( + "-j", + "--jobs", + type=int, + default=1, + help="Run (at most) specified number of stages at a time in parallel.", + metavar="", + ) repro_parser.set_defaults(func=CmdRepro) diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index 1e0ef6367d..b93de3f080 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -1,5 +1,7 @@ import contextlib import os +import threading +from collections import defaultdict from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, Union from dvc.exceptions import DvcException @@ -77,6 +79,10 @@ def check_dvcfile_path(repo, path): raise FileIsGitIgnored(relpath(path), True) +_file_locks: dict[str, threading.Lock] = defaultdict(threading.Lock) +_file_locks_lock = threading.Lock() + + class FileMixin: SCHEMA: Callable[[_T], _T] @@ -85,6 +91,10 @@ def __init__(self, repo, path, verify=True, **kwargs): self.path = path self.verify = verify + def _thread_lock(self) -> threading.Lock: + with _file_locks_lock: + return _file_locks[self.path] + def __repr__(self): return f"{self.__class__.__name__}: {relpath(self.path, self.repo.root_dir)}" @@ -148,15 +158,19 @@ def validate(cls, d: _T, fname: Optional[str] = None) -> _T: def _load_yaml(self, **kwargs: Any) -> tuple[Any, str]: from dvc.utils import strictyaml - return strictyaml.load( - self.path, - self.SCHEMA, # type: ignore[arg-type] - self.repo.fs, - **kwargs, - ) + with self._thread_lock(): + return strictyaml.load( + self.path, + self.SCHEMA, # type: ignore[arg-type] + self.repo.fs, + **kwargs, + ) def remove(self, force=False): # noqa: ARG002 - with contextlib.suppress(FileNotFoundError): + with ( + self._thread_lock(), + contextlib.suppress(FileNotFoundError), + ): os.unlink(self.path) def dump(self, stage, **kwargs): @@ -407,7 +421,10 @@ def _load(self, **kwargs: Any): return {}, "" def dump_dataset(self, dataset: dict): - with modify_yaml(self.path, fs=self.repo.fs) as data: + with ( + self._thread_lock(), + modify_yaml(self.path, fs=self.repo.fs) as data, + ): data.update({"schema": "2.0"}) if not data: logger.info("Generating lock file '%s'", self.relpath) @@ -430,7 +447,10 @@ def dump_stages(self, stages, **kwargs): is_modified = False log_updated = False - with modify_yaml(self.path, fs=self.repo.fs) as data: + with ( + self._thread_lock(), + modify_yaml(self.path, fs=self.repo.fs) as data, + ): if not data: data.update({"schema": "2.0"}) # order is important, meta should always be at the top @@ -468,7 +488,8 @@ def remove_stage(self, stage): del data[stage.name] if data: - dump_yaml(self.path, d) + with self._thread_lock(): + dump_yaml(self.path, d) else: self.remove() diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index 4c62fbd970..5a323b4d44 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -1,4 +1,7 @@ +import concurrent.futures from collections.abc import Iterable +from dataclasses import dataclass +from enum import Enum from typing import TYPE_CHECKING, Callable, NoReturn, Optional, TypeVar, Union, cast from funcy import ldistinct @@ -120,14 +123,6 @@ def _reproduce_stage(stage: "Stage", **kwargs) -> Optional["Stage"]: return ret -def _get_upstream_downstream_nodes( - graph: Optional["DiGraph"], node: T -) -> tuple[list[T], list[T]]: - succ = list(graph.successors(node)) if graph else [] - pre = list(graph.predecessors(node)) if graph else [] - return succ, pre - - def _repr(stages: Iterable["Stage"]) -> str: return humanize.join(repr(stage.addressing) for stage in stages) @@ -155,54 +150,159 @@ def _raise_error(exc: Optional[Exception], *stages: "Stage") -> NoReturn: raise ReproductionError(f"failed to reproduce{segment} {names}") from exc -def _reproduce( - stages: list["Stage"], - graph: Optional["DiGraph"] = None, - force_downstream: bool = False, - on_error: str = "fail", - force: bool = False, +class ReproStatus(Enum): + READY = "ready" + IN_PROGRESS = "in-progress" + COMPLETE = "complete" + SKIPPED = "skipped" + FAILED = "failed" + + +@dataclass +class StageInfo: + upstream: list["Stage"] + upstream_unfinished: set["Stage"] + downstream: list["Stage"] + force: bool + status: ReproStatus + result: Optional["Stage"] + + +def _start_ready_stages( + to_repro: dict["Stage", StageInfo], + executor: concurrent.futures.ThreadPoolExecutor, + max_stages: int, repro_fn: Callable = _reproduce_stage, **kwargs, +) -> dict[concurrent.futures.Future["Stage"], "Stage"]: + ready = [ + (stage, stage_info) + for stage, stage_info in to_repro.items() + if stage_info.status == ReproStatus.READY and not stage_info.upstream_unfinished + ] + if not ready: + return {} + + futures = { + executor.submit( + repro_fn, + stage, + upstream=stage_info.upstream, + force=stage_info.force, + **kwargs, + ): stage + for stage, stage_info in ready[:max_stages] + } + for stage in futures.values(): + to_repro[stage].status = ReproStatus.IN_PROGRESS + return futures + + +def _result_or_raise( + to_repro: dict["Stage", StageInfo], stages: list["Stage"], on_error: str ) -> list["Stage"]: - assert on_error in ("fail", "keep-going", "ignore") - result: list[Stage] = [] failed: list[Stage] = [] - to_skip: dict[Stage, Stage] = {} - ret: Optional[Stage] = None - - force_state = dict.fromkeys(stages, force) - + # Preserve original order for stage in stages: - if stage in to_skip: - continue + stage_info = to_repro[stage] + if stage_info.status == ReproStatus.FAILED: + failed.append(stage) + elif stage_info.result: + result.append(stage_info.result) - if ret: - logger.info("") # add a newline + if on_error != "ignore" and failed: + _raise_error(None, *failed) - upstream, downstream = _get_upstream_downstream_nodes(graph, stage) - force_stage = force_state[stage] + return result - try: - ret = repro_fn(stage, upstream=upstream, force=force_stage, **kwargs) - except Exception as exc: # noqa: BLE001 - failed.append(stage) - if on_error == "fail": - _raise_error(exc, stage) - dependents = handle_error(graph, on_error, exc, stage) - to_skip.update(dict.fromkeys(dependents, stage)) +def _handle_result( + to_repro: dict["Stage", StageInfo], + future: concurrent.futures.Future["Stage"], + stage: "Stage", + stage_info: StageInfo, + graph: Optional["DiGraph"], + on_error: str, + force_downstream: bool, +): + ret: Optional[Stage] = None + success = False + try: + ret = future.result() + except Exception as exc: # noqa: BLE001 + if on_error == "fail": + _raise_error(exc, stage) + + stage_info.status = ReproStatus.FAILED + dependents = handle_error(graph, on_error, exc, stage) + for dependent in dependents: + to_repro[dependent].status = ReproStatus.SKIPPED + else: + stage_info.status = ReproStatus.COMPLETE + success = True + + for dependent in stage_info.downstream: + if dependent not in to_repro: continue + dependent_info = to_repro[dependent] + if stage in dependent_info.upstream_unfinished: + dependent_info.upstream_unfinished.remove(stage) + if success and force_downstream and (ret or stage_info.force): + dependent_info.force = True - if force_downstream and (ret or force_stage): - force_state.update(dict.fromkeys(downstream, True)) + if success and ret: + stage_info.result = ret - if ret: - result.append(ret) - if on_error != "ignore" and failed: - _raise_error(None, *failed) - return result +def _reproduce( + stages: list["Stage"], + graph: Optional["DiGraph"] = None, + force_downstream: bool = False, + on_error: str = "fail", + force: bool = False, + jobs: int = 1, + **kwargs, +) -> list["Stage"]: + assert on_error in ("fail", "keep-going", "ignore") + + to_repro = { + stage: StageInfo( + upstream=(upstream := list(graph.successors(stage)) if graph else []), + upstream_unfinished=set(upstream).intersection(stages), + downstream=list(graph.predecessors(stage)) if graph else [], + force=force, + status=ReproStatus.READY, + result=None, + ) + for stage in stages + } + + if jobs == -1: + jobs = len(stages) + max_workers = max(1, min(jobs, len(stages))) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = _start_ready_stages(to_repro, executor, max_workers, **kwargs) + while futures: + done, _ = concurrent.futures.wait( + futures, return_when=concurrent.futures.FIRST_COMPLETED + ) + for future in done: + stage = futures.pop(future) + stage_info = to_repro[stage] + _handle_result( + to_repro, + future, + stage, + stage_info, + graph, + on_error, + force_downstream, + ) + + futures.update(_start_ready_stages(to_repro, executor, len(done), **kwargs)) + + return _result_or_raise(to_repro, stages, on_error) @locked diff --git a/dvc/rwlock.py b/dvc/rwlock.py index 8d06df1632..46cc468a65 100644 --- a/dvc/rwlock.py +++ b/dvc/rwlock.py @@ -1,5 +1,6 @@ import json import os +import threading from collections import defaultdict from contextlib import contextmanager @@ -25,6 +26,8 @@ } ) +RWLOCK_THREAD_LOCK = threading.Lock() +RWLOCK_THREAD_TIMEOUT = 3 RWLOCK_FILE = "rwlock" RWLOCK_LOCK = "rwlock.lock" @@ -50,21 +53,26 @@ def _edit_rwlock(lock_dir, fs, hardlink): tmp_dir=lock_dir, hardlink_lock=hardlink, ) - with rwlock_guard: - try: - with fs.open(path, encoding="utf-8") as fobj: - lock = SCHEMA(json.load(fobj)) - except FileNotFoundError: - lock = SCHEMA({}) - except json.JSONDecodeError as exc: - raise RWLockFileCorruptedError(path) from exc - except Invalid as exc: - raise RWLockFileFormatError(path) from exc - lock["read"] = defaultdict(list, lock["read"]) - lock["write"] = defaultdict(dict, lock["write"]) - yield lock - with fs.open(path, "w", encoding="utf-8") as fobj: - json.dump(lock, fobj) + RWLOCK_THREAD_LOCK.acquire(timeout=RWLOCK_THREAD_TIMEOUT) + try: + with rwlock_guard: + try: + with fs.open(path, encoding="utf-8") as fobj: + lock = SCHEMA(json.load(fobj)) + except FileNotFoundError: + lock = SCHEMA({}) + except json.JSONDecodeError as exc: + raise RWLockFileCorruptedError(path) from exc + except Invalid as exc: + raise RWLockFileFormatError(path) from exc + lock["read"] = defaultdict(list, lock["read"]) + lock["write"] = defaultdict(dict, lock["write"]) + yield lock + with fs.open(path, "w", encoding="utf-8") as fobj: + json.dump(lock, fobj) + finally: + if RWLOCK_THREAD_LOCK.locked(): + RWLOCK_THREAD_LOCK.release() def _infos_to_str(infos): diff --git a/dvc/stage/decorators.py b/dvc/stage/decorators.py index 2b15e0fd61..7c23bfcc2f 100644 --- a/dvc/stage/decorators.py +++ b/dvc/stage/decorators.py @@ -1,3 +1,4 @@ +import threading from functools import wraps from funcy import decorator @@ -47,6 +48,8 @@ def _chain(names): def unlocked_repo(f): @wraps(f) def wrapper(stage, *args, **kwargs): + if threading.current_thread() is not threading.main_thread(): + return f(stage, *args, **kwargs) stage.repo.lock.unlock() stage.repo._reset() try: diff --git a/tests/func/repro/test_repro.py b/tests/func/repro/test_repro.py index 4266ad2cac..04b5b68ab5 100644 --- a/tests/func/repro/test_repro.py +++ b/tests/func/repro/test_repro.py @@ -1273,6 +1273,34 @@ def test_repro_ignore_errors(mocker, tmp_dir, dvc, copy_script): assert stage2_call in spy.call_args_list +def test_repro_parallel_jobs(tmp_dir, dvc, copy_script): + tmp_dir.dvc_gen({"foo": "foo", "bar": "bar", "baz": "baz"}) + + dvc.stage.add( + cmd="python copy.py foo out_a", deps=["foo"], outs=["out_a"], name="stage-a" + ) + dvc.stage.add( + cmd="python copy.py bar out_b", deps=["bar"], outs=["out_b"], name="stage-b" + ) + dvc.stage.add( + cmd="python copy.py baz out_c", deps=["baz"], outs=["out_c"], name="stage-c" + ) + dvc.stage.add( + cmd="cat out_a out_b out_c > final", + deps=["out_a", "out_b", "out_c"], + outs=["final"], + name="final", + ) + + ret = main(["repro", "--jobs", "3"]) + assert ret == 0 + + assert (tmp_dir / "out_a").read_text() == "foo" + assert (tmp_dir / "out_b").read_text() == "bar" + assert (tmp_dir / "out_c").read_text() == "baz" + assert (tmp_dir / "final").read_text() == "foobarbaz" + + @pytest.mark.parametrize("persist", [True, False]) def test_repro_external_outputs(tmp_dir, dvc, local_workspace, persist): local_workspace.gen("foo", "foo") diff --git a/tests/unit/command/test_repro.py b/tests/unit/command/test_repro.py index 7010ee466b..404dd596de 100644 --- a/tests/unit/command/test_repro.py +++ b/tests/unit/command/test_repro.py @@ -1,3 +1,5 @@ +import pytest + from dvc.cli import parse_args from dvc.commands.repro import CmdRepro @@ -20,6 +22,7 @@ "run_cache": True, "no_commit": False, "glob": False, + "jobs": 1, } @@ -30,11 +33,18 @@ def test_default_arguments(dvc, mocker): cmd.repo.reproduce.assert_called_with(**common_arguments, **repro_arguments) -def test_downstream(dvc, mocker): - cmd = CmdRepro(parse_args(["repro", "--downstream"])) +@pytest.mark.parametrize( + "cli_arguments, expected_arguments", + [ + (["--downstream"], {"downstream": True}), + (["-j", "2"], {"jobs": 2}), + ], +) +def test_calls(dvc, mocker, cli_arguments, expected_arguments): + cmd = CmdRepro(parse_args(["repro", *cli_arguments])) mocker.patch.object(cmd.repo, "reproduce") cmd.run() arguments = common_arguments.copy() arguments.update(repro_arguments) - arguments.update({"downstream": True}) + arguments.update(expected_arguments) cmd.repo.reproduce.assert_called_with(**arguments) diff --git a/tests/unit/test_rwlock.py b/tests/unit/test_rwlock.py index 0a80d7fa55..2667157c83 100644 --- a/tests/unit/test_rwlock.py +++ b/tests/unit/test_rwlock.py @@ -1,5 +1,7 @@ +import concurrent.futures import json import os +import time import pytest @@ -59,6 +61,7 @@ def test_rwlock_reentrant(tmp_path): def test_rwlock_edit_is_guarded(tmp_path, mocker): # patching to speedup tests mocker.patch("dvc.lock.DEFAULT_TIMEOUT", 0.01) + mocker.patch("dvc.rwlock.RWLOCK_THREAD_TIMEOUT", 0.01) path = os.fspath(tmp_path) @@ -68,6 +71,21 @@ def test_rwlock_edit_is_guarded(tmp_path, mocker): pass +def test_rwlock_multiple_threads(tmp_path, mocker): + # patching to speedup tests + mocker.patch("dvc.rwlock.RWLOCK_THREAD_TIMEOUT", 0.01) + path = os.fspath(tmp_path) + foo = "foo" + + def work(): + with rwlock(path, localfs, "cmd1", [foo], [], False): + time.sleep(1) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + futures = [executor.submit(work) for _ in range(2)] + concurrent.futures.wait(futures) + + def test_rwlock_subdirs(tmp_path): path = os.fspath(tmp_path) foo = "foo"