diff --git a/WORKSPACE b/WORKSPACE index c45d35109748..f2042fde55e5 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -35,18 +35,18 @@ python_configure( ################################ PyTorch Setup ################################ load("//bazel:dependencies.bzl", "PYTORCH_LOCAL_DIR") +load("//bazel:torch_repo.bzl", "torch_repo") -new_local_repository( +torch_repo( name = "torch", - build_file = "//bazel:torch.BUILD", - path = PYTORCH_LOCAL_DIR, + dist_dir = "../dist", ) ############################# OpenXLA Setup ############################### # To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to # the openxla git commit hash and note the date of the commit. -xla_hash = '9ac36592456e7be0d66506be75fbdacc90dd4e91' # Committed on 2025-06-11. +xla_hash = "9ac36592456e7be0d66506be75fbdacc90dd4e91" # Committed on 2025-06-11. http_archive( name = "xla", @@ -66,8 +66,6 @@ http_archive( ], ) - - # For development, one often wants to make changes to the OpenXLA repository as well # as the PyTorch/XLA repository. You can override the pinned repository above with a # local checkout by either: @@ -89,14 +87,14 @@ python_init_rules() load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") python_init_repositories( + default_python_version = "system", + local_wheel_workspaces = ["@torch//:WORKSPACE"], requirements = { "3.8": "//:requirements_lock_3_8.txt", "3.9": "//:requirements_lock_3_9.txt", "3.10": "//:requirements_lock_3_10.txt", "3.11": "//:requirements_lock_3_11.txt", }, - local_wheel_workspaces = ["@torch//:WORKSPACE"], - default_python_version = "system", ) load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") @@ -111,8 +109,6 @@ load("@pypi//:requirements.bzl", "install_deps") install_deps() - - # Initialize OpenXLA's external dependencies. load("@xla//:workspace4.bzl", "xla_workspace4") @@ -134,7 +130,6 @@ load("@xla//:workspace0.bzl", "xla_workspace0") xla_workspace0() - load( "@xla//third_party/gpus:cuda_configure.bzl", "cuda_configure", diff --git a/bazel/torch.BUILD b/bazel/torch.BUILD deleted file mode 100644 index afc6bb57af9e..000000000000 --- a/bazel/torch.BUILD +++ /dev/null @@ -1,58 +0,0 @@ -package( - default_visibility = [ - "//visibility:public", - ], -) - -cc_library( - name = "headers", - hdrs = glob( - ["torch/include/**/*.h"], - ["torch/include/google/protobuf/**/*.h"], - ), - strip_include_prefix = "torch/include", -) - -# Runtime headers, for importing . -cc_library( - name = "runtime_headers", - hdrs = glob(["torch/include/torch/csrc/api/include/**/*.h"]), - strip_include_prefix = "torch/include/torch/csrc/api/include", -) - -filegroup( - name = "torchgen_deps", - srcs = [ - "aten/src/ATen/native/native_functions.yaml", - "aten/src/ATen/native/tags.yaml", - "aten/src/ATen/native/ts_native_functions.yaml", - "aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp", - "aten/src/ATen/templates/DispatchKeyNativeFunctions.h", - "aten/src/ATen/templates/LazyIr.h", - "aten/src/ATen/templates/LazyNonNativeIr.h", - "aten/src/ATen/templates/RegisterDispatchDefinitions.ini", - "aten/src/ATen/templates/RegisterDispatchKey.cpp", - "torch/csrc/lazy/core/shape_inference.h", - "torch/csrc/lazy/ts_backend/ts_native_functions.cpp", - ], -) - -cc_import( - name = "libtorch", - shared_library = "build/lib/libtorch.so", -) - -cc_import( - name = "libtorch_cpu", - shared_library = "build/lib/libtorch_cpu.so", -) - -cc_import( - name = "libtorch_python", - shared_library = "build/lib/libtorch_python.so", -) - -cc_import( - name = "libc10", - shared_library = "build/lib/libc10.so", -) diff --git a/bazel/torch_repo.bzl b/bazel/torch_repo.bzl new file mode 100644 index 000000000000..d29331e4788f --- /dev/null +++ b/bazel/torch_repo.bzl @@ -0,0 +1,100 @@ +"""Repository rule to setup a torch repo.""" + +_BUILD_TEMPLATE = """ +# Generated by //bazel:torch_repo.bzl + +load("@//bazel:torch_targets.bzl", "define_torch_targets") + +package( + default_visibility = [ + "//visibility:public", + ], +) + +define_torch_targets() +""".lstrip() + +def _get_url_basename(url): + basename = url.rpartition("/")[2] + + # Starlark doesn't have any URL decode functions, so just approximate + # one with the cases we see. + return basename.replace("%2B", "+") + +def _torch_repo_impl(rctx): + rctx.file("BUILD.bazel", _BUILD_TEMPLATE) + + env_torch_whl = rctx.os.environ.get("TORCH_WHL", "") + + urls = None + local_path = None + if env_torch_whl: + if env_torch_whl.startswith("http"): + urls = [env_torch_whl] + else: + local_path = rctx.path(env_torch_whl) + else: + dist_dir = rctx.workspace_root.get_child(rctx.attr.dist_dir) + + if dist_dir.exists: + for child in dist_dir.readdir(): + # For lack of a better option, take the first match + if child.basename.endswith(".whl"): + local_path = child + break + + if not local_path and not urls: + fail(( + "No torch wheel source configured:\n" + + "* Set TORCH_WHL environment variable to a local path or URL.\n" + + "* Or ensure the {dist_dir} directory is present with a torch wheel." + + "\n" + ).format( + dist_dir = dist_dir, + )) + + if local_path: + whl_path = local_path + if not whl_path.exists: + fail("File not found: {}".format(whl_path)) + + # The dist/ directory is necessary for XLA's python_init_repositories + # to discover the wheel and add it to requirements.txt + rctx.symlink(whl_path, "dist/{}".format(whl_path.basename)) + elif urls: + whl_basename = _get_url_basename(urls[0]) + + # The dist/ directory is necessary for XLA's python_init_repositories + # to discover the wheel and add it to requirements.txt + whl_path = rctx.path("dist/{}".format(whl_basename)) + result = rctx.download( + url = urls, + output = whl_path, + ) + if not result.success: + fail("Failed to download: {}", urls) + + # Extract into the repo root. Also use .zip as the extension so that extract + # recognizes the file type. + # Use the whl basename so progress messages are more informative. + whl_zip = whl_path.basename.replace(".whl", ".zip") + rctx.symlink(whl_path, whl_zip) + rctx.extract(whl_zip) + rctx.delete(whl_zip) + +torch_repo = repository_rule( + implementation = _torch_repo_impl, + doc = """ +Creates a repository with torch headers, shared libraries, and wheel +for integration with Bazel. +""", + attrs = { + "dist_dir": attr.string( + doc = """ +Directory with a prebuilt torch wheel. Typically points to a source checkout +that built a torch wheel. Relative paths are relative to the workspace root. +""", + ), + }, + environ = ["TORCH_WHL"], +) diff --git a/bazel/torch_repo_targets.bzl b/bazel/torch_repo_targets.bzl new file mode 100644 index 000000000000..c3682a710ecd --- /dev/null +++ b/bazel/torch_repo_targets.bzl @@ -0,0 +1,64 @@ +"""Handles the loading phase to define targets for torch_repo.""" + +cc_library = native.cc_library + +def define_torch_targets(): + cc_library( + name = "headers", + hdrs = native.glob( + ["torch/include/**/*.h"], + ["torch/include/google/protobuf/**/*.h"], + ), + strip_include_prefix = "torch/include", + ) + + # Runtime headers, for importing . + cc_library( + name = "runtime_headers", + hdrs = native.glob(["torch/include/torch/csrc/api/include/**/*.h"]), + strip_include_prefix = "torch/include/torch/csrc/api/include", + ) + + native.filegroup( + name = "torchgen_deps", + srcs = [ + # torchgen/packaged/ instead of aten/src + "torchgen/packaged/ATen/native/native_functions.yaml", + "torchgen/packaged/ATen/native/tags.yaml", + ##"torchgen/packaged/ATen/native/ts_native_functions.yaml", + "torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp", + "torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h", + "torchgen/packaged/ATen/templates/LazyIr.h", + "torchgen/packaged/ATen/templates/LazyNonNativeIr.h", + "torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini", + "torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp", + # Add torch/include prefix + "torch/include/torch/csrc/lazy/core/shape_inference.h", + ##"torch/csrc/lazy/ts_backend/ts_native_functions.cpp", + ], + ) + + # Changed to cc_library from cc_import + + cc_library( + name = "libtorch", + srcs = ["torch/lib/libtorch.so"], + ) + + cc_library( + name = "libtorch_cpu", + srcs = ["torch/lib/libtorch_cpu.so"], + ) + + cc_library( + name = "libtorch_python", + srcs = [ + "torch/lib/libshm.so", # libtorch_python.so depends on this + "torch/lib/libtorch_python.so", + ], + ) + + cc_library( + name = "libc10", + srcs = ["torch/lib/libc10.so"], + ) diff --git a/codegen/lazy_tensor_generator.py b/codegen/lazy_tensor_generator.py index c596f5c999dd..6f61b175e43d 100644 --- a/codegen/lazy_tensor_generator.py +++ b/codegen/lazy_tensor_generator.py @@ -21,9 +21,9 @@ kernel_signature, ) -aten_path = os.path.join(torch_root, "aten", "src", "ATen") -shape_inference_hdr = os.path.join(torch_root, "torch", "csrc", "lazy", "core", - "shape_inference.h") +aten_path = os.path.join(torch_root, "torchgen", "packaged", "ATen") +shape_inference_hdr = os.path.join(torch_root, "torch", "include", + "torch", "csrc", "lazy", "core", "shape_inference.h") impl_path = os.path.join(xla_root, "__main__", "torch_xla/csrc/aten_xla_type.cpp") source_yaml = sys.argv[2]