Skip to content

Commit 9eb546b

Browse files
authored
Windows compile Fix (#1284)
* adding changes for Windows compile Fix. Signed-off-by: LosCrossos <[email protected]> * Update setup.py with lint * Update setup.py with lint again * lint Signed-off-by: LosCrossos <[email protected]> --------- Signed-off-by: LosCrossos <[email protected]>
1 parent 17504e8 commit 9eb546b

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

setup.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ def get_extensions():
461461

462462
extra_compile_args = {"cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"]}
463463
if sys.platform == "win32":
464+
if os.getenv("DISTUTILS_USE_SDK") == "1":
465+
extra_compile_args = {
466+
"cxx": ["-O2", "/std:c++17", "/DPy_LIMITED_API=0x03090000"]
467+
}
464468
define_macros += [("xformers_EXPORTS", None)]
465469
extra_compile_args["cxx"].extend(
466470
["/MP", "/Zc:lambda", "/Zc:preprocessor", "/Zc:__cplusplus"]
@@ -671,6 +675,31 @@ def __init__(self, *args, **kwargs) -> None:
671675

672676
def build_extensions(self) -> None:
673677
super().build_extensions()
678+
679+
# Fix incorrect output names caused by py_limited_api=True on Windows. see item #1272
680+
for ext in self.extensions:
681+
ext_path_parts = ext.name.split(".")
682+
ext_basename = ext_path_parts[-1]
683+
ext_subpath = os.path.join(
684+
*ext_path_parts[:-1]
685+
) # xformers, xformers/flash_attn_3, etc.
686+
687+
# Directory where the .pyd was written
688+
output_dir = os.path.join(self.build_lib, ext_subpath)
689+
690+
# Expected correct filename
691+
correct_name = os.path.join(output_dir, f"{ext_basename}.pyd")
692+
693+
# But py_limited_api may incorrectly write it as just "pyd"
694+
broken_name = os.path.join(output_dir, "pyd")
695+
if os.path.exists(broken_name) and not os.path.exists(correct_name):
696+
import shutil
697+
698+
print(
699+
f"[INFO]build_extensions: Fixing broken .pyd name: {broken_name} -> {correct_name}"
700+
)
701+
shutil.move(broken_name, correct_name)
702+
674703
for filename, content in self.xformers_build_metadata.items():
675704
with open(
676705
os.path.join(self.build_lib, self.pkg_name, filename), "w+"
@@ -685,12 +714,46 @@ def copy_extensions_to_source(self) -> None:
685714
build_py = self.get_finalized_command("build_py")
686715
package_dir = build_py.get_package_dir(self.pkg_name)
687716

717+
# Fix for windows when using py_limited_api=True. see #1272
718+
for ext in self.extensions:
719+
ext_path_parts = ext.name.split(".")
720+
ext_basename = ext_path_parts[-1]
721+
ext_subpath = os.path.join(*ext_path_parts[:-1])
722+
build_dir = os.path.join(self.build_lib, ext_subpath)
723+
724+
correct_name = os.path.join(build_dir, f"{ext_basename}.pyd")
725+
broken_name = os.path.join(build_dir, "pyd")
726+
if os.path.exists(broken_name) and not os.path.exists(correct_name):
727+
import shutil
728+
729+
print(
730+
f"[INFO]copy_extensions_to_source: Fixing inplace broken .pyd name: {broken_name} -> {correct_name}"
731+
)
732+
shutil.move(broken_name, correct_name)
733+
688734
for filename in self.xformers_build_metadata.keys():
689735
inplace_file = os.path.join(package_dir, filename)
690736
regular_file = os.path.join(self.build_lib, self.pkg_name, filename)
691737
self.copy_file(regular_file, inplace_file, level=self.verbose)
692738
super().copy_extensions_to_source()
693739

740+
def get_ext_filename(self, ext_name):
741+
filename = super().get_ext_filename(ext_name)
742+
# Fix for windows when using py_limited_api=True. see #1272
743+
# If setuptools returns a bogus 'pyd' filename, fix it.
744+
if os.path.basename(filename) == "pyd":
745+
# Extract the final component of the ext_name (after last dot)
746+
last_part = ext_name.rsplit(".", 1)[-1]
747+
parent_path = (
748+
os.path.join(*ext_name.split(".")[:-1]) if "." in ext_name else ""
749+
)
750+
fixed_name = f"{last_part}.pyd"
751+
print(
752+
f"[INFO]get_ext_filename: Fixing inplace broken .pyd name: pyd -> {fixed_name}"
753+
)
754+
return os.path.join(parent_path, fixed_name) if parent_path else fixed_name
755+
return filename
756+
694757

695758
if __name__ == "__main__":
696759
if os.getenv("BUILD_VERSION"): # In CI

0 commit comments

Comments
 (0)