diff --git a/tripy/nvtripy/frontend/ops/attention.py b/tripy/nvtripy/frontend/ops/attention.py new file mode 100644 index 000000000..ec8352717 --- /dev/null +++ b/tripy/nvtripy/frontend/ops/attention.py @@ -0,0 +1,181 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +from nvtripy import export +from nvtripy.common import datatype +from nvtripy.frontend.ops import utils as op_utils +from nvtripy.trace.ops.attention import Attention +from nvtripy.utils import wrappers + + +@export.public_api(document_under="operations/functions") +@wrappers.interface( + dtype_constraints={ + "query": "T1", + "key": "T1", + "value": "T1", + wrappers.RETURN_VALUE: "T1", + }, + dtype_variables={ + "T1": ["float32", "float16", "bfloat16"], + }, +) +def attention( + query: "nvtripy.Tensor", + key: "nvtripy.Tensor", + value: "nvtripy.Tensor", + *, + mask: Optional["nvtripy.Tensor"] = None, + normalization_quantize_scale: Optional["nvtripy.Tensor"] = None, + normalization_operation: str = "softmax", + causal: bool = False, + decomposable: bool = False, + normalization_quantize_to_type: Optional[datatype.dtype] = None, +) -> "nvtripy.Tensor": + r""" + Performs a fused multi-head attention operation. + + This operation implements the attention mechanism: + + .. math:: + \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V + + The operation consists of: + + 1. Matrix multiplication between query and transposed key (BMM1) + 2. Optional masking + 3. Normalization (typically softmax) + 4. Optional quantization of the normalized output + 5. Matrix multiplication with value (BMM2) + + Args: + query: The query tensor with shape ``[batch_size, num_heads_query, sequence_length_query, dim_head]``. + key: The key tensor with shape ``[batch_size, num_heads_key, sequence_length_key, dim_head]``. + value: The value tensor with shape ``[batch_size, num_heads_value, sequence_length_value, dim_head]``. + mask: Optional mask tensor with shape + ``[batch_size, num_heads_query, sequence_length_query, sequence_length_key]``. + For boolean masks (dtype=bool), ``True`` indicates positions that are allowed to attend. + For float masks, the values are added to the attention scores before normalization. + normalization_quantize_scale: Optional scale tensor for quantizing the normalization output. + Required if ``normalization_quantize_to_type`` is specified. + normalization_operation: The normalization operation to use. Must be one of "softmax" or "none". + Defaults to ``"softmax"``. + causal: If ``True``, applies causal (autoregressive) masking where each position can only + attend to earlier positions. Cannot be used together with explicit ``mask``. Defaults to ``False``. + decomposable: If ``True``, allows the operation to be decomposed into multiple kernels if + no fused kernel is available. Defaults to ``False``. + normalization_quantize_to_type: Optional data type for quantizing the normalization output. + Must be either :class:`nvtripy.float8` or :class:`nvtripy.int8`. + Requires ``normalization_quantize_scale`` to be provided. + + Returns: + The attention output tensor with shape ``[batch_size, num_heads_query, sequence_length_query, dim_head]``. + + .. code-block:: python + :linenos: + :caption: Basic Attention + + batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64 + query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) + key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) + value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) + + output = tp.attention(query, key, value) + + assert output.shape == (batch_size, num_heads, seq_len, head_dim) + + .. code-block:: python + :linenos: + :caption: Attention with Quantization + + batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64 + query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) + key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) + value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) + + # Quantize the normalization output (softmax) to float8 + mask = tp.ones((batch_size, num_heads, seq_len, seq_len), dtype=tp.bool) + scale = tp.Tensor([1.0], dtype=tp.float16) + + output = tp.attention(query, key, value, mask=mask, + normalization_quantize_scale=scale, + normalization_quantize_to_type=tp.float8) + + assert output.shape == (batch_size, num_heads, seq_len, head_dim) + + .. code-block:: python + :linenos: + :caption: Attention with Mask + + batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64 + query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) + key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) + value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) + + # Create a boolean mask where True indicates positions that can attend + mask = tp.ones((batch_size, num_heads, seq_len, seq_len), dtype=tp.bool) + + output = tp.attention(query, key, value, mask=mask) + + assert output.shape == (batch_size, num_heads, seq_len, head_dim) + """ + from nvtripy.common.exception import raise_error + + if normalization_operation not in ("softmax", "none"): + raise_error( + f"Invalid normalization operation: {normalization_operation}. Must be one of 'softmax' or 'none'.", + ) + + # Validation checks + if causal and mask is not None: + raise_error( + "Cannot use both `causal` and `mask` at the same time.", + details=[ + "The `causal` parameter applies implicit causal masking.", + "Please use either `causal=True` or provide an explicit `mask`.", + ], + ) + + if normalization_quantize_to_type is not None: + if normalization_quantize_scale is None: + raise_error( + "`normalization_quantize_scale` must be provided when `normalization_quantize_to_type` is specified.", + ) + + if normalization_quantize_to_type not in (datatype.float8, datatype.int8): + raise_error( + f"`normalization_quantize_to_type` must be either float8 or int8.", + details=[f"Got: {normalization_quantize_to_type}"], + ) + + # Collect inputs based on what's provided + inputs = [query, key, value] + if mask is not None: + inputs.append(mask) + if normalization_quantize_scale is not None: + inputs.append(normalization_quantize_scale) + + return op_utils.create_op( + Attention, + inputs, + normalization_operation=normalization_operation, + causal=causal, + decomposable=decomposable, + normalization_quantize_to_type=normalization_quantize_to_type, + ) diff --git a/tripy/nvtripy/trace/ops/attention.py b/tripy/nvtripy/trace/ops/attention.py new file mode 100644 index 000000000..7786d1562 --- /dev/null +++ b/tripy/nvtripy/trace/ops/attention.py @@ -0,0 +1,79 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +from typing import Optional + +from mlir_tensorrt.compiler import ir +from mlir_tensorrt.compiler.dialects import tensorrt +from nvtripy.common import datatype +from nvtripy.trace.ops import utils as op_utils +from nvtripy.trace.ops.base import TraceOp + + +@dataclass(repr=False) +class Attention(TraceOp): + normalization_operation: Optional[str] = "kSOFTMAX" + causal: bool = False + decomposable: bool = False + normalization_quantize_to_type: Optional[datatype.dtype] = None + + infer_rank = op_utils.InferRankPolicies.same_shape_as_input(0) + + def infer_dtypes(self): + self.outputs[0].dtype = self.inputs[0].dtype + + def to_mlir(self, inputs, outputs): + assert len(inputs) >= 3, "Attention operation should have at least 3 inputs!" + + query, key, value = inputs[0], inputs[1], inputs[2] + mask = inputs[3] if len(inputs) > 3 else None + normalization_quantize_scale = inputs[4] if len(inputs) > 4 else None + + # Create attributes + normalization_operation_attr = None + if self.normalization_operation: + norm_op_str = "k" + self.normalization_operation.upper() + normalization_operation_attr = tensorrt.AttentionNormalizationOpAttr.get(norm_op_str) + + causal_attr = None + if self.causal: + causal_attr = ir.BoolAttr.get(self.causal) + + decomposable_attr = None + if self.decomposable: + decomposable_attr = ir.BoolAttr.get(self.decomposable) + + normalization_quantize_to_type_attr = None + if self.normalization_quantize_to_type: + trt_dtype_str = op_utils.get_trt_dtype_enum_str(self.normalization_quantize_to_type) + normalization_quantize_to_type_attr = tensorrt.DataTypeAttr.get(trt_dtype_str) + + return [ + tensorrt.attention( + outputs[0], + query, + key, + value, + mask=mask, + normalization_quantize_scale=normalization_quantize_scale, + normalization_operation=normalization_operation_attr, + causal=causal_attr, + decomposable=decomposable_attr, + normalization_quantize_to_type=normalization_quantize_to_type_attr, + ) + ] diff --git a/tripy/nvtripy/trace/ops/utils.py b/tripy/nvtripy/trace/ops/utils.py index 68e01f1a1..c78379116 100644 --- a/tripy/nvtripy/trace/ops/utils.py +++ b/tripy/nvtripy/trace/ops/utils.py @@ -70,3 +70,34 @@ def get_broadcast_in_dim(input_rank: int, output_rank: int) -> List[int]: assert len(broadcast_dimensions) == input_rank return broadcast_dimensions + + +## +## Datatype conversion +## + + +def get_trt_dtype_enum_str(dtype: "nvtripy.dtype") -> str: + """ + Converts a tripy datatype to its corresponding TensorRT DataType enum string. + + Args: + dtype: A tripy datatype + + Returns: + The TensorRT DataType enum string (e.g., "kFP8", "kINT8") + """ + from nvtripy.common import datatype + + TRIPY_DTYPE_TO_TRT_ENUM = { + datatype.float32: "kFLOAT", + datatype.float16: "kHALF", + datatype.int8: "kINT8", + datatype.int32: "kINT32", + datatype.bool: "kBOOL", + datatype.float8: "kFP8", + datatype.bfloat16: "kBF16", + datatype.int64: "kINT64", + datatype.int4: "kINT4", + } + return TRIPY_DTYPE_TO_TRT_ENUM[dtype]