Skip to content

Commit bfeb90f

Browse files
committed
Also clean if torch version has changed
1 parent b7e0f0d commit bfeb90f

File tree

2 files changed

+67
-12
lines changed

2 files changed

+67
-12
lines changed

src/charonload/_finder.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -175,22 +175,18 @@ def __init__(self: Self, module_name: str, config: ResolvedConfig, step_number:
175175
str,
176176
self.config.full_build_directory / "charonload" / "version.txt",
177177
)
178+
self.cache.connect(
179+
"torch_version",
180+
str,
181+
self.config.full_build_directory / "charonload" / self.config.build_type / "torch_version.txt",
182+
)
178183

179184
def _run_impl(self: Self) -> None:
180-
clean_if_failed = {
181-
"status_cmake_configure": True,
182-
"status_build": False,
183-
"status_stub_generation": False,
184-
}
185-
step_failed = {
186-
step: bool(self.cache.get(step, _StepStatus.SKIPPED) == _StepStatus.FAILED) for step in clean_if_failed
187-
}
188-
should_clean = [clean_if_failed[step] and failed for step, failed in step_failed.items()]
189-
190185
if (
191186
self.config.clean_build
192-
or not _is_compatible(self.cache.get("version", _version()), _version())
193-
or any(should_clean)
187+
or self._version_incompatible()
188+
or self._crucial_step_failed()
189+
or self._torch_version_changed()
194190
):
195191
number_removed_files = 0
196192
number_removed_directories = 0
@@ -212,6 +208,32 @@ def _run_impl(self: Self) -> None:
212208
f"{number_removed_files} files, {number_removed_directories} directories{colorama.Style.RESET_ALL}"
213209
)
214210

211+
if "torch" in sys.modules:
212+
self.cache["torch_version"] = str(sys.modules["torch"].__version__)
213+
214+
def _crucial_step_failed(self: Self) -> bool:
215+
is_crucial = {
216+
"status_cmake_configure": True,
217+
"status_build": False,
218+
"status_stub_generation": False,
219+
}
220+
failed_statuses = {
221+
step: bool(self.cache.get(step, _StepStatus.SKIPPED) == _StepStatus.FAILED) for step in is_crucial
222+
}
223+
return any(is_crucial[step] and failed for step, failed in failed_statuses.items())
224+
225+
def _version_incompatible(self: Self) -> bool:
226+
return not _is_compatible(self.cache.get("version", _version()), _version())
227+
228+
def _torch_version_changed(self: Self) -> bool:
229+
if "torch" in sys.modules:
230+
current_torch_version = str(sys.modules["torch"].__version__)
231+
previous_torch_version: str = self.cache.get("torch_version", str(sys.modules["torch"].__version__))
232+
233+
return current_torch_version != previous_torch_version
234+
235+
return False
236+
215237

216238
class _InitializeStep(_JITCompileStep):
217239
exception_cls = type(None)

tests/test_finder.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,39 @@ def test_torch_clean_build_incompatible_version(shared_datadir: pathlib.Path, tm
686686
assert not dirty_file.exists()
687687

688688

689+
def test_torch_clean_build_incompatible_torch_version(shared_datadir: pathlib.Path, tmp_path: pathlib.Path) -> None:
690+
project_directory = shared_datadir / "torch_cpu"
691+
build_directory = tmp_path / "build"
692+
693+
charonload.module_config["test_torch_clean_build_incompatible_torch_version"] = charonload.Config(
694+
project_directory,
695+
build_directory,
696+
stubs_directory=VSCODE_STUBS_DIRECTORY,
697+
)
698+
config = charonload.module_config["test_torch_clean_build_incompatible_torch_version"]
699+
700+
dirty_file = build_directory / "dirty.txt"
701+
702+
build_directory.mkdir(parents=True, exist_ok=True)
703+
(build_directory / "charonload" / config.build_type).mkdir(parents=True, exist_ok=True)
704+
dirty_file.touch()
705+
with (build_directory / "charonload" / config.build_type / "torch_version.txt").open("w") as f:
706+
f.write("0.0")
707+
708+
assert dirty_file.exists()
709+
710+
import test_torch_clean_build_incompatible_torch_version as test_torch
711+
712+
t_input = torch.randint(0, 10, size=(3, 3, 3), dtype=torch.float, device="cpu")
713+
t_output = test_torch.two_times(t_input)
714+
715+
assert t_output.device == t_input.device
716+
assert t_output.shape == t_input.shape
717+
assert torch.equal(t_output, 2 * t_input)
718+
719+
assert not dirty_file.exists()
720+
721+
689722
def test_torch_clean_build_configure_failed(shared_datadir: pathlib.Path, tmp_path: pathlib.Path) -> None:
690723
build_directory = tmp_path / "build"
691724

0 commit comments

Comments
 (0)