@@ -349,12 +349,14 @@ def _torch_incremental_build_function(
349349 module_name : str ,
350350 project_directory : pathlib .Path ,
351351 build_directory : pathlib .Path ,
352+ stubs_directory : pathlib .Path | None ,
353+ in_submodule : bool , # noqa: FBT001
352354 result : Any , # noqa: ANN401
353355) -> None :
354356 charonload .module_config [module_name ] = charonload .Config (
355357 project_directory ,
356358 build_directory ,
357- stubs_directory = VSCODE_STUBS_DIRECTORY ,
359+ stubs_directory = stubs_directory ,
358360 )
359361
360362 t_start = time .perf_counter ()
@@ -364,7 +366,8 @@ def _torch_incremental_build_function(
364366 result .value = float (t_end - t_start )
365367
366368 t_input = torch .randint (0 , 10 , size = (3 , 3 , 3 ), dtype = torch .float , device = "cpu" )
367- t_output = test_torch .two_times (t_input )
369+
370+ t_output = test_torch .sub .two_times (t_input ) if in_submodule else test_torch .two_times (t_input )
368371
369372 assert t_output .device == t_input .device
370373 assert t_output .shape == t_input .shape
@@ -375,6 +378,9 @@ def _torch_incremental_build(
375378 module_name : str ,
376379 project_directory : pathlib .Path ,
377380 build_directory : pathlib .Path ,
381+ * ,
382+ stubs_directory : pathlib .Path | None = None ,
383+ in_submodule : bool = False ,
378384) -> float :
379385 result = multiprocessing .get_context ("spawn" ).Value ("d" , 0.0 )
380386 p = multiprocessing .get_context ("spawn" ).Process (
@@ -383,6 +389,8 @@ def _torch_incremental_build(
383389 module_name ,
384390 project_directory ,
385391 build_directory ,
392+ stubs_directory ,
393+ in_submodule ,
386394 result ,
387395 ),
388396 )
@@ -424,6 +432,76 @@ def test_torch_incremental_build_cmake(shared_datadir: pathlib.Path, tmp_path: p
424432 assert t_with_changes > t_no_changes
425433
426434
435+ def test_torch_incremental_build_stubs_single (shared_datadir : pathlib .Path , tmp_path : pathlib .Path ) -> None :
436+ project_directory = shared_datadir / "torch_cpu"
437+ build_directory = tmp_path / "build"
438+
439+ _torch_incremental_build (
440+ "test_torch_incremental_build_stubs_single" ,
441+ project_directory ,
442+ build_directory ,
443+ stubs_directory = VSCODE_STUBS_DIRECTORY ,
444+ in_submodule = False ,
445+ )
446+
447+ t_no_changes = _torch_incremental_build (
448+ "test_torch_incremental_build_stubs_single" ,
449+ project_directory ,
450+ build_directory ,
451+ stubs_directory = VSCODE_STUBS_DIRECTORY ,
452+ in_submodule = False ,
453+ )
454+
455+ # Either a single file or a directory containing a single file will be generated
456+ if (stub := VSCODE_STUBS_DIRECTORY / "test_torch_incremental_build_stubs_single.pyi" ).exists ():
457+ stub .unlink ()
458+ if (stub := VSCODE_STUBS_DIRECTORY / "test_torch_incremental_build_stubs_single" ).exists ():
459+ shutil .rmtree (stub , ignore_errors = True )
460+
461+ t_with_changes = _torch_incremental_build (
462+ "test_torch_incremental_build_stubs_single" ,
463+ project_directory ,
464+ build_directory ,
465+ stubs_directory = VSCODE_STUBS_DIRECTORY ,
466+ in_submodule = False ,
467+ )
468+
469+ assert t_with_changes > t_no_changes
470+
471+
472+ def test_torch_incremental_build_stubs_multiple (shared_datadir : pathlib .Path , tmp_path : pathlib .Path ) -> None :
473+ project_directory = shared_datadir / "torch_submodule"
474+ build_directory = tmp_path / "build"
475+
476+ _torch_incremental_build (
477+ "test_torch_incremental_build_stubs_multiple" ,
478+ project_directory ,
479+ build_directory ,
480+ stubs_directory = VSCODE_STUBS_DIRECTORY ,
481+ in_submodule = True ,
482+ )
483+
484+ t_no_changes = _torch_incremental_build (
485+ "test_torch_incremental_build_stubs_multiple" ,
486+ project_directory ,
487+ build_directory ,
488+ stubs_directory = VSCODE_STUBS_DIRECTORY ,
489+ in_submodule = True ,
490+ )
491+
492+ shutil .rmtree (VSCODE_STUBS_DIRECTORY / "test_torch_incremental_build_stubs_multiple" , ignore_errors = True )
493+
494+ t_with_changes = _torch_incremental_build (
495+ "test_torch_incremental_build_stubs_multiple" ,
496+ project_directory ,
497+ build_directory ,
498+ stubs_directory = VSCODE_STUBS_DIRECTORY ,
499+ in_submodule = True ,
500+ )
501+
502+ assert t_with_changes > t_no_changes
503+
504+
427505def test_torch_cmake_include_twice (shared_datadir : pathlib .Path , tmp_path : pathlib .Path ) -> None :
428506 project_directory = shared_datadir / "torch_cmake_include_twice"
429507 build_directory = tmp_path / "build"
0 commit comments