diff --git a/CMakeLists.txt b/CMakeLists.txt index 770b4ba30..429570443 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) +set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -64,10 +65,18 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps") set(BUILD_CUDA OFF) set(BUILD_HIP OFF) set(BUILD_MPS ON) +elseif(${COMPUTE_BACKEND} STREQUAL "xpu") + if(APPLE) + message(FATAL_ERROR "XPU is not supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_MPS OFF) + set(BUILD_XPU ON) else() set(BUILD_CUDA OFF) set(BUILD_HIP OFF) set(BUILD_MPS OFF) + set(BUILD_XPU OFF) endif() @@ -217,6 +226,15 @@ elseif(BUILD_MPS) COMMENT "Compiling Metal kernels" VERBATIM) add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_XPU) + list(APPEND SRC_FILES ${XPU_FILES}) + string(APPEND BNB_OUTPUT_NAME "_xpu") + add_compile_definitions(BUILD_XPU) + set(CMAKE_C_COMPILER icx) + set(CMAKE_CXX_COMPILER icpx) + if(WIN32) + set(CMAKE_CXX_COMPILER icx) + endif() else() string(APPEND BNB_OUTPUT_NAME "_cpu") set(GPU_SOURCES) @@ -285,6 +303,15 @@ if(BUILD_MPS) add_dependencies(bitsandbytes metallib) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") endif() +if(BUILD_XPU) + set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'") + set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;") + + set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20) + target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS}) + target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS}) + +endif() if(WIN32) set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index a260852f5..56bfaa357 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -4,7 +4,7 @@ import torch -from .cextension import ipex_cpu, ipex_xpu +from .utils import ipex_cpu _IS_TORCH_GTE_24 = False @@ -331,7 +331,7 @@ def _( torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") -if ipex_cpu or ipex_xpu: +if ipex_cpu: # Register the dequantize_nf4_ipex implementation torch.library.define( "bitsandbytes::dequantize_nf4_ipex", diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 80fc86861..c28b301b9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -8,7 +8,6 @@ from typing_extensions import deprecated import bitsandbytes.functional as F -from bitsandbytes.functional import ipex_cpu, ipex_xpu # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py @@ -320,8 +319,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) output = torch.nn.functional.linear(A, CB, bias) - # to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu] - state.idx = False ctx.state = state ctx.dtype_A = A.dtype ctx.grad_shape = A.shape @@ -426,7 +423,7 @@ def matmul( state.threshold = threshold # MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU if state.is_training: - if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu): + if A.device.type in ("cpu", "xpu"): return MatMul8bitFp.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state) @@ -440,7 +437,7 @@ def matmul_4bit( ): assert quant_state is not None - if A.device.type in ("cpu", "xpu") and A.requires_grad == False: + if A.device.type == "cpu" and A.requires_grad == False: if getattr(quant_state, "ipex", False): # IPEX CPU will change weight to 4D so don't need transpose B = B.t() if B.dim() == 2 else B @@ -450,7 +447,6 @@ def matmul_4bit( return out else: return MatMul4Bit.apply(A, B, out, bias, quant_state) - if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: warn( diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 5f009ea40..b715b1d00 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -7,7 +7,7 @@ from ..._ops import register_kernel from ...cextension import lib -from ..utils import ipex_cpu +from ...utils import ipex_cpu # torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+. # However, we can overflow if we use this without AVX512_VNNI support. diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py old mode 100755 new mode 100644 index 1543f3474..19edd768d --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -3,16 +3,6 @@ from packaging import version import torch -try: - # to support Intel CPU/XPU (IPEX) backend - import intel_extension_for_pytorch as ipex - - ipex_cpu = ipex if ipex._C._has_cpu() else None - ipex_xpu = ipex if ipex._C._has_xpu() else None -except BaseException: - ipex_cpu = None - ipex_xpu = None - try: import triton # noqa: F401 import triton.language as tl # noqa: F401 diff --git a/bitsandbytes/backends/xpu/__init__.py b/bitsandbytes/backends/xpu/__init__.py old mode 100755 new mode 100644 diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py old mode 100755 new mode 100644 index 999116c97..ed59ed2f2 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -1,58 +1,205 @@ from collections.abc import Sequence +import ctypes as ct import warnings import torch +from bitsandbytes.functional import _get_tensor_stream, get_ptr + from ..._ops import register_kernel -from ..utils import ipex_xpu, triton_available +from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib +from ..utils import triton_available + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + if dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + -# _int_mm is available in torch starting from 2.7 version, -# but currently it's don't have xpu implementation. -if ipex_xpu and torch.__version__ >= (2, 7): +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) - @register_kernel("bitsandbytes::int8_linear_matmul", "xpu") - def _(A: torch.Tensor, B: torch.Tensor): - return torch._int_mm( - A.reshape(-1, A.shape[-1]), - B.t(), - ).reshape(*A.shape[:-1], B.shape[0]) +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + m = ct.c_int32(1) + n = ct.c_int32(shapeB[0]) + k = ct.c_int32(shapeB[1]) -# IPEX should be faster for xpu, so at first checking if it is available. -if ipex_xpu: + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m - @register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu") + stream = _get_tensor_stream(A) + if A.dtype == torch.float16: + lib.cgemv_4bit_inference_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemv_4bit_inference_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemv_4bit_inference_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + + +# SYCL should be faster for xpu, so at first checking if it is available. +if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): + + @register_kernel("bitsandbytes::dequantize_4bit", "xpu") def _( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, + quant_type: str, shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype) + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out @register_kernel("bitsandbytes::dequantize_blockwise", "xpu") + def _( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype + ) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + @register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu") def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, + out: torch.Tensor, + ) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + @register_kernel("bitsandbytes::gemv_4bit", "xpu") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, ) -> torch.Tensor: - shape = A.shape - out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device) - # void cdequantize_blockwise_fp32( - # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) - if dtype == torch.float16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) - elif dtype == torch.bfloat16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) - elif dtype == torch.float32: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out - return out.reshape(shape) + @register_kernel("bitsandbytes::gemv_4bit.out", "xpu") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, + ) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) elif triton_available: from ..triton import ops as triton_ops @@ -64,4 +211,6 @@ def _( register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) else: - warnings.warn("XPU available but no ipex or triton packages found.") + warnings.warn( + "XPU available but no native library or triton packages found. Please follow the installation instructions in the documentation." + ) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index bb301e712..29101c76c 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -283,6 +283,9 @@ def get_native_library() -> BNBNativeLibrary: binary_path = cuda_binary_path + if torch._C._has_xpu: + binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}" + logger.debug(f"Loading bitsandbytes native library from: {binary_path}") # Try to load the library - any errors will propagate up @@ -300,16 +303,6 @@ def get_native_library() -> BNBNativeLibrary: ROCM_GPU_ARCH = get_rocm_gpu_arch() -try: - # to support Intel CPU/GPU (XPU) backend - import intel_extension_for_pytorch as ipex - - ipex_cpu = ipex if ipex._C._has_cpu() else None - ipex_xpu = ipex if ipex._C._has_xpu() else None -except BaseException: - ipex_cpu = None - ipex_xpu = None - try: if torch.version.hip: HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" @@ -319,11 +312,10 @@ def get_native_library() -> BNBNativeLibrary: lib = get_native_library() except Exception as e: error_msg = str(e) - if not (ipex_cpu or ipex_xpu): - logger.error( - f"bitsandbytes library load error: {error_msg}\n If you are using Intel CPU/XPU, please install intel_extension_for_pytorch to enable required ops", - exc_info=True, - ) + logger.error( + f"bitsandbytes library load error: {error_msg}", + exc_info=True, + ) # create a mock with error messaging as fallback lib = ErrorHandlerMockBNBNativeLibrary(error_msg) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9b446a2de..372632d17 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib +from .cextension import HIP_ENVIRONMENT, lib name2qmap = {} @@ -439,6 +439,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: # We use the raw stream for performance reasons. + if tensor.device.type == "xpu": + return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index)) return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) @@ -2351,31 +2353,19 @@ def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): quant_state.nested = False delattr(quant_state, "state2") - if x.device.type == "cpu" and ipex_cpu: - converted_weight = _reverse_4bit_compress_format(linear.weight.data) - new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( - converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), - "nf4", - quant_state.shape, # weight shape - quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales - None, # zero_points - None, # bias - None, # batch_size - quant_state.blocksize, - 2, - ) - elif x.device.type == "xpu" and ipex_xpu: - new_weight = _reverse_4bit_compress_format(linear.weight.data) - new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) - new_zeros = None - compensation = None - new_scales = list(new_scales) - if not linear.training and not x.requires_grad: - new_weight = new_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) - else: - raise ValueError( - "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.7" - ) + assert x.device.type == "cpu" + converted_weight = _reverse_4bit_compress_format(linear.weight.data) + new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( + converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + "nf4", + quant_state.shape, # weight shape + quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales + None, # zero_points + None, # bias + None, # batch_size + quant_state.blocksize, + 2, + ) linear.weight.data = new_weight.data linear.weight.quant_state.ipex = True diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ba134f52a..9015665ee 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,12 +12,13 @@ import bitsandbytes as bnb from bitsandbytes.cextension import HIP_ENVIRONMENT -from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu +from bitsandbytes.functional import QuantState, _enable_ipex_fusion from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import ( INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, _reverse_4bit_compress_format, + ipex_cpu, ) T = TypeVar("T", bound="torch.nn.Module") @@ -476,8 +477,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): self.weight, "nf4", self.weight.quant_state.shape, 2 ) self.weight.data = _reverse_4bit_compress_format(original_weight.data) - elif self.weight.device.type == "xpu": - self.weight.data = _reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) self.weight.quant_state.ipex = False self.ipex_linear_is_set = False @@ -494,13 +493,15 @@ def set_ipex_linear(self, x: torch.Tensor): and self.weight.data.dtype == torch.uint8 and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 and self.weight.quant_state.quant_type == "nf4" + and x.device.type == "cpu" + and not self.training + and not x.requires_grad ): - if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False): - _enable_ipex_fusion(self, x) + _enable_ipex_fusion(self, x) def forward(self, x: torch.Tensor): # Check if ipex fusion can be used - if not self.ipex_linear_is_set and (ipex_cpu or ipex_xpu): + if not self.ipex_linear_is_set and ipex_cpu: self.set_ipex_linear(x) self.ipex_linear_is_set = True @@ -675,7 +676,7 @@ def to(self, *args, **kwargs): if device is not None and device.type != "meta" and self.data.device.type == "cpu": if device.type != "cpu" or self.data.dtype != torch.int8: return self._quantize(device) - elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu): + elif self.data.dtype == torch.int8 and device.type == "cpu" and ipex_cpu: self.CB = self.data new_param = Int8Params( diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 7920e2188..4328a241c 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -4,6 +4,14 @@ import torch +try: + # to support Intel CPU backend + import intel_extension_for_pytorch as ipex + + ipex_cpu = ipex if ipex._C._has_cpu() else None +except BaseException: + ipex_cpu = None + def outlier_hook(module, input): assert isinstance(module, torch.nn.Linear) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 9c4cab9cc..aa577d853 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -12,6 +12,9 @@ #if BUILD_MPS // #include #endif +#if BUILD_XPU +#include +#endif #include // Compatibility between HIP/CUDA APIs @@ -308,6 +311,88 @@ void spmm_coo_very_sparse_naive_int8( } #endif +#if BUILD_XPU + +void dequantizeBlockwise_fp16( + float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_fp4( + float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp16_nf4( + float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32( + float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_fp4( + float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_fp32_nf4( + float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16( + float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_fp4( + float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void dequantizeBlockwise_bf16_nf4( + float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); +} + +void gemv_4bit_inference_fp16( + int m, int n, int k, sycl::half * A, unsigned char* B, float *absmax, float *datatype, sycl::half * out, + int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void gemv_4bit_inference_bf16( + int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, + sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void gemv_4bit_inference_fp32( + int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#endif + extern "C" { #if BUILD_CUDA || BUILD_HIP void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } @@ -658,6 +743,88 @@ void cgemm_4bit_inference_naive_fp32( #endif +#if BUILD_XPU + +void cdequantize_blockwise_fp16_fp4( + float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16( + float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp16_nf4( + float *code, unsigned char *A, float *absmax, sycl::half *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32( + float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_fp4( + float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_fp32_nf4( + float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, sycl::queue* stream +) { + dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16( + float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_fp4( + float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); +} + +void cdequantize_blockwise_bf16_nf4( + float *code, unsigned char *A, float *absmax, sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + sycl::queue* stream +) { + dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); +} + +void cgemv_4bit_inference_fp16( + int m, int n, int k, sycl::half * A, unsigned char* B, float *absmax, float *datatype, sycl::half * out, + int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemv_4bit_inference_bf16( + int m, int n, int k, sycl::ext::oneapi::bfloat16 * A, unsigned char* B, float *absmax, float *datatype, + sycl::ext::oneapi::bfloat16 * out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +void cgemv_4bit_inference_fp32( + int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, + int ldb, int ldc, int blocksize, sycl::queue* stream +) { + gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); +} + +#endif + void cquantize_blockwise_cpu_fp32( float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n ) { diff --git a/csrc/xpu_kernels.cpp b/csrc/xpu_kernels.cpp new file mode 100644 index 000000000..9bdbd6e31 --- /dev/null +++ b/csrc/xpu_kernels.cpp @@ -0,0 +1,306 @@ +#include "xpu_kernels.h" +#include +#include +#include + +#include + +inline float dDequantizeFP4(unsigned char val) { + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -0.25000000f; + else + return -0.16666667f; + else if ((val & 0b0001) == 1) + return -0.50000000f; + else + return -0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return -1.00000000f; + else + return -0.66666667f; + else if ((val & 0b0001) == 1) + return -5.208333333e-03f; + else + return 0.00000000f; + else if ((val & 0b0100) == 4) + if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 0.25000000f; + else + return 0.16666667f; + else if ((val & 0b0001) == 1) + return 0.50000000f; + else + return 0.33333333f; + else if ((val & 0b0010) == 2) + if ((val & 0b0001) == 1) + return 1.00000000f; + else + return 0.66666667f; + else if ((val & 0b0001) == 1) + return 5.208333333e-03f; + else + return 0.00000000f; +} + +inline float dDequantizeNF4(unsigned char val) { + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) // 1 + if ((val & 0b0010) == 2) // 11 + if ((val & 0b0001) == 1) // 111 + return 1.0f; //*1111 + else + return 0.7229568362236023f; //*1110 + else if ((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; //*1101 + else + return 0.44070982933044434f; //*1100 + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; //*1011 + else + return 0.24611230194568634f; //*1010 + else if ((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; //*1001 + else + return 0.07958029955625534f; //*1000 + + else if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 011 + return 0.0f; //*0111 + else + return -0.09105003625154495f; //*0110 + else if ((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; //*0101 + else + return -0.28444138169288635f; //*0100 + else if ((val & 0b0010) == 2) // 00 + if ((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; //*0011 + else + return -0.5250730514526367f; //*0010 + else if ((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; //*0001 + else + return -1.0f; //*0000 +} + +template +SYCL_EXTERNAL void +kDequantizeBlockwise::operator()( + sycl::nd_item<1> item) const { + const int base_idx = item.get_group(0) * TILE_SIZE; + size_t local_idx = item.get_local_id(0) * NUM_PER_TH; + float local_abs_max = -FLT_MAX; + int local_load_idx = 0; + int local_store_idx = 0; + + uint8_t qvals[NUM_PER_TH]; + T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; + + if (DATA_TYPE > 0) { + local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx); + local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2); + } else { + local_load_idx = sycl::min(TILE_SIZE, n - base_idx); + local_store_idx = local_load_idx; + } + + // Avoid expensive divsion by the blocksize (as blocksize will always be a + // power-of-2) + local_abs_max = absmax[(base_idx + local_idx) >> + (31 - std::countl_zero(blocksize))]; + + if (local_idx + NUM_PER_TH < local_load_idx) { + reinterpret_cast(&)[NUM_PER_TH]>(qvals)[0] = + reinterpret_cast *>( + A)[(base_idx + local_idx) / NUM_PER_TH]; + } else { +#pragma unroll NUM_PER_TH + for (int i = 0; i < NUM_PER_TH; i++) { + if (local_idx + i < local_load_idx) { + qvals[i] = A[base_idx + local_idx + i]; + } else { + qvals[i] = (uint8_t)0; + } + } + } + + switch (DATA_TYPE) { + case General8bit: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + vals[j] = code[qvals[j]] * local_abs_max; + break; + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeFP4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeFP4(qvals[j] & 0x0F) * local_abs_max; + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; + } + break; + } + + const int local_dst_size = (DATA_TYPE > 0) ? NUM_PER_TH * 2 : NUM_PER_TH; + int local_dst_idx = (DATA_TYPE > 0) ? local_idx * 2 : local_idx; + + if (local_dst_idx + local_dst_size < local_store_idx) { + reinterpret_cast *>( + out)[(((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx) / + local_dst_size] = + reinterpret_cast(&)[local_dst_size]>( + vals)[0]; + } else { +#pragma unroll NUM_PER_TH + for (int i = 0; i < local_dst_size; i++) { + if (local_dst_idx + i < local_store_idx) { + out[((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx + i] = + vals[i]; + } + } + } +} + +template +SYCL_EXTERNAL void +kgemv_4bit_inference::operator()(sycl::nd_item<1> item) const { + size_t idx = item.get_local_id(); + const int sg_idx = idx / SUBG_SIZE; + const int sg_lane = idx % SUBG_SIZE; + const int num_values_4bit = SUBG_SIZE; + const int row_B = NUM_PER_THREAD * item.get_group().get_group_id() + sg_idx; + const int offset_B = ldb * row_B; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + T local_absmax = T(0.0f); + + if (idx < 16) { + quant_map[idx] = T(datatype[idx]); + } + + item.barrier(sycl::access::fence_space::local_space); + + for (int inner_idx = sg_lane * num_values_4bit; inner_idx < K; + inner_idx += SUBG_SIZE * num_values_4bit) { + const int inner_idx_halved = inner_idx / 2; + + // Avoid expensive divsion by the blocksize (as blocksize will always be a + // power-of-2) + const int absidx = ((2 * offset_B) + inner_idx) >> + (31 - std::countl_zero((unsigned int)blocksize)); + local_absmax = absmax[absidx]; + + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + reinterpret_cast(&)[num_values_8bit]>( + local_B_4bit)[0] = + reinterpret_cast *>( + B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { +#pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { +#pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { + local_B[k * 2] = + quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * + local_absmax; + local_B[k * 2 + 1] = + quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * + local_absmax; + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + if (BITS == 16) { + reinterpret_cast(&)[num_values_4bit / 4]>( + local_A)[0] = + reinterpret_cast *>( + A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(&)[num_values_4bit / 4]>( + local_A)[0] = + reinterpret_cast *>( + A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(&)[num_values_4bit / 4]>( + local_A)[1] = + reinterpret_cast *>( + A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + + } else { +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } + +// accumulate in float for accuracy; +#pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + local_C += (float)(local_A[k] * local_B[k]); + } + } + } + + local_C = + sycl::reduce_over_group(item.get_sub_group(), local_C, sycl::plus<>()); + + if (row_B < N && sg_lane == 0) + out[row_B] = T(local_C); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; + +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; + +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; +template class kDequantizeBlockwise; + +template class kgemv_4bit_inference; +template class kgemv_4bit_inference; +template class kgemv_4bit_inference; diff --git a/csrc/xpu_kernels.h b/csrc/xpu_kernels.h new file mode 100644 index 000000000..e5a115ced --- /dev/null +++ b/csrc/xpu_kernels.h @@ -0,0 +1,59 @@ +#include +#include + +#ifndef xpu_kernels +#define xpu_kernels + +template +class kDequantizeBlockwise { +public: + SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; + + kDequantizeBlockwise(float *code_, uint8_t *A_, float *absmax_, T *out_, + const int blocksize_, const int n_) + : code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), + n(n_) {} + +private: + float *code; + uint8_t *A; + float *absmax; + T *out; + const int blocksize; + const int n; +}; + +template +class kgemv_4bit_inference { +public: + SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const; + + kgemv_4bit_inference(int M_, int N_, int K_, T *A_, unsigned char *B_, + float *absmax_, const float *datatype_, T *out_, + int lda_, int ldb_, int ldc_, int blocksize_) + : M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_), + out(out_), lda(lda_), ldb(ldb_), ldc(ldc_), blocksize(blocksize_), + quant_map() {} + + void sycl_ker_local_memory_creation(sycl::handler &cgh) { + quant_map = sycl::local_accessor(16, cgh); + } + +private: + int M; + int N; + int K; + T *A; + unsigned char *B; + float *absmax; + const float *datatype; + T *out; + int lda; + int ldb; + int ldc; + int blocksize; + sycl::local_accessor quant_map; +}; + +#endif diff --git a/csrc/xpu_ops.cpp b/csrc/xpu_ops.cpp new file mode 100644 index 000000000..c1feb3996 --- /dev/null +++ b/csrc/xpu_ops.cpp @@ -0,0 +1,108 @@ +#include +#include +#include + +template +void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, + int blocksize, const int n, sycl::queue *stream) { + auto &queue = *stream; + const int workgroup_size = 128; + const int num_per_th = 4; + const int tile_size = workgroup_size * num_per_th; + if (DATA_TYPE > 0) { + const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2); + sycl::range<1> local_range{(size_t)workgroup_size}; + sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; + kDequantizeBlockwise kfn( + code, A, absmax, out, blocksize / 2, n); + sycl_kernel_submit( + sycl::nd_range<1>(sycl::range<1>(global_range), + sycl::range<1>(local_range)), + queue, kfn); + } else { + const int workgroup_num = (n + tile_size - 1) / tile_size; + sycl::range<1> local_range{(size_t)workgroup_size}; + sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size}; + kDequantizeBlockwise kfn( + code, A, absmax, out, blocksize, n); + sycl_kernel_submit( + sycl::nd_range<1>(sycl::range<1>(global_range), + sycl::range<1>(local_range)), + queue, kfn); + } +} + +template +void gemv_4bit_inference(int m, int n, int k, T *A, unsigned char *B, + float *absmax, float *datatype, T *out, int lda, + int ldb, int ldc, int blocksize, sycl::queue *stream) { + + auto &queue = *stream; + + const size_t GROUP_SIZE = 128; // workgroup_size + const size_t SUBG_SIZE = 32; // subgroup_size + const size_t NUM_PER_THREAD = GROUP_SIZE / SUBG_SIZE; + size_t workgroup_num = (n + NUM_PER_THREAD - 1) / NUM_PER_THREAD; + + kgemv_4bit_inference kfn( + m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + + sycl_comp_kernel_submit( + sycl::nd_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), + sycl::range<1>(GROUP_SIZE)), + queue, kfn); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void dequantizeBlockwise( + float *code, unsigned char *A, float *absmax, float *out, int blocksize, + const int n, sycl::queue *stream); +template void dequantizeBlockwise(float *code, unsigned char *A, + float *absmax, float *out, + int blocksize, const int n, + sycl::queue *stream); +template void dequantizeBlockwise(float *code, unsigned char *A, + float *absmax, float *out, + int blocksize, const int n, + sycl::queue *stream); + +template void dequantizeBlockwise( + float *code, unsigned char *A, float *absmax, sycl::half *out, + int blocksize, const int n, sycl::queue *stream); +template void dequantizeBlockwise( + float *code, unsigned char *A, float *absmax, sycl::half *out, + int blocksize, const int n, sycl::queue *stream); +template void dequantizeBlockwise( + float *code, unsigned char *A, float *absmax, sycl::half *out, + int blocksize, const int n, sycl::queue *stream); + +template void dequantizeBlockwise( + float *code, unsigned char *A, float *absmax, + sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + sycl::queue *stream); +template void dequantizeBlockwise( + float *code, unsigned char *A, float *absmax, + sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + sycl::queue *stream); +template void dequantizeBlockwise( + float *code, unsigned char *A, float *absmax, + sycl::ext::oneapi::bfloat16 *out, int blocksize, const int n, + sycl::queue *stream); + +template void gemv_4bit_inference( + int m, int n, int k, sycl::half *A, unsigned char *B, float *absmax, + float *datatype, sycl::half *out, int lda, int ldb, int ldc, int blocksize, + sycl::queue *stream); +template void gemv_4bit_inference( + int m, int n, int k, sycl::ext::oneapi::bfloat16 *A, unsigned char *B, + float *absmax, float *datatype, sycl::ext::oneapi::bfloat16 *out, int lda, + int ldb, int ldc, int blocksize, sycl::queue *stream); +template void gemv_4bit_inference(int m, int n, int k, float *A, + unsigned char *B, float *absmax, + float *datatype, float *out, + int lda, int ldb, int ldc, + int blocksize, + sycl::queue *stream); diff --git a/csrc/xpu_ops.h b/csrc/xpu_ops.h new file mode 100644 index 000000000..3045283a9 --- /dev/null +++ b/csrc/xpu_ops.h @@ -0,0 +1,49 @@ +#ifndef xpu_ops_H +#define xpu_ops_H + +#include +#include +#include +#include + +#include +#include + +#include + +template +static inline void sycl_kernel_submit(sycl::nd_range range, sycl::queue q, + ker_t ker) { + auto cgf = [&](::sycl::handler & cgh) + [[sycl::reqd_sub_group_size(subgroup_size)]] { + cgh.parallel_for(range, ker); + }; + q.submit(cgf); +} + +template +static inline void sycl_comp_kernel_submit(sycl::nd_range range, + sycl::queue q, ker_t ker) { + auto cgf = [&](::sycl::handler & cgh) + [[sycl::reqd_sub_group_size(subgroup_size)]] { + ker.sycl_ker_local_memory_creation(cgh); + cgh.parallel_for(range, ker); + }; + q.submit(cgf); +} + +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +template +void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, + int workgroup_size, const int n, sycl::queue *stream); +template +void gemv_4bit_inference(int m, int n, int k, T *A, unsigned char *B, + float *absmax, float *datatype, T *out, int lda, + int ldb, int ldc, int blocksize, sycl::queue *stream); + +#endif diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index e61ce4655..9b3449870 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -237,24 +237,16 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise #### Intel CPU + XPU - -If you are using Intel CPU on Linux or Intel XPU on Linux/Windows, please follow the [instruction](https://pytorch-extension.intel.com/) or the following command to install intel_extension_for_pytorch so you can get better performance. - -CPU: `pip install intel_extension_for_pytorch` -XPU: `pip install intel_extension_for_pytorch --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/` - -Install bitsandbytes: -CPU: Need to build CPU C++ codes +CPU needs to build CPU C++ codes, while xpu needs to build sycl codes. +Run `export bnb_device=xpu` if you are using xpu, run `export bnb_device=cpu` if you are using cpu. ``` git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ -cmake -DCOMPUTE_BACKEND=cpu -S . +cmake -DCOMPUTE_BACKEND=$bnb_device -S . make -pip install . -``` -XPU: -``` -pip install git+https://github.com/bitsandbytes-foundation/bitsandbytes.git +pip install -e . ``` +Note: You can run `pip install intel_extension_for_pytorch to get better performance on CPU` + diff --git a/tests/test_functional.py b/tests/test_functional.py index b84db6502..d201bc8ec 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -142,11 +142,11 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, abserr = sum(diffs) / len(diffs) relerr = sum(reldiffs) / len(reldiffs) if signed: - threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035 + threshold_abserr = 0.0035 assert abserr < 0.0036 assert relerr < 0.015 else: - assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023 + assert abserr < 0.0023 assert relerr < 0.012 assert A2.dtype == dtype @@ -177,8 +177,8 @@ def test_blockwise_cpu_large(self, hidden, blocksize): @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"]) def test_few_bit_quant(self, device, bits, method): - if bits != 8 and (device == "cpu" or (device == "xpu" and F.ipex_xpu)): - pytest.skip("CPU/XPU implementation only supports 8 bits") + if bits != 8 and device == "cpu": + pytest.skip("CPU implementation only supports 8 bits") abserrs = [] relerrs = [] @@ -1341,13 +1341,13 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert err1 < 6e-5 assert relerr1 < 2e-4 assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.005 and relratio > 0.995 - assert maxratio < 1.005 and maxratio > 0.995 + assert relratio < 1.005 and relratio > 0.992 + assert maxratio < 1.005 and maxratio > 0.992 elif dtype == torch.float32: if dim <= 512: assert err1 < 5e-8 assert relerr1 < 1e-6 - assert maxerr1 < 1e-7 + assert maxerr1 < 1.05e-7 else: assert err1 < 5e-8 assert relerr1 < 8e-6 @@ -1357,16 +1357,17 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert maxratio < 1.005 and maxratio > 0.995 elif dtype == torch.bfloat16: if dim <= 512: + relerr_thres = 0.013 if hasattr(torch, "xpu") and torch.xpu.is_available() else 0.007 assert err1 < 6e-4 - assert relerr1 < 0.007 + assert relerr1 < relerr_thres assert maxerr1 < 0.015 else: assert err1 < 2e-4 assert relerr1 < 0.002 assert maxerr1 < 0.0012 assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.04 and relratio > 0.96 - assert maxratio < 1.02 and maxratio > 0.98 + assert relratio < 1.05 and relratio > 0.96 + assert maxratio < 1.05 and maxratio > 0.97 @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) diff --git a/tests/test_modules.py b/tests/test_modules.py index 8946522d3..e5682e5c8 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -143,9 +143,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half() @@ -156,9 +155,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device) @@ -167,9 +165,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -189,9 +186,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -211,9 +207,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: + if threshold > 0 and device not in ("cpu", "xpu"): assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 diff --git a/tests/test_ops.py b/tests/test_ops.py index 8aa0560fd..3b52bf284 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -5,7 +5,6 @@ import bitsandbytes from bitsandbytes.cextension import HIP_ENVIRONMENT -from bitsandbytes.functional import ipex_xpu from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu # torch.library.opcheck is only available in torch 2.4 and later. @@ -145,10 +144,6 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): assert out.dtype == dtype assert out.device == A.device - # TODO: Enable it - if device == "xpu" and ipex_xpu: - pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check") - opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))