@@ -461,6 +461,10 @@ def get_extensions():
461461
462462 extra_compile_args = {"cxx" : ["-O3" , "-std=c++17" , "-DPy_LIMITED_API=0x03090000" ]}
463463 if sys .platform == "win32" :
464+ if os .getenv ("DISTUTILS_USE_SDK" ) == "1" :
465+ extra_compile_args = {
466+ "cxx" : ["-O2" , "/std:c++17" , "/DPy_LIMITED_API=0x03090000" ]
467+ }
464468 define_macros += [("xformers_EXPORTS" , None )]
465469 extra_compile_args ["cxx" ].extend (
466470 ["/MP" , "/Zc:lambda" , "/Zc:preprocessor" , "/Zc:__cplusplus" ]
@@ -671,6 +675,31 @@ def __init__(self, *args, **kwargs) -> None:
671675
672676 def build_extensions (self ) -> None :
673677 super ().build_extensions ()
678+
679+ # Fix incorrect output names caused by py_limited_api=True on Windows. see item #1272
680+ for ext in self .extensions :
681+ ext_path_parts = ext .name .split ("." )
682+ ext_basename = ext_path_parts [- 1 ]
683+ ext_subpath = os .path .join (
684+ * ext_path_parts [:- 1 ]
685+ ) # xformers, xformers/flash_attn_3, etc.
686+
687+ # Directory where the .pyd was written
688+ output_dir = os .path .join (self .build_lib , ext_subpath )
689+
690+ # Expected correct filename
691+ correct_name = os .path .join (output_dir , f"{ ext_basename } .pyd" )
692+
693+ # But py_limited_api may incorrectly write it as just "pyd"
694+ broken_name = os .path .join (output_dir , "pyd" )
695+ if os .path .exists (broken_name ) and not os .path .exists (correct_name ):
696+ import shutil
697+
698+ print (
699+ f"[INFO]build_extensions: Fixing broken .pyd name: { broken_name } -> { correct_name } "
700+ )
701+ shutil .move (broken_name , correct_name )
702+
674703 for filename , content in self .xformers_build_metadata .items ():
675704 with open (
676705 os .path .join (self .build_lib , self .pkg_name , filename ), "w+"
@@ -685,12 +714,46 @@ def copy_extensions_to_source(self) -> None:
685714 build_py = self .get_finalized_command ("build_py" )
686715 package_dir = build_py .get_package_dir (self .pkg_name )
687716
717+ # Fix for windows when using py_limited_api=True. see #1272
718+ for ext in self .extensions :
719+ ext_path_parts = ext .name .split ("." )
720+ ext_basename = ext_path_parts [- 1 ]
721+ ext_subpath = os .path .join (* ext_path_parts [:- 1 ])
722+ build_dir = os .path .join (self .build_lib , ext_subpath )
723+
724+ correct_name = os .path .join (build_dir , f"{ ext_basename } .pyd" )
725+ broken_name = os .path .join (build_dir , "pyd" )
726+ if os .path .exists (broken_name ) and not os .path .exists (correct_name ):
727+ import shutil
728+
729+ print (
730+ f"[INFO]copy_extensions_to_source: Fixing inplace broken .pyd name: { broken_name } -> { correct_name } "
731+ )
732+ shutil .move (broken_name , correct_name )
733+
688734 for filename in self .xformers_build_metadata .keys ():
689735 inplace_file = os .path .join (package_dir , filename )
690736 regular_file = os .path .join (self .build_lib , self .pkg_name , filename )
691737 self .copy_file (regular_file , inplace_file , level = self .verbose )
692738 super ().copy_extensions_to_source ()
693739
740+ def get_ext_filename (self , ext_name ):
741+ filename = super ().get_ext_filename (ext_name )
742+ # Fix for windows when using py_limited_api=True. see #1272
743+ # If setuptools returns a bogus 'pyd' filename, fix it.
744+ if os .path .basename (filename ) == "pyd" :
745+ # Extract the final component of the ext_name (after last dot)
746+ last_part = ext_name .rsplit ("." , 1 )[- 1 ]
747+ parent_path = (
748+ os .path .join (* ext_name .split ("." )[:- 1 ]) if "." in ext_name else ""
749+ )
750+ fixed_name = f"{ last_part } .pyd"
751+ print (
752+ f"[INFO]get_ext_filename: Fixing inplace broken .pyd name: pyd -> { fixed_name } "
753+ )
754+ return os .path .join (parent_path , fixed_name ) if parent_path else fixed_name
755+ return filename
756+
694757
695758if __name__ == "__main__" :
696759 if os .getenv ("BUILD_VERSION" ): # In CI
0 commit comments