Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dvc/commands/repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _repro_kwargs(self):
"run_cache": not self.args.no_run_cache,
"no_commit": self.args.no_commit,
"glob": self.args.glob,
"jobs": self.args.jobs,
}


Expand Down Expand Up @@ -188,4 +189,12 @@ def add_parser(subparsers, parent_parser):
"the same command/dependencies/outputs/etc before."
),
)
repro_parser.add_argument(
"-j",
"--jobs",
type=int,
default=1,
help="Run (at most) specified number of stages at a time in parallel.",
metavar="<number>",
)
repro_parser.set_defaults(func=CmdRepro)
41 changes: 31 additions & 10 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import contextlib
import os
import threading
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, Union

from dvc.exceptions import DvcException
Expand Down Expand Up @@ -77,6 +79,10 @@ def check_dvcfile_path(repo, path):
raise FileIsGitIgnored(relpath(path), True)


_file_locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
_file_locks_lock = threading.Lock()


class FileMixin:
SCHEMA: Callable[[_T], _T]

Expand All @@ -85,6 +91,10 @@ def __init__(self, repo, path, verify=True, **kwargs):
self.path = path
self.verify = verify

def _thread_lock(self) -> threading.Lock:
with _file_locks_lock:
return _file_locks[self.path]

def __repr__(self):
return f"{self.__class__.__name__}: {relpath(self.path, self.repo.root_dir)}"

Expand Down Expand Up @@ -148,15 +158,19 @@ def validate(cls, d: _T, fname: Optional[str] = None) -> _T:
def _load_yaml(self, **kwargs: Any) -> tuple[Any, str]:
from dvc.utils import strictyaml

return strictyaml.load(
self.path,
self.SCHEMA, # type: ignore[arg-type]
self.repo.fs,
**kwargs,
)
with self._thread_lock():
return strictyaml.load(
self.path,
self.SCHEMA, # type: ignore[arg-type]
self.repo.fs,
**kwargs,
)

def remove(self, force=False): # noqa: ARG002
with contextlib.suppress(FileNotFoundError):
with (
self._thread_lock(),
contextlib.suppress(FileNotFoundError),
):
os.unlink(self.path)

def dump(self, stage, **kwargs):
Expand Down Expand Up @@ -407,7 +421,10 @@ def _load(self, **kwargs: Any):
return {}, ""

def dump_dataset(self, dataset: dict):
with modify_yaml(self.path, fs=self.repo.fs) as data:
with (
self._thread_lock(),
modify_yaml(self.path, fs=self.repo.fs) as data,
):
data.update({"schema": "2.0"})
if not data:
logger.info("Generating lock file '%s'", self.relpath)
Expand All @@ -430,7 +447,10 @@ def dump_stages(self, stages, **kwargs):

is_modified = False
log_updated = False
with modify_yaml(self.path, fs=self.repo.fs) as data:
with (
self._thread_lock(),
modify_yaml(self.path, fs=self.repo.fs) as data,
):
if not data:
data.update({"schema": "2.0"})
# order is important, meta should always be at the top
Expand Down Expand Up @@ -468,7 +488,8 @@ def remove_stage(self, stage):
del data[stage.name]

if data:
dump_yaml(self.path, d)
with self._thread_lock():
dump_yaml(self.path, d)
else:
self.remove()

Expand Down
184 changes: 142 additions & 42 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import concurrent.futures
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Callable, NoReturn, Optional, TypeVar, Union, cast

from funcy import ldistinct
Expand Down Expand Up @@ -120,14 +123,6 @@ def _reproduce_stage(stage: "Stage", **kwargs) -> Optional["Stage"]:
return ret


def _get_upstream_downstream_nodes(
graph: Optional["DiGraph"], node: T
) -> tuple[list[T], list[T]]:
succ = list(graph.successors(node)) if graph else []
pre = list(graph.predecessors(node)) if graph else []
return succ, pre


def _repr(stages: Iterable["Stage"]) -> str:
return humanize.join(repr(stage.addressing) for stage in stages)

Expand Down Expand Up @@ -155,54 +150,159 @@ def _raise_error(exc: Optional[Exception], *stages: "Stage") -> NoReturn:
raise ReproductionError(f"failed to reproduce{segment} {names}") from exc


def _reproduce(
stages: list["Stage"],
graph: Optional["DiGraph"] = None,
force_downstream: bool = False,
on_error: str = "fail",
force: bool = False,
class ReproStatus(Enum):
READY = "ready"
IN_PROGRESS = "in-progress"
COMPLETE = "complete"
SKIPPED = "skipped"
FAILED = "failed"


@dataclass
class StageInfo:
upstream: list["Stage"]
upstream_unfinished: set["Stage"]
downstream: list["Stage"]
force: bool
status: ReproStatus
result: Optional["Stage"]


def _start_ready_stages(
to_repro: dict["Stage", StageInfo],
executor: concurrent.futures.ThreadPoolExecutor,
max_stages: int,
repro_fn: Callable = _reproduce_stage,
**kwargs,
) -> dict[concurrent.futures.Future["Stage"], "Stage"]:
ready = [
(stage, stage_info)
for stage, stage_info in to_repro.items()
if stage_info.status == ReproStatus.READY and not stage_info.upstream_unfinished
]
if not ready:
return {}

futures = {
executor.submit(
repro_fn,
stage,
upstream=stage_info.upstream,
force=stage_info.force,
**kwargs,
): stage
for stage, stage_info in ready[:max_stages]
}
for stage in futures.values():
to_repro[stage].status = ReproStatus.IN_PROGRESS
return futures


def _result_or_raise(
to_repro: dict["Stage", StageInfo], stages: list["Stage"], on_error: str
) -> list["Stage"]:
assert on_error in ("fail", "keep-going", "ignore")

result: list[Stage] = []
failed: list[Stage] = []
to_skip: dict[Stage, Stage] = {}
ret: Optional[Stage] = None

force_state = dict.fromkeys(stages, force)

# Preserve original order
for stage in stages:
if stage in to_skip:
continue
stage_info = to_repro[stage]
if stage_info.status == ReproStatus.FAILED:
failed.append(stage)
elif stage_info.result:
result.append(stage_info.result)

if ret:
logger.info("") # add a newline
if on_error != "ignore" and failed:
_raise_error(None, *failed)

upstream, downstream = _get_upstream_downstream_nodes(graph, stage)
force_stage = force_state[stage]
return result

try:
ret = repro_fn(stage, upstream=upstream, force=force_stage, **kwargs)
except Exception as exc: # noqa: BLE001
failed.append(stage)
if on_error == "fail":
_raise_error(exc, stage)

dependents = handle_error(graph, on_error, exc, stage)
to_skip.update(dict.fromkeys(dependents, stage))
def _handle_result(
to_repro: dict["Stage", StageInfo],
future: concurrent.futures.Future["Stage"],
stage: "Stage",
stage_info: StageInfo,
graph: Optional["DiGraph"],
on_error: str,
force_downstream: bool,
):
ret: Optional[Stage] = None
success = False
try:
ret = future.result()
except Exception as exc: # noqa: BLE001
if on_error == "fail":
_raise_error(exc, stage)

stage_info.status = ReproStatus.FAILED
dependents = handle_error(graph, on_error, exc, stage)
for dependent in dependents:
to_repro[dependent].status = ReproStatus.SKIPPED
else:
stage_info.status = ReproStatus.COMPLETE
success = True

for dependent in stage_info.downstream:
if dependent not in to_repro:
continue
dependent_info = to_repro[dependent]
if stage in dependent_info.upstream_unfinished:
dependent_info.upstream_unfinished.remove(stage)
if success and force_downstream and (ret or stage_info.force):
dependent_info.force = True

if force_downstream and (ret or force_stage):
force_state.update(dict.fromkeys(downstream, True))
if success and ret:
stage_info.result = ret

if ret:
result.append(ret)

if on_error != "ignore" and failed:
_raise_error(None, *failed)
return result
def _reproduce(
stages: list["Stage"],
graph: Optional["DiGraph"] = None,
force_downstream: bool = False,
on_error: str = "fail",
force: bool = False,
jobs: int = 1,
**kwargs,
) -> list["Stage"]:
assert on_error in ("fail", "keep-going", "ignore")

to_repro = {
stage: StageInfo(
upstream=(upstream := list(graph.successors(stage)) if graph else []),
upstream_unfinished=set(upstream).intersection(stages),
downstream=list(graph.predecessors(stage)) if graph else [],
force=force,
status=ReproStatus.READY,
result=None,
)
for stage in stages
}

if jobs == -1:
jobs = len(stages)
max_workers = max(1, min(jobs, len(stages)))
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = _start_ready_stages(to_repro, executor, max_workers, **kwargs)
while futures:
done, _ = concurrent.futures.wait(
futures, return_when=concurrent.futures.FIRST_COMPLETED
)
for future in done:
stage = futures.pop(future)
stage_info = to_repro[stage]
_handle_result(
to_repro,
future,
stage,
stage_info,
graph,
on_error,
force_downstream,
)

futures.update(_start_ready_stages(to_repro, executor, len(done), **kwargs))

return _result_or_raise(to_repro, stages, on_error)


@locked
Expand Down
38 changes: 23 additions & 15 deletions dvc/rwlock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import threading
from collections import defaultdict
from contextlib import contextmanager

Expand All @@ -25,6 +26,8 @@
}
)

RWLOCK_THREAD_LOCK = threading.Lock()
RWLOCK_THREAD_TIMEOUT = 3
RWLOCK_FILE = "rwlock"
RWLOCK_LOCK = "rwlock.lock"

Expand All @@ -50,21 +53,26 @@ def _edit_rwlock(lock_dir, fs, hardlink):
tmp_dir=lock_dir,
hardlink_lock=hardlink,
)
with rwlock_guard:
try:
with fs.open(path, encoding="utf-8") as fobj:
lock = SCHEMA(json.load(fobj))
except FileNotFoundError:
lock = SCHEMA({})
except json.JSONDecodeError as exc:
raise RWLockFileCorruptedError(path) from exc
except Invalid as exc:
raise RWLockFileFormatError(path) from exc
lock["read"] = defaultdict(list, lock["read"])
lock["write"] = defaultdict(dict, lock["write"])
yield lock
with fs.open(path, "w", encoding="utf-8") as fobj:
json.dump(lock, fobj)
RWLOCK_THREAD_LOCK.acquire(timeout=RWLOCK_THREAD_TIMEOUT)
try:
with rwlock_guard:
try:
with fs.open(path, encoding="utf-8") as fobj:
lock = SCHEMA(json.load(fobj))
except FileNotFoundError:
lock = SCHEMA({})
except json.JSONDecodeError as exc:
raise RWLockFileCorruptedError(path) from exc
except Invalid as exc:
raise RWLockFileFormatError(path) from exc
lock["read"] = defaultdict(list, lock["read"])
lock["write"] = defaultdict(dict, lock["write"])
yield lock
with fs.open(path, "w", encoding="utf-8") as fobj:
json.dump(lock, fobj)
finally:
if RWLOCK_THREAD_LOCK.locked():
RWLOCK_THREAD_LOCK.release()


def _infos_to_str(infos):
Expand Down
Loading