@@ -175,22 +175,18 @@ def __init__(self: Self, module_name: str, config: ResolvedConfig, step_number:
175175 str ,
176176 self .config .full_build_directory / "charonload" / "version.txt" ,
177177 )
178+ self .cache .connect (
179+ "torch_version" ,
180+ str ,
181+ self .config .full_build_directory / "charonload" / self .config .build_type / "torch_version.txt" ,
182+ )
178183
179184 def _run_impl (self : Self ) -> None :
180- clean_if_failed = {
181- "status_cmake_configure" : True ,
182- "status_build" : False ,
183- "status_stub_generation" : False ,
184- }
185- step_failed = {
186- step : bool (self .cache .get (step , _StepStatus .SKIPPED ) == _StepStatus .FAILED ) for step in clean_if_failed
187- }
188- should_clean = [clean_if_failed [step ] and failed for step , failed in step_failed .items ()]
189-
190185 if (
191186 self .config .clean_build
192- or not _is_compatible (self .cache .get ("version" , _version ()), _version ())
193- or any (should_clean )
187+ or self ._version_incompatible ()
188+ or self ._crucial_step_failed ()
189+ or self ._torch_version_changed ()
194190 ):
195191 number_removed_files = 0
196192 number_removed_directories = 0
@@ -212,6 +208,32 @@ def _run_impl(self: Self) -> None:
212208 f"{ number_removed_files } files, { number_removed_directories } directories{ colorama .Style .RESET_ALL } "
213209 )
214210
211+ if "torch" in sys .modules :
212+ self .cache ["torch_version" ] = str (sys .modules ["torch" ].__version__ )
213+
214+ def _crucial_step_failed (self : Self ) -> bool :
215+ is_crucial = {
216+ "status_cmake_configure" : True ,
217+ "status_build" : False ,
218+ "status_stub_generation" : False ,
219+ }
220+ failed_statuses = {
221+ step : bool (self .cache .get (step , _StepStatus .SKIPPED ) == _StepStatus .FAILED ) for step in is_crucial
222+ }
223+ return any (is_crucial [step ] and failed for step , failed in failed_statuses .items ())
224+
225+ def _version_incompatible (self : Self ) -> bool :
226+ return not _is_compatible (self .cache .get ("version" , _version ()), _version ())
227+
228+ def _torch_version_changed (self : Self ) -> bool :
229+ if "torch" in sys .modules :
230+ current_torch_version = str (sys .modules ["torch" ].__version__ )
231+ previous_torch_version : str = self .cache .get ("torch_version" , str (sys .modules ["torch" ].__version__ ))
232+
233+ return current_torch_version != previous_torch_version
234+
235+ return False
236+
215237
216238class _InitializeStep (_JITCompileStep ):
217239 exception_cls = type (None )
0 commit comments