diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index d4c25a7..250bf8c 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -35,6 +35,7 @@ impl Build { Kernel::Cuda { .. } => Backend::Cuda, Kernel::Metal { .. } => Backend::Metal, Kernel::Rocm { .. } => Backend::Rocm, + Kernel::Xpu { .. } => Backend::Xpu, }) .collect() } @@ -111,6 +112,14 @@ pub enum Kernel { include: Option>, src: Vec, }, + #[serde(rename_all = "kebab-case")] + Xpu { + cxx_flags: Option>, + depends: Vec, + sycl_flags: Option>, + include: Option>, + src: Vec, + }, } impl Kernel { @@ -118,7 +127,8 @@ impl Kernel { match self { Kernel::Cuda { cxx_flags, .. } | Kernel::Metal { cxx_flags, .. } - | Kernel::Rocm { cxx_flags, .. } => cxx_flags.as_deref(), + | Kernel::Rocm { cxx_flags, .. } + | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(), } } @@ -126,7 +136,8 @@ impl Kernel { match self { Kernel::Cuda { include, .. } | Kernel::Metal { include, .. } - | Kernel::Rocm { include, .. } => include.as_deref(), + | Kernel::Rocm { include, .. } + | Kernel::Xpu { include, .. } => include.as_deref(), } } @@ -135,6 +146,7 @@ impl Kernel { Kernel::Cuda { .. } => Backend::Cuda, Kernel::Metal { .. } => Backend::Metal, Kernel::Rocm { .. } => Backend::Rocm, + Kernel::Xpu { .. } => Backend::Xpu, } } @@ -142,13 +154,17 @@ impl Kernel { match self { Kernel::Cuda { depends, .. } | Kernel::Metal { depends, .. } - | Kernel::Rocm { depends, .. } => depends, + | Kernel::Rocm { depends, .. } + | Kernel::Xpu { depends, .. } => depends, } } pub fn src(&self) -> &[String] { match self { - Kernel::Cuda { src, .. } | Kernel::Metal { src, .. } | Kernel::Rocm { src, .. } => src, + Kernel::Cuda { src, .. } + | Kernel::Metal { src, .. } + | Kernel::Rocm { src, .. } + | Kernel::Xpu { src, .. } => src, } } } @@ -159,6 +175,7 @@ pub enum Backend { Cuda, Metal, Rocm, + Xpu, } impl Display for Backend { @@ -167,6 +184,7 @@ impl Display for Backend { Backend::Cuda => write!(f, "cuda"), Backend::Metal => write!(f, "metal"), Backend::Rocm => write!(f, "rocm"), + Backend::Xpu => write!(f, "xpu"), } } } @@ -179,6 +197,7 @@ impl FromStr for Backend { "cuda" => Ok(Backend::Cuda), "metal" => Ok(Backend::Metal), "rocm" => Ok(Backend::Rocm), + "xpu" => Ok(Backend::Xpu), _ => Err(format!("Unknown backend: {s}")), } } diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index 5f7f8ec..3d9be9e 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -9,7 +9,9 @@ use eyre::{bail, ensure, Context, Result}; use minijinja::Environment; mod torch; -use torch::{write_torch_ext_cuda, write_torch_ext_metal, write_torch_ext_universal}; +use torch::{ + write_torch_ext_cuda, write_torch_ext_metal, write_torch_ext_universal, write_torch_ext_xpu, +}; mod config; use config::{Backend, Build, BuildCompat}; @@ -180,6 +182,7 @@ fn generate_torch( write_torch_ext_cuda(&env, backend, &build, target_dir.clone(), ops_id)? } Backend::Metal => write_torch_ext_metal(&env, &build, target_dir.clone(), ops_id)?, + Backend::Xpu => write_torch_ext_xpu(&env, &build, target_dir.clone(), ops_id)?, }; file_set.write(&target_dir, force)?; @@ -379,6 +382,7 @@ fn get_generated_files( Backend::Metal => { write_torch_ext_metal(env, build, target_dir.clone(), ops_id.clone())? } + Backend::Xpu => write_torch_ext_xpu(env, build, target_dir.clone(), ops_id.clone())?, }; all_set.extend(set); diff --git a/build2cmake/src/templates/xpu/kernel.cmake b/build2cmake/src/templates/xpu/kernel.cmake new file mode 100644 index 0000000..020409a --- /dev/null +++ b/build2cmake/src/templates/xpu/kernel.cmake @@ -0,0 +1,48 @@ +set({{kernel_name}}_SRC + {{ sources }} +) + +{% if includes %} +# TODO: check if CLion support this: +# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories +set_source_files_properties( + {{'${' + kernel_name + '_SRC}'}} + PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") +{% endif %} + +{% if cxx_flags %} +foreach(_KERNEL_SRC {{'${' + kernel_name + '_SRC}'}}) + set_property( + SOURCE ${_KERNEL_SRC} + APPEND PROPERTY + COMPILE_OPTIONS "$<$:{{ cxx_flags }}>" + ) +endforeach() +{% endif %} + +# Add SYCL-specific compilation flags for XPU sources +{% if sycl_flags %} +# Use kernel-specific SYCL flags +foreach(_KERNEL_SRC {{'${' + kernel_name + '_SRC}'}}) + if(_KERNEL_SRC MATCHES ".*\\.(cpp|cxx|cc)$") + set_property( + SOURCE ${_KERNEL_SRC} + APPEND PROPERTY + COMPILE_OPTIONS "$<$:{{ sycl_flags }}>" + ) + endif() +endforeach() +{% else %} +# Use default SYCL flags +foreach(_KERNEL_SRC {{'${' + kernel_name + '_SRC}'}}) + if(_KERNEL_SRC MATCHES ".*\\.(cpp|cxx|cc)$") + set_property( + SOURCE ${_KERNEL_SRC} + APPEND PROPERTY + COMPILE_OPTIONS "$<$:${sycl_flags}>" + ) + endif() +endforeach() +{% endif %} + +list(APPEND SRC {{'"${' + kernel_name + '_SRC}"'}}) diff --git a/build2cmake/src/templates/xpu/preamble.cmake b/build2cmake/src/templates/xpu/preamble.cmake new file mode 100644 index 0000000..2f0369e --- /dev/null +++ b/build2cmake/src/templates/xpu/preamble.cmake @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.26) + +# Set Intel SYCL compiler before project() call +find_program(ICPX_COMPILER icpx) +if(ICPX_COMPILER) + set(CMAKE_CXX_COMPILER ${ICPX_COMPILER}) + message(STATUS "Using Intel SYCL compiler: ${ICPX_COMPILER}") +else() + message(FATAL_ERROR "Intel SYCL compiler (icpx) not found. Please install Intel oneAPI toolkit.") +endif() + +project({{ name }}) + +include("cmake/utils.cmake") + +# Find Python with all necessary components for building extensions +find_package(Python REQUIRED COMPONENTS Interpreter Development.Module Development.SABIModule) + +append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") + +find_package(Torch REQUIRED) + +# Intel XPU backend detection and setup +if(NOT TORCH_VERSION) + run_python(TORCH_VERSION "import torch; print(torch.__version__)" "Failed to get Torch version") +endif() + +# Check for Intel XPU support in PyTorch +run_python(XPU_AVAILABLE + "import torch; print('true' if hasattr(torch, 'xpu') and torch.xpu.is_available() else 'false')" + "Failed to check XPU availability") + +if(NOT XPU_AVAILABLE STREQUAL "true") + message(WARNING "Intel XPU is not available in this PyTorch installation. XPU kernels will be skipped.") + return() +endif() + +# Set up XPU compilation flags +set(GPU_LANG "SYCL") +add_compile_definitions(XPU_KERNEL) +add_compile_definitions(USE_XPU) + +# Set SYCL-specific flags +# Set comprehensive SYCL compilation and linking flags +set(sycl_link_flags "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'") +set(sycl_flags "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;") +message(STATUS "Configuring for Intel XPU backend using SYCL") diff --git a/build2cmake/src/templates/xpu/setup.py b/build2cmake/src/templates/xpu/setup.py new file mode 100644 index 0000000..2a21cdf --- /dev/null +++ b/build2cmake/src/templates/xpu/setup.py @@ -0,0 +1,123 @@ +import logging +import os +from shutil import which, move +import subprocess +import sys +from pathlib import Path + +from setuptools import Extension, find_packages, setup +from setuptools.command.build_ext import build_ext + +logger = logging.getLogger(__name__) + + +def is_sccache_available() -> bool: + return which("sccache") is not None + + +def is_ccache_available() -> bool: + return which("ccache") is not None + + +def is_ninja_available() -> bool: + return which("ninja") is not None + + +class CMakeExtension(Extension): + def __init__(self, name: str, sourcedir: str = "") -> None: + super().__init__(name, sources=[], py_limited_api=True) + self.sourcedir = os.fspath(Path(sourcedir).resolve()) + + +class CMakeBuild(build_ext): + def build_extension(self, ext: CMakeExtension) -> None: + ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) + extdir = ext_fullpath.parent.resolve() + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", + f"-DPython_EXECUTABLE={sys.executable}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + build_args = [] + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + if not cmake_generator or cmake_generator == "Ninja": + try: + import ninja + + ninja_executable_path = Path(ninja.BIN_DIR) / "ninja" + cmake_args += [ + "-GNinja", + f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", + ] + except ImportError: + pass + + if is_sccache_available(): + cmake_args += [ + "-DCMAKE_C_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", + ] + elif is_ccache_available(): + cmake_args += [ + "-DCMAKE_C_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + ] + + num_jobs = os.getenv("MAX_JOBS", None) + if num_jobs is not None: + num_jobs = int(num_jobs) + logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) + else: + try: + # os.sched_getaffinity() isn't universally available, so fall + # back to os.cpu_count() if we get an error here. + num_jobs = len(os.sched_getaffinity(0)) + except AttributeError: + num_jobs = os.cpu_count() + + build_args += [f"-j{num_jobs}"] + if sys.platform == "win32": + build_args += ["--config", cfg] + + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + subprocess.run( + ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True + ) + subprocess.run( + ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True + ) + if sys.platform == "win32": + # Move the dylib one folder up for discovery. + for filename in os.listdir(extdir / cfg): + move(extdir / cfg / filename, extdir / filename) + + +setup( + name="{{ name }}", + # The version is just a stub, it's not used by the final build artefact. + version="0.1.0", + ext_modules=[CMakeExtension("{{ name }}.{{ ops_name }}")], + cmdclass={"build_ext": CMakeBuild}, + packages=find_packages(where="torch-ext", include=["{{ name }}*"]), + package_dir={"": "torch-ext"}, +{% if data_globs %} + package_data={"{{ name }}": [ {{ data_globs }} ]}, +{% endif %} + zip_safe=False, + install_requires=["torch"], + python_requires=">=3.9", +) diff --git a/build2cmake/src/templates/xpu/torch-binding.cmake b/build2cmake/src/templates/xpu/torch-binding.cmake new file mode 100644 index 0000000..8812c52 --- /dev/null +++ b/build2cmake/src/templates/xpu/torch-binding.cmake @@ -0,0 +1,13 @@ +set(TORCH_{{name}}_SRC + {{ src|join(' ') }} +) + +{% if includes %} +# TODO: check if CLion support this: +# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories +set_source_files_properties( + {{'${TORCH_' + name + '_SRC}'}} + PROPERTIES INCLUDE_DIRECTORIES "{{ includes }}") +{% endif %} + +list(APPEND SRC {{'"${TORCH_' + name + '_SRC}"'}}) \ No newline at end of file diff --git a/build2cmake/src/templates/xpu/torch-extension.cmake b/build2cmake/src/templates/xpu/torch-extension.cmake new file mode 100644 index 0000000..8b448f5 --- /dev/null +++ b/build2cmake/src/templates/xpu/torch-extension.cmake @@ -0,0 +1,11 @@ +define_gpu_extension_target( + {{ ops_name }} + DESTINATION {{ ops_name }} + LANGUAGE ${GPU_LANG} + SOURCES ${SRC} + COMPILE_FLAGS ${sycl_flags} + USE_SABI 3 + WITH_SOABI) + +# Add XPU/SYCL specific linker flags +target_link_options({{ ops_name }} PRIVATE ${sycl_link_flags}) diff --git a/build2cmake/src/torch/metal.rs b/build2cmake/src/torch/metal.rs index 8e9190b..4b1edcf 100644 --- a/build2cmake/src/torch/metal.rs +++ b/build2cmake/src/torch/metal.rs @@ -94,7 +94,11 @@ fn write_cmake( render_binding(env, torch, name, cmake_writer)?; - for (kernel_name, kernel) in &build.kernels { + for (kernel_name, kernel) in build + .kernels + .iter() + .filter(|(_, kernel)| matches!(kernel, Kernel::Metal { .. })) + { render_kernel(env, kernel_name, kernel, cmake_writer)?; } diff --git a/build2cmake/src/torch/mod.rs b/build2cmake/src/torch/mod.rs index f6dce91..d389693 100644 --- a/build2cmake/src/torch/mod.rs +++ b/build2cmake/src/torch/mod.rs @@ -9,3 +9,6 @@ pub(crate) use ops_identifier::kernel_ops_identifier; mod universal; pub use universal::write_torch_ext_universal; + +mod xpu; +pub use xpu::write_torch_ext_xpu; diff --git a/build2cmake/src/torch/xpu.rs b/build2cmake/src/torch/xpu.rs new file mode 100644 index 0000000..0c61f1a --- /dev/null +++ b/build2cmake/src/torch/xpu.rs @@ -0,0 +1,291 @@ +use std::collections::HashSet; +use std::io::Write; +use std::path::PathBuf; + +use eyre::{bail, Context, Result}; +use itertools::Itertools; +use minijinja::{context, Environment}; + +use super::kernel_ops_identifier; +use crate::config::{Build, Dependencies, Kernel, Torch}; +use crate::FileSet; + +static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake"); +static REGISTRATION_H: &str = include_str!("../templates/registration.h"); + +pub fn write_torch_ext_xpu( + env: &Environment, + build: &Build, + target_dir: PathBuf, + ops_id: Option, +) -> Result { + let torch_ext = match build.torch.as_ref() { + Some(torch_ext) => torch_ext, + None => bail!("Build configuration does not have `torch` section"), + }; + + let mut file_set = FileSet::default(); + + let ops_name = kernel_ops_identifier(&target_dir, &build.general.name, ops_id); + + write_cmake( + env, + build, + torch_ext, + &build.general.name, + &ops_name, + &mut file_set, + )?; + + write_setup_py( + env, + torch_ext, + &build.general.name, + &ops_name, + &mut file_set, + )?; + + write_ops_py(env, &build.general.name, &ops_name, &mut file_set)?; + + write_pyproject_toml(env, &mut file_set)?; + + write_torch_registration_macros(&mut file_set)?; + + Ok(file_set) +} + +fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> { + let mut path = PathBuf::new(); + path.push("torch-ext"); + path.push("registration.h"); + file_set + .entry(path) + .extend_from_slice(REGISTRATION_H.as_bytes()); + + Ok(()) +} + +fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> { + let writer = file_set.entry("pyproject.toml"); + + env.get_template("pyproject.toml") + .wrap_err("Cannot get pyproject.toml template")? + .render_to_write(context! {}, writer) + .wrap_err("Cannot render pyproject.toml template")?; + + Ok(()) +} + +fn write_setup_py( + env: &Environment, + torch: &Torch, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let writer = file_set.entry("setup.py"); + + let data_globs = torch.data_globs().map(|globs| globs.join(", ")); + + env.get_template("xpu/setup.py") + .wrap_err("Cannot get setup.py template")? + .render_to_write( + context! { + data_globs => data_globs, + ops_name => ops_name, + name => name, + version => "0.1.0", + }, + writer, + ) + .wrap_err("Cannot render setup.py template")?; + + Ok(()) +} + +fn write_ops_py( + env: &Environment, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let mut path = PathBuf::new(); + path.push("torch-ext"); + path.push(name); + path.push("_ops.py"); + let writer = file_set.entry(path); + + env.get_template("_ops.py") + .wrap_err("Cannot get _ops.py template")? + .render_to_write( + context! { + ops_name => ops_name, + }, + writer, + ) + .wrap_err("Cannot render _ops.py template")?; + + Ok(()) +} + +fn write_cmake( + env: &Environment, + build: &Build, + torch: &Torch, + name: &str, + ops_name: &str, + file_set: &mut FileSet, +) -> Result<()> { + let mut utils_path = PathBuf::new(); + utils_path.push("cmake"); + utils_path.push("utils.cmake"); + file_set + .entry(utils_path.clone()) + .extend_from_slice(CMAKE_UTILS.as_bytes()); + + let cmake_writer = file_set.entry("CMakeLists.txt"); + + render_preamble(env, name, cmake_writer)?; + + render_deps(build, cmake_writer)?; + + render_binding(env, torch, name, cmake_writer)?; + + for (kernel_name, kernel) in build + .kernels + .iter() + .filter(|(_, kernel)| matches!(kernel, Kernel::Xpu { .. })) + { + render_kernel(env, kernel_name, kernel, cmake_writer)?; + } + + render_extension(env, ops_name, cmake_writer)?; + + Ok(()) +} + +fn render_binding( + env: &Environment, + torch: &Torch, + name: &str, + write: &mut impl Write, +) -> Result<()> { + env.get_template("xpu/torch-binding.cmake") + .wrap_err("Cannot get Torch binding template")? + .render_to_write( + context! { + includes => torch.include.as_ref().map(prefix_and_join_includes), + name => name, + src => torch.src + }, + &mut *write, + ) + .wrap_err("Cannot render Torch binding template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +fn render_deps(build: &Build, write: &mut impl Write) -> Result<()> { + let mut deps = HashSet::new(); + + for kernel in build.kernels.values() { + deps.extend(kernel.depends()); + } + + for dep in deps { + match dep { + Dependencies::Torch => (), + _ => { + // XPU doesn't support CUTLASS dependencies yet + eprintln!("Warning: XPU backend doesn't support dependency: {dep:?}"); + } + } + write.write_all(b"\n")?; + } + + Ok(()) +} + +pub fn render_kernel( + env: &Environment, + kernel_name: &str, + kernel: &Kernel, + write: &mut impl Write, +) -> Result<()> { + // Easier to do in Rust than Jinja. + let sources = kernel + .src() + .iter() + .map(|src| format!("\"{src}\"")) + .collect_vec() + .join("\n"); + + let sycl_flags = match kernel { + Kernel::Xpu { sycl_flags, .. } => sycl_flags.as_deref(), + _ => unreachable!("Unsupported kernel type for XPU rendering"), + }; + + env.get_template("xpu/kernel.cmake") + .wrap_err("Cannot get kernel template")? + .render_to_write( + context! { + cxx_flags => kernel.cxx_flags().map(|flags| flags.join(";")), + sycl_flags => sycl_flags.map(|flags| flags.join(";")), + includes => kernel.include().map(prefix_and_join_includes), + kernel_name => kernel_name, + sources => sources, + }, + &mut *write, + ) + .wrap_err("Cannot render kernel template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +pub fn render_extension(env: &Environment, ops_name: &str, write: &mut impl Write) -> Result<()> { + env.get_template("xpu/torch-extension.cmake") + .wrap_err("Cannot get Torch extension template")? + .render_to_write( + context! { + ops_name => ops_name, + }, + &mut *write, + ) + .wrap_err("Cannot render Torch extension template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +pub fn render_preamble(env: &Environment, name: &str, write: &mut impl Write) -> Result<()> { + env.get_template("xpu/preamble.cmake") + .wrap_err("Cannot get CMake prelude template")? + .render_to_write( + context! { + name => name, + }, + &mut *write, + ) + .wrap_err("Cannot render CMake prelude template")?; + + write.write_all(b"\n")?; + + Ok(()) +} + +fn prefix_and_join_includes(includes: impl AsRef<[S]>) -> String +where + S: AsRef, +{ + includes + .as_ref() + .iter() + .map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref())) + .collect_vec() + .join(";") +} diff --git a/examples/relu/build.toml b/examples/relu/build.toml index 80c6886..2effffc 100644 --- a/examples/relu/build.toml +++ b/examples/relu/build.toml @@ -37,3 +37,8 @@ rocm-archs = [ ] depends = ["torch"] src = ["relu_cuda/relu.cu"] + +[kernel.activation_xpu] +backend = "xpu" +depends = ["torch"] +src = ["relu_xpu/relu.cpp"] diff --git a/examples/relu/relu_xpu/relu.cpp b/examples/relu/relu_xpu/relu.cpp new file mode 100644 index 0000000..1809de0 --- /dev/null +++ b/examples/relu/relu_xpu/relu.cpp @@ -0,0 +1,40 @@ +#include +#include + +using namespace sycl; + +void relu_xpu_impl(torch::Tensor& output, const torch::Tensor& input) { + // Create SYCL queue directly + sycl::queue queue; + + auto input_ptr = input.data_ptr(); + auto output_ptr = output.data_ptr(); + auto numel = input.numel(); + + // Launch SYCL kernel + queue.parallel_for(range<1>(numel), [=](id<1> idx) { + auto i = idx[0]; + output_ptr[i] = input_ptr[i] > 0.0f ? input_ptr[i] : 0.0f; + }).wait(); +} + +void relu(torch::Tensor& out, const torch::Tensor& input) { + TORCH_CHECK(input.device().is_xpu(), "input must be a XPU tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.scalar_type() == torch::kFloat, + "Unsupported data type: ", input.scalar_type()); + + TORCH_CHECK(input.sizes() == out.sizes(), + "Tensors must have the same shape. Got input shape: ", + input.sizes(), " and output shape: ", out.sizes()); + + TORCH_CHECK(input.scalar_type() == out.scalar_type(), + "Tensors must have the same data type. Got input dtype: ", + input.scalar_type(), " and output dtype: ", out.scalar_type()); + + TORCH_CHECK(input.device() == out.device(), + "Tensors must be on the same device. Got input device: ", + input.device(), " and output device: ", out.device()); + + relu_xpu_impl(out, input); +} diff --git a/examples/relu/tests/test_relu.py b/examples/relu/tests/test_relu.py index 98b292b..d2adc05 100644 --- a/examples/relu/tests/test_relu.py +++ b/examples/relu/tests/test_relu.py @@ -9,6 +9,8 @@ def test_relu(): if platform.system() == "Darwin": device = torch.device("mps") + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") else: device = torch.device("cuda") x = torch.randn(1024, 1024, dtype=torch.float32, device=device) diff --git a/examples/relu/torch-ext/torch_binding.cpp b/examples/relu/torch-ext/torch_binding.cpp index a854951..8b50483 100644 --- a/examples/relu/torch-ext/torch_binding.cpp +++ b/examples/relu/torch-ext/torch_binding.cpp @@ -9,6 +9,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("relu", torch::kCUDA, &relu); #elif defined(METAL_KERNEL) ops.impl("relu", torch::kMPS, relu); +#elif defined(XPU_KERNEL) + ops.impl("relu", torch::kXPU, &relu); #endif }