diff --git a/vllm_hpu_extension/flags.py b/vllm_hpu_extension/flags.py index 3dc9f885e..1cd94496c 100644 --- a/vllm_hpu_extension/flags.py +++ b/vllm_hpu_extension/flags.py @@ -11,6 +11,7 @@ from vllm_hpu_extension.environment import get_environment from vllm_hpu_extension.kernels import fsdpa +from vllm_hpu_extension.kernels import block_softmax_adjustment detected = None @@ -160,6 +161,11 @@ def enabled_flags(): & ModelType("llama") & Not(EnvFlag("VLLM_PROMPT_USE_FUSEDSDPA", "false")) & EnvFlag("VLLM_PROMPT_USE_FLEX_ATTENTION", "false")), + "fused_block_softmax_adjustment": (Not(Hardware("cpu")) + & VersionRange(">=1.22.0.101") + & Kernel(block_softmax_adjustment) + & EnvFlag("VLLM_FUSED_BLOCK_SOFTMAX_ADJUSTMENT", + Not(ModelType('qwen2')) & Hardware("gaudi3"))), } environment = get_environment() detected = Flags(supported_flags, environment) diff --git a/vllm_hpu_extension/kernels.py b/vllm_hpu_extension/kernels.py index 77164985a..5aadd3c1f 100644 --- a/vllm_hpu_extension/kernels.py +++ b/vllm_hpu_extension/kernels.py @@ -5,24 +5,37 @@ # LICENSE file in the root directory of this source tree. ############################################################################### -from .utils import logger from functools import cache -@cache +def _kernel(name): + def loader(fn): + @cache + def loader_impl(): + try: + print("Load", name, fn) + return fn() + except (ImportError, AttributeError): + from .utils import logger + logger().warning(f"Could not import HPU {name} kernel. " + "vLLM will use native implementation") + return loader_impl + return loader + + +@_kernel("FusedSDPA") def fsdpa(): - try: - from habana_frameworks.torch.hpex.kernels import FusedSDPA - return FusedSDPA - except ImportError: - logger().warning("Could not import HPU FusedSDPA kernel. " - "vLLM will use native implementation.") - -@cache + from habana_frameworks.torch.hpex.kernels import FusedSDPA + return FusedSDPA + + +@_kernel("FusedRMSNorm") def rms_norm(): - try: - from habana_frameworks.torch.hpex.normalization import FusedRMSNorm - return FusedRMSNorm - except ImportError: - logger().warning("Could not import HPU FusedRMSNorm kernel. " - "vLLM will use forward_native implementation of RMSNorm.") + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm + return FusedRMSNorm + + +@_kernel("block_softmax_adjustment") +def block_softmax_adjustment(): + import torch + return torch.ops.hpu.block_softmax_adjustment diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index dbc7c4457..d17885fbf 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -65,31 +65,39 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s adjustment_target_shape = block_max.shape attn = attn.sub(block_max) attn = attn.exp() - attn = attn.to(value.dtype) + if attn.dtype == torch.float32: + attn = attn.to(value.dtype) block_sums = attn.sum(dim=-1, keepdim=True) attn = matmul_av_op(attn, value) - block_max = block_max.squeeze() - block_sums = block_sums.squeeze() - - # Calculate maximum of blocks that belong to the same sequences - # and cast adjustments to native dtype - group_max = grouped_max(block_max, batch_size, block_groups) - block_adjustment = (block_max - group_max).exp() - block_adjustment = block_adjustment.to(value.dtype) - sum_adjusted = block_sums.mul(block_adjustment) - - # Sum block's sums that belongs to the same sequences - group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op) - group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op) - sum_adjusted = sum_adjusted.view(*adjustment_target_shape) - group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape) - block_adjustment = block_adjustment.view(*adjustment_target_shape) - - # For stability in case some of the sums have been zeroed out during block aggretation - group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted) - - # Post processing for the attention scores - rescale = block_adjustment.div(group_sum_adjusted) + + if 'fused_block_softmax_adjustment' in enabled_flags() and block_max.dtype != torch.float16: + rescale = torch.ops.hpu.block_softmax_adjustment(block_max, + block_sums.to(block_max.dtype), + block_groups, + batch_size).to(attn.dtype) + else: + block_max = block_max.squeeze() + block_sums = block_sums.squeeze() + + # Calculate maximum of blocks that belong to the same sequences + # and cast adjustments to native dtype + group_max = grouped_max(block_max, batch_size, block_groups) + block_adjustment = (block_max - group_max).exp() + if block_adjustment.dtype == torch.float32: + block_adjustment = block_adjustment.to(value.dtype) + sum_adjusted = block_sums.mul(block_adjustment) + + # Sum block's sums that belongs to the same sequences + group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op) + group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op) + sum_adjusted = sum_adjusted.view(*adjustment_target_shape) + group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape) + block_adjustment = block_adjustment.view(*adjustment_target_shape) + + # For stability in case some of the sums have been zeroed out during block aggretation + group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted) + # Post processing for the attention scores + rescale = block_adjustment.div(group_sum_adjusted) attn = attn.mul(rescale) return attn @@ -405,8 +413,8 @@ def forward(self, hidden_states, score, topk): htorch.core.mark_step() routing_weights = F.softmax(score, dim=1, dtype=torch.float32) routing_weights, selected_experts = torch.topk(routing_weights, - topk, - dim=-1) + topk, + dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype)