Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions src/charonload/_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,18 @@ def __init__(self: Self, module_name: str, config: ResolvedConfig, step_number:
str,
self.config.full_build_directory / "charonload" / "version.txt",
)
self.cache.connect(
"torch_version",
str,
self.config.full_build_directory / "charonload" / self.config.build_type / "torch_version.txt",
)

def _run_impl(self: Self) -> None:
clean_if_failed = {
"status_cmake_configure": True,
"status_build": False,
"status_stub_generation": False,
}
step_failed = {
step: bool(self.cache.get(step, _StepStatus.SKIPPED) == _StepStatus.FAILED) for step in clean_if_failed
}
should_clean = [clean_if_failed[step] and failed for step, failed in step_failed.items()]

if (
self.config.clean_build
or not _is_compatible(self.cache.get("version", _version()), _version())
or any(should_clean)
or self._version_incompatible()
or self._crucial_step_failed()
or self._torch_version_changed()
):
number_removed_files = 0
number_removed_directories = 0
Expand All @@ -212,6 +208,32 @@ def _run_impl(self: Self) -> None:
f"{number_removed_files} files, {number_removed_directories} directories{colorama.Style.RESET_ALL}"
)

if "torch" in sys.modules:
self.cache["torch_version"] = str(sys.modules["torch"].__version__)

def _crucial_step_failed(self: Self) -> bool:
is_crucial = {
"status_cmake_configure": True,
"status_build": False,
"status_stub_generation": False,
}
failed_statuses = {
step: bool(self.cache.get(step, _StepStatus.SKIPPED) == _StepStatus.FAILED) for step in is_crucial
}
return any(is_crucial[step] and failed for step, failed in failed_statuses.items())

def _version_incompatible(self: Self) -> bool:
return not _is_compatible(self.cache.get("version", _version()), _version())

def _torch_version_changed(self: Self) -> bool:
if "torch" in sys.modules:
current_torch_version = str(sys.modules["torch"].__version__)
previous_torch_version: str = self.cache.get("torch_version", str(sys.modules["torch"].__version__))

return current_torch_version != previous_torch_version

return False


class _InitializeStep(_JITCompileStep):
exception_cls = type(None)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,39 @@ def test_torch_clean_build_incompatible_version(shared_datadir: pathlib.Path, tm
assert not dirty_file.exists()


def test_torch_clean_build_incompatible_torch_version(shared_datadir: pathlib.Path, tmp_path: pathlib.Path) -> None:
project_directory = shared_datadir / "torch_cpu"
build_directory = tmp_path / "build"

charonload.module_config["test_torch_clean_build_incompatible_torch_version"] = charonload.Config(
project_directory,
build_directory,
stubs_directory=VSCODE_STUBS_DIRECTORY,
)
config = charonload.module_config["test_torch_clean_build_incompatible_torch_version"]

dirty_file = build_directory / "dirty.txt"

build_directory.mkdir(parents=True, exist_ok=True)
(build_directory / "charonload" / config.build_type).mkdir(parents=True, exist_ok=True)
dirty_file.touch()
with (build_directory / "charonload" / config.build_type / "torch_version.txt").open("w") as f:
f.write("0.0")

assert dirty_file.exists()

import test_torch_clean_build_incompatible_torch_version as test_torch

t_input = torch.randint(0, 10, size=(3, 3, 3), dtype=torch.float, device="cpu")
t_output = test_torch.two_times(t_input)

assert t_output.device == t_input.device
assert t_output.shape == t_input.shape
assert torch.equal(t_output, 2 * t_input)

assert not dirty_file.exists()


def test_torch_clean_build_configure_failed(shared_datadir: pathlib.Path, tmp_path: pathlib.Path) -> None:
build_directory = tmp_path / "build"

Expand Down