diff --git a/bitsandbytes/backends/triton/__init__.py b/bitsandbytes/backends/triton/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py new file mode 100644 index 000000000..1e2802ab5 --- /dev/null +++ b/bitsandbytes/backends/triton/ops.py @@ -0,0 +1,166 @@ +from collections.abc import Sequence + +import torch + +from . import triton_kernels + +# currently codes unused, kept for reference +# Should be the same for quant/dequant +# from bitsandbytes.functional import get_4bit_type +# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu") +# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu") + + +def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + + absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) + out = torch.empty_like(A.flatten(), dtype=torch.uint8) + + triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) + out = out.reshape(A.shape) + + return out, absmax.float() + + +def dequantize_blockwise( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype +) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") + + out = torch.empty_like(A, dtype=dtype, device=A.device) + triton_kernels.dequant_int8_blockwise( + A, + code, + absmax, + out, + blocksize, + ) + + return out + + +def dequantize_blockwise_inplace( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + + triton_kernels.dequant_int8_blockwise( + A, + code, + absmax, + out, + blocksize, + ) + + +def quantize_4bit( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + # torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}") + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + + # TODO: Support when weight matrix is not divisible by blocksize + # torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}") + + blocks = -(n // -(blocksize * 2)) + + absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype) + out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) + + triton_kernels.quantize_4bit_blockwise_triton( + A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out + ) + packed = out + + if quant_storage != torch.uint8: + packed = out.squeeze().view(quant_storage).unsqueeze(1) + + return packed, absmax.float() + + +def dequantize_4bit( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + torch._check_is_size(blocksize) + # torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}") + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + # torch._check( + # A.dtype == torch.uint8, + # lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}", + # ) + # Check if this is fine and fast + if A.dtype != torch.uint8: + A = A.squeeze().view(torch.uint8).unsqueeze(1) + + out = torch.empty(shape, dtype=dtype, device=A.device) + + triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +def dequantize_4bit_inplace( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def gemv_4bit( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> torch.Tensor: + if B.dtype != torch.uint8: + B = B.squeeze().view(torch.uint8).unsqueeze(1) + + B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device) + + triton_kernels._dequantize_4bit_impl_passing_code( + B, + absmax, + blocksize, + code, + dtype=A.dtype, + out=B_dq_triton, + ) + + return torch.nn.functional.linear( + A, + B_dq_triton, + bias=None, + ) diff --git a/bitsandbytes/backends/triton/triton_kernels.py b/bitsandbytes/backends/triton/triton_kernels.py new file mode 100644 index 000000000..03ffa187d --- /dev/null +++ b/bitsandbytes/backends/triton/triton_kernels.py @@ -0,0 +1,713 @@ +import torch + +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# # triton.Config({'SPLIT_SIZE': 64}), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128}), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_SIZE": 256}), +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# triton.Config({"SPLIT_SIZE": 512}), +# # triton.Config({'SPLIT_SIZE': 1024}), +# ], +# key=["num_paired_elements", "QUANT_BLOCK"], +# ) +@triton.jit +def dequant_8bit_kernel( + a_ptr, + c_ptr, + quant_ptr, + absmax_ptr, + num_paired_elements, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * SPLIT_SIZE + offsets = block_start + tl.arange(0, SPLIT_SIZE) + mask = offsets < num_paired_elements + + a = tl.load(a_ptr + offsets, mask) + a = a.to(tl.uint8) + + # apply conversion + scaled_int8 = tl.load(quant_ptr + a, mask) + + abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK + abs_offsets = offsets // QUANT_BLOCK + mask_blocked = offsets < abs_blocks_lim + + absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked) + # apply scales + out_dq = scaled_int8 * absmax + + offs = block_start + tl.arange(0, SPLIT_SIZE) + mask = offs < num_paired_elements + tl.store(c_ptr + offs, out_dq, mask) + + +def dequant_int8_blockwise( + A_nf4: torch.Tensor, + quant_state_code: torch.Tensor, + absmax: torch.Tensor, + out: torch.Tensor, + quant_blocksize: int = 64, +): + number_of_paired_elements = A_nf4.numel() + + SPLIT_SIZE = 256 + # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) + grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) + dequant_8bit_kernel[grid]( + A_nf4, + out, + quant_state_code, + absmax, + number_of_paired_elements, + quant_blocksize, + SPLIT_SIZE, + ) + return out + + +# @triton.autotune( +# configs=[ +# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 1}), +# triton.Config({"SPLIT_NUM_BLOCKS": 2}), +# ], +# key=["n_elements"], +# ) +@triton.jit +def quantize_blockwise_kernel( + A_ptr, + code_ptr, + absmax_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + CODE_SIZE: tl.constexpr, + SPLIT_NUM_BLOCKS: tl.constexpr, +): + block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS + thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) + + offsets = block_start_idx * BLOCK_SIZE + thread_idx + mask = offsets < n_elements + + A = tl.load(A_ptr + offsets, mask=mask, other=0.0) + + # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) + A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE)) + + # Calculating absamax for each block + absmax = tl.max(tl.abs(A_reshaped), axis=1) + tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) + + A_normalized = A_reshaped / absmax[:, None] + A_normalized = tl.clamp(A_normalized, -1.0, 1.0) + + lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32) + upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) + + for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter + pivot = (lower_pivot + upper_pivot) // 2 + val = tl.load(code_ptr + pivot) + is_higher = A_normalized > val # code[pivot] + lower_pivot = tl.where(is_higher, pivot, lower_pivot) + upper_pivot = tl.where(is_higher, upper_pivot, pivot) + + # Choose closest level + lower_val = tl.load(code_ptr + lower_pivot) + upper_val = tl.load(code_ptr + upper_pivot) + lower_dist = tl.abs(A_normalized - lower_val) + upper_dist = tl.abs(A_normalized - upper_val) + quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) + + # too slow approach + # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) + # quantized = tl.argmin(diff, axis=2).to(tl.uint8) + + quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) + tl.store(out_ptr + offsets, quantized_flat, mask=mask) + + +def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out): + n = A.numel() + + split_num_blocks = 1 + grid = (triton.cdiv(blocks, split_num_blocks),) + # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) + quantize_blockwise_kernel[grid]( + A_ptr=A, + code_ptr=code, + absmax_ptr=absmax, + out_ptr=quantized_out, + n_elements=n, + BLOCK_SIZE=blocksize, + CODE_SIZE=code.numel(), + SPLIT_NUM_BLOCKS=split_num_blocks, + ) + + return quantized_out, absmax + + +# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4 +# @triton.autotune( +# configs=[ +# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 1}), +# triton.Config({"SPLIT_NUM_BLOCKS": 2}), +# triton.Config({"SPLIT_NUM_BLOCKS": 4}), +# triton.Config({"SPLIT_NUM_BLOCKS": 8}), +# ], +# key=["n_elements"], +# ) +@triton.jit +def quantize_fp4_blockwise_kernel( + A_ptr, + absmax_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + SPLIT_NUM_BLOCKS: tl.constexpr, +): + PAIRED_SPLIT_NUM_BLOCKS: tl.constexpr = SPLIT_NUM_BLOCKS * 2 + block_start_idx = tl.program_id(0) * PAIRED_SPLIT_NUM_BLOCKS + thread_idx = tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS * BLOCK_SIZE) + + offsets = block_start_idx * BLOCK_SIZE + thread_idx + mask = offsets < n_elements + + A = tl.load(A_ptr + offsets, mask=mask, other=0.0) + + # To be able process several blocks -> (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE) + A_reshaped = tl.reshape(A, (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE)) + + # Calculating absamax for each block + absmax = tl.max(tl.abs(A_reshaped), axis=1) + tl.store(absmax_ptr + block_start_idx + tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS), absmax) + + A_normalized = A_reshaped / absmax[:, None] + A_normalized = tl.clamp(A_normalized, -1.0, 1.0) + + sign = tl.where(A_normalized < 0, 0b1000, 0b0000) + A_absf = tl.abs(A_normalized) + + result = tl.where( + A_absf > 0.29166667, + tl.where( + A_absf > 0.583333, tl.where(A_absf > 0.8333333, 0b011, 0b010), tl.where(A_absf > 0.4166667, 0b101, 0b100) + ), + tl.where( + A_absf > 0.0859375, + tl.where(A_absf > 0.20833333, 0b0111, 0b0110), + tl.where(A_absf > 0.00260417, 0b0001, 0b0000), + ), + ) + quantized = (result ^ sign).to(tl.uint8) + + quantized = quantized.reshape((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE // 2, 2)) + left, right = quantized.split() + packed = left << 4 | (right & 0xF) + + packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) + out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) + out_mask = out_offsets < n_elements // 2 + tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask) + + +# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeNF4 +# @triton.autotune( +# configs=[ +# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 1}), +# triton.Config({"SPLIT_NUM_BLOCKS": 2}), +# triton.Config({"SPLIT_NUM_BLOCKS": 4}), +# triton.Config({"SPLIT_NUM_BLOCKS": 8}), +# ], +# key=["n_elements"], +# ) +@triton.jit +def quantize_nf4_blockwise_kernel( + A_ptr, + absmax_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + SPLIT_NUM_BLOCKS: tl.constexpr, +): + PAIRED_SPLIT_NUM_BLOCKS: tl.constexpr = SPLIT_NUM_BLOCKS * 2 + block_start_idx = tl.program_id(0) * PAIRED_SPLIT_NUM_BLOCKS + thread_idx = tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS * BLOCK_SIZE) + + offsets = block_start_idx * BLOCK_SIZE + thread_idx + mask = offsets < n_elements + + A = tl.load(A_ptr + offsets, mask=mask, other=0.0) + + # To be able process several blocks -> (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE) + A_reshaped = tl.reshape(A, (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE)) + + # Calculating absamax for each block + absmax = tl.max(tl.abs(A_reshaped), axis=1) + tl.store(absmax_ptr + block_start_idx + tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS), absmax) + + A_normalized = A_reshaped / absmax[:, None] + A_normalized = tl.clamp(A_normalized, -1.0, 1.0) + + result = tl.where( + A_normalized > 0.03979014977812767, + tl.where( + A_normalized > 0.3893125355243683, + tl.where( + A_normalized > 0.6427869200706482, + tl.where(A_normalized > 0.8614784181118011, 0b1111, 0b1110), + tl.where(A_normalized > 0.5016634166240692, 0b1101, 0b1100), + ), + tl.where( + A_normalized > 0.2035212516784668, + tl.where(A_normalized > 0.2920137718319893, 0b1011, 0b1010), + tl.where(A_normalized > 0.1202552504837513, 0b1001, 0b1000), + ), + ), + tl.where( + A_normalized > -0.33967943489551544, + tl.where( + A_normalized > -0.13791173323988914, + tl.where(A_normalized > -0.045525018125772476, 0b0111, 0b0110), + tl.where(A_normalized > -0.23460740596055984, 0b0101, 0b0100), + ), + tl.where( + A_normalized > -0.6106329262256622, + tl.where(A_normalized > -0.4599952697753906, 0b0011, 0b0010), + tl.where(A_normalized > -0.8480964004993439, 0b0001, 0b0000), + ), + ), + ) + quantized = result.to(tl.uint8) + + quantized = quantized.reshape((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE // 2, 2)) + + left, right = quantized.split() + packed = left << 4 | (right & 0xF) + + packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) + out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) + out_mask = out_offsets < n_elements // 2 + tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask) + + +def quantize_4bit_blockwise_triton(A, blocksize, quant_type, blocks, absmax, num_elements, quantized_out): + # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) + split_num_blocks = 4 + grid = (triton.cdiv(blocks, split_num_blocks),) + if quant_type == "fp4": + quantize_fp4_blockwise_kernel[grid]( + A_ptr=A, + absmax_ptr=absmax, + out_ptr=quantized_out, + n_elements=num_elements, + BLOCK_SIZE=blocksize, + SPLIT_NUM_BLOCKS=split_num_blocks, + ) + else: + quantize_nf4_blockwise_kernel[grid]( + A_ptr=A, + absmax_ptr=absmax, + out_ptr=quantized_out, + n_elements=num_elements, + BLOCK_SIZE=blocksize, + SPLIT_NUM_BLOCKS=split_num_blocks, + ) + return quantized_out, absmax + + +@triton.jit +def dequant_4bit_body_util(a, offsets, quant_ptr, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr): + PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2 + mask = offsets < n_elems + higher = a & 0xF + # lower 4bits + lower = a >> 4 + + abs_offsets = offsets // PAIRED_QUANT_BLOCK + absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last") + + # apply conversion + lower_4 = tl.load(quant_ptr + lower, eviction_policy="evict_last") + higher_4 = tl.load(quant_ptr + higher, eviction_policy="evict_last") + + mul_high = higher_4 * absmax + mul_low = lower_4 * absmax + out_dq = tl.interleave(mul_low, mul_high) + return out_dq + + +# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dDequantizeFP4Tree +@triton.jit +def dequantize_fp4_tree(val, absmax): + # val: tl.tensor (uint8) + # absmax: tl.tensor (float32/float16) + # 00001100 00001011 00001001 00001111 + sign = tl.where((val & 0b1000) == 0b1000, -1.0, 1.0) # -1 + third_bit = (val & 0b0100) == 0b0100 # True + second_bit = (val & 0b0010) == 0b0010 # False + first_bit = (val & 0b0001) == 0b0001 # False + + branch1 = tl.where( + second_bit, + tl.where(first_bit, 0.25, 0.16666667), # 1111, 1110 + tl.where(first_bit, 0.5, 0.33333333), # 1101, 1100 + ) + branch2 = tl.where( + second_bit, + tl.where(first_bit, 1.0, 0.66666667), # 1011, 1010 + tl.where(first_bit, 0.00520833, 0.0), # 1001, 1000 + ) + out = tl.where(third_bit, branch1, branch2) + return out * sign * absmax + + +@triton.jit +def dequant_fp4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr): + PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2 + mask = offsets < n_elems + higher = a & 0xF + lower = a >> 4 + + abs_offsets = offsets // PAIRED_QUANT_BLOCK + absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last") + mul_high = dequantize_fp4_tree(higher, absmax) + mul_low = dequantize_fp4_tree(lower, absmax) + out_dq = tl.interleave(mul_low, mul_high) + return out_dq + + +# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dDequantizeNF4 +@triton.jit +def dequantize_nf4_tree(val): + # val: tl.tensor (uint8) + cond0 = (val & 0b1000) == 0b1000 + cond1 = (val & 0b0100) == 0b0100 + cond2 = (val & 0b0010) == 0b0010 + cond3 = (val & 0b0001) == 0b0001 + + # Positive branch (val & 0b1000) == 8 + branch_pos = tl.where( + cond1, + tl.where( + cond2, + tl.where(cond3, 1.0, 0.7229568362236023), # 1111, 1110 + tl.where(cond3, 0.5626170039176941, 0.44070982933044434), # 1101, 1100 + ), + tl.where( + cond2, + tl.where(cond3, 0.33791524171829224, 0.24611230194568634), # 1011, 1010 + tl.where(cond3, 0.16093020141124725, 0.07958029955625534), # 1001, 1000 + ), + ) + + # Negative branch (val & 0b1000) == 0 + branch_neg = tl.where( + cond1, + tl.where( + cond2, + tl.where(cond3, 0.0, -0.09105003625154495), # 0111, 0110 + tl.where(cond3, -0.18477343022823334, -0.28444138169288635), # 0101, 0100 + ), + tl.where( + cond2, + tl.where(cond3, -0.39491748809814453, -0.5250730514526367), # 0011, 0010 + tl.where(cond3, -0.6961928009986877, -1.0), # 0001, 0000 + ), + ) + return tl.where(cond0, branch_pos, branch_neg) + + +@triton.jit +def dequant_nf4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr): + PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2 + mask = offsets < n_elems + higher = a & 0xF + # lower 4bits + lower = a >> 4 + + abs_offsets = offsets // PAIRED_QUANT_BLOCK + absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last") + mul_high = dequantize_nf4_tree(higher) * absmax + mul_low = dequantize_nf4_tree(lower) * absmax + out_dq = tl.interleave(mul_low, mul_high) + return out_dq + + +# All such kernels are similar, so maybe code can be generalised. +# @triton.autotune( +# configs=[ +# # # triton.Config({'SPLIT_SIZE': 64}), +# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# triton.Config({'SPLIT_SIZE': 128}), +# triton.Config({'SPLIT_SIZE': 128}, num_warps = 32, num_stages = 2), +# # # triton.Config({'SPLIT_SIZE': 128}, num_warps = 4, num_stages = 4), +# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# triton.Config({'SPLIT_SIZE': 256}), +# triton.Config({'SPLIT_SIZE': 256}, num_warps = 32, num_stages = 2), +# # triton.Config({'SPLIT_SIZE': 256}, num_warps = 4, num_stages = 4), +# triton.Config({'SPLIT_SIZE': 512}), +# triton.Config({'SPLIT_SIZE': 512}, num_warps = 32, num_stages = 2), +# # triton.Config({'SPLIT_SIZE': 512}, num_warps = 4, num_stages = 4), +# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# # # triton.Config({'SPLIT_SIZE': 1024}), +# # # # triton.Config({'SPLIT_SIZE': 2048}), +# # # # triton.Config({'SPLIT_SIZE': 4096}), +# # # # triton.Config({'SPLIT_SIZE': 8192}), +# # # # triton.Config({'SPLIT_SIZE': 16384}), +# ], +# key=['num_paired_elements'], +# ) +@triton.jit +def dequant_4bit_kernel( + a_ptr, c_ptr, quant_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * SPLIT_SIZE + offsets = block_start + tl.arange(0, SPLIT_SIZE) + mask = offsets < num_paired_elements + + a = tl.load(a_ptr + offsets, mask, eviction_policy="evict_first") + + out_dq = dequant_4bit_body_util( + a=a, + offsets=offsets, + quant_ptr=quant_ptr, + absmax_ptr=absmax_ptr, + n_elems=num_paired_elements, + QUANT_BLOCK=QUANT_BLOCK, + ) + + out_block_start = pid * SPLIT_SIZE * 2 + offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2) + mask = offs < num_paired_elements * 2 + tl.store(c_ptr + offs, out_dq, mask) + + +# @triton.autotune( +# configs=[ +# triton.Config({'SPLIT_SIZE': 128}, num_warps = 32, num_stages = 2), +# triton.Config({'SPLIT_SIZE': 256}), +# triton.Config({'SPLIT_SIZE': 256}, num_warps = 32, num_stages = 2), +# triton.Config({'SPLIT_SIZE': 512}), +# triton.Config({'SPLIT_SIZE': 512}, num_warps = 32, num_stages = 2), +# triton.Config({'SPLIT_SIZE': 1024}, num_warps = 32, num_stages = 2), +# ], +# key=['num_paired_elements'], +# ) +@triton.jit +def dequant_fp4_kernel( + a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * SPLIT_SIZE + offsets = block_start + tl.arange(0, SPLIT_SIZE) + mask = offsets < num_paired_elements + + a = tl.load(a_ptr + offsets, mask, eviction_policy="evict_first") + + out_dq = dequant_fp4_body_util( + a=a, + offsets=offsets, + absmax_ptr=absmax_ptr, + n_elems=num_paired_elements, + QUANT_BLOCK=QUANT_BLOCK, + ) + + out_block_start = pid * SPLIT_SIZE * 2 + offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2) + mask = offs < num_paired_elements * 2 + tl.store(c_ptr + offs, out_dq, mask) + + +# @triton.autotune( +# configs=[ +# triton.Config({'SPLIT_SIZE': 128}, num_warps = 32, num_stages = 2), +# triton.Config({'SPLIT_SIZE': 256}), +# triton.Config({'SPLIT_SIZE': 256}, num_warps = 32, num_stages = 2), +# triton.Config({'SPLIT_SIZE': 512}), +# triton.Config({'SPLIT_SIZE': 512}, num_warps = 32, num_stages = 2), +# triton.Config({'SPLIT_SIZE': 1024}, num_warps = 32, num_stages = 2), +# ], +# key=['num_paired_elements'], +# ) +@triton.jit +def dequant_nf4_kernel( + a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * SPLIT_SIZE + offsets = block_start + tl.arange(0, SPLIT_SIZE) + mask = offsets < num_paired_elements + + a = tl.load(a_ptr + offsets, mask, eviction_policy="evict_first") + + out_dq = dequant_nf4_body_util( + a=a, + offsets=offsets, + absmax_ptr=absmax_ptr, + n_elems=num_paired_elements, + QUANT_BLOCK=QUANT_BLOCK, + ) + + out_block_start = pid * SPLIT_SIZE * 2 + offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2) + mask = offs < num_paired_elements * 2 + tl.store(c_ptr + offs, out_dq, mask) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + # It's will be processed as an array, so + # actual length is row * col + # Elements are in uint8 format, so interleaved + # so total amount of data is 2 * elem_count + number_of_paired_elements = A.numel() + # we assume that split_size > quant_blocksize + + SPLIT_SIZE = 256 + # grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), ) + grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) + if quant_type == "fp4": + dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) + else: + dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) + + +def _dequantize_4bit_impl_passing_code( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + code: torch.Tensor, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + number_of_paired_elements = A.numel() + # we assume that split_size > quant_blocksize + + SPLIT_SIZE = 256 + # grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), ) + grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) + dequant_4bit_kernel[grid](A, out, code, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) + + +######################### Fallback dequantization functions ######################### +## for debug ## + + +# @triton.autotune( +# configs=[ +# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # # +# # triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# # +# triton.Config({"SPLIT_NUM_BLOCKS": 2}), +# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "large"}, num_stages=2, num_warps=32), +# # # triton.Config({'SPLIT_NUM_BLOCKS': 2, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=2, num_warps=32), +# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=2, num_warps=32), +# # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_NUM_BLOCKS': 8, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# ], +# key=["n_elements", "BLOCK_SIZE"], +# ) +@triton.jit +def quantize_4bit_blockwise_kernel( + A_ptr, + code_ptr, + absmax_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + CODE_SIZE: tl.constexpr, + SPLIT_NUM_BLOCKS: tl.constexpr, +): + PAIRED_SPLIT_NUM_BLOCKS: tl.constexpr = SPLIT_NUM_BLOCKS * 2 + block_start_idx = tl.program_id(0) * PAIRED_SPLIT_NUM_BLOCKS + thread_idx = tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS * BLOCK_SIZE) + + offsets = block_start_idx * BLOCK_SIZE + thread_idx + mask = offsets < n_elements + + A = tl.load(A_ptr + offsets, mask=mask, other=0.0) + + # To be able process several blocks -> (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE) + A_reshaped = tl.reshape(A, (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE)) + + # Calculating absamax for each block + absmax = tl.max(tl.abs(A_reshaped), axis=1) + tl.store(absmax_ptr + block_start_idx + tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS), absmax) + + A_normalized = A_reshaped / absmax[:, None] + A_normalized = tl.clamp(A_normalized, -1.0, 1.0) + + lower_pivot = tl.zeros((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32) + upper_pivot = tl.full((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) + + for _ in range(4): # ceil(log2(code_size)) = 4, actually, in general case should be input parameter + pivot = (lower_pivot + upper_pivot) // 2 + val = tl.load(code_ptr + pivot) + is_higher = A_normalized > val # code[pivot] + lower_pivot = tl.where(is_higher, pivot, lower_pivot) + upper_pivot = tl.where(is_higher, upper_pivot, pivot) + + # Choose closest level + lower_val = tl.load(code_ptr + lower_pivot) + upper_val = tl.load(code_ptr + upper_pivot) + lower_dist = tl.abs(A_normalized - lower_val) + upper_dist = tl.abs(A_normalized - upper_val) + quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) + + quantized = quantized.reshape((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE // 2, 2)) + quantized = quantized.to(tl.uint8, bitcast=True) + left, right = quantized.split() + packed = left << 4 | (right & 0xF) + + # Reduce don't guarantee the order of the elements passed to unite_2_int4 + # packed = tl.reduce(quantized, axis=2, combine_fn=unite_2_int4) + # packed = packed.to(tl.uint8, bitcast=True) + + packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) + out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) + out_mask = out_offsets < n_elements // 2 + tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask) diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index 01d316285..1543f3474 100755 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -13,6 +13,15 @@ ipex_cpu = None ipex_xpu = None +try: + import triton # noqa: F401 + import triton.language as tl # noqa: F401 + + triton_available = True +except ImportError as e: + triton_available = False + + _NF4_QUANT_TABLE = torch.tensor( [ -1.0, diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 47a3bd009..999116c97 100755 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -1,11 +1,14 @@ from collections.abc import Sequence +import warnings import torch from ..._ops import register_kernel -from ..utils import ipex_xpu +from ..utils import ipex_xpu, triton_available -if torch.__version__ >= (2, 7): +# _int_mm is available in torch starting from 2.7 version, +# but currently it's don't have xpu implementation. +if ipex_xpu and torch.__version__ >= (2, 7): @register_kernel("bitsandbytes::int8_linear_matmul", "xpu") def _(A: torch.Tensor, B: torch.Tensor): @@ -15,6 +18,7 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) +# IPEX should be faster for xpu, so at first checking if it is available. if ipex_xpu: @register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu") @@ -49,3 +53,15 @@ def _( raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") return out.reshape(shape) +elif triton_available: + from ..triton import ops as triton_ops + + register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) + register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(triton_ops.dequantize_blockwise_inplace) + register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(triton_ops.dequantize_blockwise) + register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit) + register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace) + register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) + register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) +else: + warnings.warn("XPU available but no ipex or triton packages found.") diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index a9cc60dc1..1aed09219 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -671,7 +671,7 @@ def to(self, *args, **kwargs): if device is not None and device.type != "meta" and self.data.device.type == "cpu": if device.type != "cpu" or self.data.dtype != torch.int8: return self._quantize(device) - elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu"): + elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu): self.CB = self.data new_param = Int8Params( diff --git a/tests/test_functional.py b/tests/test_functional.py index 1aa2e1d37..6706df138 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -137,11 +137,11 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, abserr = sum(diffs) / len(diffs) relerr = sum(reldiffs) / len(reldiffs) if signed: - threshold_abserr = 0.0036 if device in ("cpu", "xpu") else 0.0035 + threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035 assert abserr < 0.0036 assert relerr < 0.015 else: - assert abserr < 0.00175 if device in ("cpu", "xpu") else 0.0023 + assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023 assert relerr < 0.012 assert A2.dtype == dtype @@ -172,7 +172,7 @@ def test_blockwise_cpu_large(self, hidden, blocksize): @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"]) def test_few_bit_quant(self, device, bits, method): - if device in ("cpu", "xpu") and bits != 8: + if device in ("cpu", "xpu") and bits != 8 and (F.ipex_cpu or F.ipex_xpu): pytest.skip("CPU/XPU implementation only supports 8 bits") abserrs = [] diff --git a/tests/test_ops.py b/tests/test_ops.py index 7da19c012..60c47a250 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,6 +4,7 @@ import torch import bitsandbytes +from bitsandbytes.functional import ipex_xpu from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter # torch.library.opcheck is only available in torch 2.4 and later. @@ -144,7 +145,7 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): assert out.device == A.device # TODO: Enable it - if device == "xpu": + if device == "xpu" and ipex_xpu: pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check") opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype)) @@ -170,7 +171,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize if storage_dtype != torch.uint8: pytest.xfail("opcheck fails for storage_dtype != torch.uint8") - opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) + opcheck(torch.ops.bitsandbytes.quantize_4bit.default, (A, blocksize, quant_type, storage_dtype)) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))