Skip to content

Add SYCL Kernels for XPU backend #1679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
dd7b173
Add SYCL Kernels for XPU backend
xiaolil1 Jun 15, 2025
df93cdd
Merge pull request #1 from xiaolil1/jiqing
xiaolil1 Jun 16, 2025
872aa02
fix transpose
jiqing-feng Jun 16, 2025
04437a3
fix log and format
jiqing-feng Jun 16, 2025
d585bea
revert cpu changes
jiqing-feng Jun 16, 2025
1781611
clean ipex_xpu
jiqing-feng Jun 16, 2025
c982781
clean ipex import
jiqing-feng Jun 16, 2025
a4c5f8c
fix ipex cpu import
jiqing-feng Jun 16, 2025
4f076bb
fix typo
jiqing-feng Jun 16, 2025
76d7178
fix comments
jiqing-feng Jun 16, 2025
b31ea62
Merge pull request #2 from xiaolil1/jiqing
xiaolil1 Jun 16, 2025
452aa84
refine gemv_4bit kernel
xiaolil1 Jun 17, 2025
e8ac8b5
Merge branch 'main' into main
jiqing-feng Jun 17, 2025
8620a95
enable FP4 for dequant_4bit and gemv_4bit
xiaolil1 Jun 17, 2025
00f064b
refine FP4 dequantization performance
xiaolil1 Jun 17, 2025
d60750f
remove check for better performance
jiqing-feng Jun 17, 2025
59f2aa8
Merge pull request #3 from xiaolil1/jiqing
xiaolil1 Jun 17, 2025
aad358f
fix doc
jiqing-feng Jun 17, 2025
45e4451
Merge pull request #4 from xiaolil1/jiqing
xiaolil1 Jun 17, 2025
1e21ee9
clean code
xiaolil1 Jun 18, 2025
4e7f5c1
Merge branch 'main' into main
xiaolil1 Jun 18, 2025
1601652
fix tests
jiqing-feng Jun 18, 2025
1cc25ff
rm comments
jiqing-feng Jun 18, 2025
c44f38e
Merge pull request #5 from xiaolil1/jiqing
xiaolil1 Jun 18, 2025
9f283bd
fix memory issue
xiaolil1 Jun 20, 2025
9897eae
fix ut failure
xiaolil1 Jun 20, 2025
411a276
adjust threshold
jiqing-feng Jun 20, 2025
b6a3524
fix xpu check
jiqing-feng Jun 20, 2025
1c4f478
change test_functional check
jiqing-feng Jun 20, 2025
e5cf821
fix test_module
jiqing-feng Jun 20, 2025
502fe83
Merge pull request #6 from xiaolil1/jiqing
xiaolil1 Jun 20, 2025
8b54381
fix device check
jiqing-feng Jun 23, 2025
1e0f661
Merge pull request #7 from xiaolil1/jiqing_test
jiqing-feng Jun 23, 2025
99698d2
fix tests
jiqing-feng Jun 23, 2025
b88236a
Merge pull request #8 from xiaolil1/jiqing
jiqing-feng Jun 23, 2025
56c48bc
Merge branch 'main' into main
jiqing-feng Jun 24, 2025
302413e
Merge branch 'main' into main
jiqing-feng Jun 25, 2025
685962c
Enable Windows build and refine code
xiaolil1 Jun 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
# C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES})

set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)

if(APPLE)
Expand Down Expand Up @@ -64,10 +65,18 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps")
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON)
elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
if(APPLE)
message(FATAL_ERROR "XPU is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU ON)
else()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU OFF)
endif()


Expand Down Expand Up @@ -217,6 +226,15 @@ elseif(BUILD_MPS)
COMMENT "Compiling Metal kernels"
VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
elseif(BUILD_XPU)
list(APPEND SRC_FILES ${XPU_FILES})
string(APPEND BNB_OUTPUT_NAME "_xpu")
add_compile_definitions(BUILD_XPU)
set(CMAKE_C_COMPILER icx)
set(CMAKE_CXX_COMPILER icpx)
if(WIN32)
set(CMAKE_CXX_COMPILER icx)
endif()
else()
string(APPEND BNB_OUTPUT_NAME "_cpu")
set(GPU_SOURCES)
Expand Down Expand Up @@ -285,6 +303,15 @@ if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
endif()
if(BUILD_XPU)
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")

set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20)
target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS})
target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS})

endif()

if(WIN32)
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .cextension import ipex_cpu, ipex_xpu
from .utils import ipex_cpu

_IS_TORCH_GTE_24 = False

Expand Down Expand Up @@ -331,7 +331,7 @@ def _(
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")


if ipex_cpu or ipex_xpu:
if ipex_cpu:
# Register the dequantize_nf4_ipex implementation
torch.library.define(
"bitsandbytes::dequantize_nf4_ipex",
Expand Down
8 changes: 2 additions & 6 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing_extensions import deprecated

import bitsandbytes.functional as F
from bitsandbytes.functional import ipex_cpu, ipex_xpu

# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
Expand Down Expand Up @@ -320,8 +319,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):

CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A, CB, bias)
# to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu]
state.idx = False
ctx.state = state
ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape
Expand Down Expand Up @@ -426,7 +423,7 @@ def matmul(
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu):
if A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)

Expand All @@ -440,7 +437,7 @@ def matmul_4bit(
):
assert quant_state is not None

if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if A.device.type == "cpu" and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
# IPEX CPU will change weight to 4D so don't need transpose
B = B.t() if B.dim() == 2 else B
Expand All @@ -450,7 +447,6 @@ def matmul_4bit(
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..._ops import register_kernel
from ...cextension import lib
from ..utils import ipex_cpu
from ...utils import ipex_cpu

# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
Expand Down
10 changes: 0 additions & 10 deletions bitsandbytes/backends/utils.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
from packaging import version
import torch

try:
# to support Intel CPU/XPU (IPEX) backend
import intel_extension_for_pytorch as ipex

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_cpu = None
ipex_xpu = None

try:
import triton # noqa: F401
import triton.language as tl # noqa: F401
Expand Down
Empty file modified bitsandbytes/backends/xpu/__init__.py
100755 → 100644
Empty file.
205 changes: 177 additions & 28 deletions bitsandbytes/backends/xpu/ops.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,58 +1,205 @@
from collections.abc import Sequence
import ctypes as ct
import warnings

import torch

from bitsandbytes.functional import _get_tensor_stream, get_ptr

from ..._ops import register_kernel
from ..utils import ipex_xpu, triton_available
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
from ..utils import triton_available


def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
args = (
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(out.numel()),
_get_tensor_stream(A),
)
if dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(*args)
else:
lib.cdequantize_blockwise_bf16_nf4(*args)
elif dtype == torch.float16:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(*args)
else:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif dtype == torch.float32:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(*args)
else:
lib.cdequantize_blockwise_fp32_nf4(*args)


# _int_mm is available in torch starting from 2.7 version,
# but currently it's don't have xpu implementation.
if ipex_xpu and torch.__version__ >= (2, 7):
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
args = (
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(A.numel()),
_get_tensor_stream(A),
)
if dtype == torch.float16:
lib.cdequantize_blockwise_fp16(*args)
elif dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(*args)
elif dtype == torch.float32:
lib.cdequantize_blockwise_fp32(*args)

@register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
def _(A: torch.Tensor, B: torch.Tensor):
return torch._int_mm(
A.reshape(-1, A.shape[-1]),
B.t(),
).reshape(*A.shape[:-1], B.shape[0])

def _gemv_4bit_impl(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
m = ct.c_int32(1)
n = ct.c_int32(shapeB[0])
k = ct.c_int32(shapeB[1])

# IPEX should be faster for xpu, so at first checking if it is available.
if ipex_xpu:
lda = m
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
ldc = m

@register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu")
stream = _get_tensor_stream(A)
if A.dtype == torch.float16:
lib.cgemv_4bit_inference_fp16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.bfloat16:
lib.cgemv_4bit_inference_bf16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.float32:
lib.cgemv_4bit_inference_fp32(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)


# SYCL should be faster for xpu, so at first checking if it is available.
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):

@register_kernel("bitsandbytes::dequantize_4bit", "xpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype)
out = torch.empty(shape, dtype=dtype, device=A.device)
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out

@register_kernel("bitsandbytes::dequantize_blockwise", "xpu")
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
out = torch.empty_like(A, dtype=dtype)
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
return out

@register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)

@register_kernel("bitsandbytes::gemv_4bit", "xpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
shape = A.shape
out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device)
# void cdequantize_blockwise_fp32(
# float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
if dtype == torch.float16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
elif dtype == torch.bfloat16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
elif dtype == torch.float32:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
shape = (*A.shape[:-1], shapeB[0])
out = torch.empty(shape, device=A.device, dtype=A.dtype)
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
return out

return out.reshape(shape)
@register_kernel("bitsandbytes::gemv_4bit.out", "xpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
elif triton_available:
from ..triton import ops as triton_ops

Expand All @@ -64,4 +211,6 @@ def _(
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
else:
warnings.warn("XPU available but no ipex or triton packages found.")
warnings.warn(
"XPU available but no native library or triton packages found. Please follow the installation instructions in the documentation."
)
Loading