Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions tripy/nvtripy/frontend/ops/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Optional

from nvtripy import export
from nvtripy.common import datatype
from nvtripy.frontend.ops import utils as op_utils
from nvtripy.trace.ops.attention import Attention
from nvtripy.utils import wrappers


@export.public_api(document_under="operations/functions")
@wrappers.interface(
dtype_constraints={
"query": "T1",
"key": "T1",
"value": "T1",
wrappers.RETURN_VALUE: "T1",
},
dtype_variables={
"T1": ["float32", "float16", "bfloat16"],
},
)
def attention(
query: "nvtripy.Tensor",
key: "nvtripy.Tensor",
value: "nvtripy.Tensor",
*,
mask: Optional["nvtripy.Tensor"] = None,
normalization_quantize_scale: Optional["nvtripy.Tensor"] = None,
normalization_operation: str = "softmax",
causal: bool = False,
decomposable: bool = False,
normalization_quantize_to_type: Optional[datatype.dtype] = None,
) -> "nvtripy.Tensor":
r"""
Performs a fused multi-head attention operation.
This operation implements the attention mechanism:
.. math::
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
The operation consists of:
1. Matrix multiplication between query and transposed key (BMM1)
2. Optional masking
3. Normalization (typically softmax)
4. Optional quantization of the normalized output
5. Matrix multiplication with value (BMM2)
Args:
query: The query tensor with shape ``[batch_size, num_heads_query, sequence_length_query, dim_head]``.
key: The key tensor with shape ``[batch_size, num_heads_key, sequence_length_key, dim_head]``.
value: The value tensor with shape ``[batch_size, num_heads_value, sequence_length_value, dim_head]``.
mask: Optional mask tensor with shape
``[batch_size, num_heads_query, sequence_length_query, sequence_length_key]``.
For boolean masks (dtype=bool), ``True`` indicates positions that are allowed to attend.
For float masks, the values are added to the attention scores before normalization.
normalization_quantize_scale: Optional scale tensor for quantizing the normalization output.
Required if ``normalization_quantize_to_type`` is specified.
normalization_operation: The normalization operation to use. Must be one of "softmax" or "none".
Defaults to ``"softmax"``.
causal: If ``True``, applies causal (autoregressive) masking where each position can only
attend to earlier positions. Cannot be used together with explicit ``mask``. Defaults to ``False``.
decomposable: If ``True``, allows the operation to be decomposed into multiple kernels if
no fused kernel is available. Defaults to ``False``.
normalization_quantize_to_type: Optional data type for quantizing the normalization output.
Must be either :class:`nvtripy.float8` or :class:`nvtripy.int8`.
Requires ``normalization_quantize_scale`` to be provided.
Returns:
The attention output tensor with shape ``[batch_size, num_heads_query, sequence_length_query, dim_head]``.
.. code-block:: python
:linenos:
:caption: Basic Attention
batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64
query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
output = tp.attention(query, key, value)
assert output.shape == (batch_size, num_heads, seq_len, head_dim)
.. code-block:: python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the inputs to all 3 examples are the same, can we omit the input initialization in the docs so that it is easier to tell what is changing between the samples? Also, can we have the quantization sample omit the mask?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm conflicted on this - on one hand, it will make the examples much cleaner, but on the other, it'll mean that you can't just copy-paste the example code and have it work.

If all the tensors are the same shape, maybe a compromise could be:

query = key = value = tp.iota(...)

although we would need to clarify that it's only being done for the sake of brevity and they don't all need to be the same tensor.

:linenos:
:caption: Attention with Quantization
batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64
query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
# Quantize the normalization output (softmax) to float8
mask = tp.ones((batch_size, num_heads, seq_len, seq_len), dtype=tp.bool)
scale = tp.Tensor([1.0], dtype=tp.float16)
output = tp.attention(query, key, value, mask=mask,
normalization_quantize_scale=scale,
normalization_quantize_to_type=tp.float8)
assert output.shape == (batch_size, num_heads, seq_len, head_dim)
.. code-block:: python
:linenos:
:caption: Attention with Mask
batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64
query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
# Create a boolean mask where True indicates positions that can attend
mask = tp.ones((batch_size, num_heads, seq_len, seq_len), dtype=tp.bool)
output = tp.attention(query, key, value, mask=mask)
assert output.shape == (batch_size, num_heads, seq_len, head_dim)
"""
from nvtripy.common.exception import raise_error

if normalization_operation not in ("softmax", "none"):
raise_error(
f"Invalid normalization operation: {normalization_operation}. Must be one of 'softmax' or 'none'.",
)

# Validation checks
if causal and mask is not None:
raise_error(
"Cannot use both `causal` and `mask` at the same time.",
details=[
"The `causal` parameter applies implicit causal masking.",
"Please use either `causal=True` or provide an explicit `mask`.",
],
)

if normalization_quantize_to_type is not None:
if normalization_quantize_scale is None:
raise_error(
"`normalization_quantize_scale` must be provided when `normalization_quantize_to_type` is specified.",
)

if normalization_quantize_to_type not in (datatype.float8, datatype.int8):
raise_error(
f"`normalization_quantize_to_type` must be either float8 or int8.",
details=[f"Got: {normalization_quantize_to_type}"],
)

# Collect inputs based on what's provided
inputs = [query, key, value]
if mask is not None:
inputs.append(mask)
if normalization_quantize_scale is not None:
inputs.append(normalization_quantize_scale)

return op_utils.create_op(
Attention,
inputs,
normalization_operation=normalization_operation,
causal=causal,
decomposable=decomposable,
normalization_quantize_to_type=normalization_quantize_to_type,
)
79 changes: 79 additions & 0 deletions tripy/nvtripy/trace/ops/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass
from typing import Optional

from mlir_tensorrt.compiler import ir
from mlir_tensorrt.compiler.dialects import tensorrt
from nvtripy.common import datatype
from nvtripy.trace.ops import utils as op_utils
from nvtripy.trace.ops.base import TraceOp


@dataclass(repr=False)
class Attention(TraceOp):
normalization_operation: Optional[str] = "kSOFTMAX"
causal: bool = False
decomposable: bool = False
normalization_quantize_to_type: Optional[datatype.dtype] = None

infer_rank = op_utils.InferRankPolicies.same_shape_as_input(0)

def infer_dtypes(self):
self.outputs[0].dtype = self.inputs[0].dtype

def to_mlir(self, inputs, outputs):
assert len(inputs) >= 3, "Attention operation should have at least 3 inputs!"

query, key, value = inputs[0], inputs[1], inputs[2]
mask = inputs[3] if len(inputs) > 3 else None
normalization_quantize_scale = inputs[4] if len(inputs) > 4 else None

# Create attributes
normalization_operation_attr = None
if self.normalization_operation:
norm_op_str = "k" + self.normalization_operation.upper()
normalization_operation_attr = tensorrt.AttentionNormalizationOpAttr.get(norm_op_str)

causal_attr = None
if self.causal:
causal_attr = ir.BoolAttr.get(self.causal)

decomposable_attr = None
if self.decomposable:
decomposable_attr = ir.BoolAttr.get(self.decomposable)

normalization_quantize_to_type_attr = None
if self.normalization_quantize_to_type:
trt_dtype_str = op_utils.get_trt_dtype_enum_str(self.normalization_quantize_to_type)
normalization_quantize_to_type_attr = tensorrt.DataTypeAttr.get(trt_dtype_str)

return [
tensorrt.attention(
outputs[0],
query,
key,
value,
mask=mask,
normalization_quantize_scale=normalization_quantize_scale,
normalization_operation=normalization_operation_attr,
causal=causal_attr,
decomposable=decomposable_attr,
normalization_quantize_to_type=normalization_quantize_to_type_attr,
)
]
31 changes: 31 additions & 0 deletions tripy/nvtripy/trace/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,34 @@ def get_broadcast_in_dim(input_rank: int, output_rank: int) -> List[int]:

assert len(broadcast_dimensions) == input_rank
return broadcast_dimensions


##
## Datatype conversion
##


def get_trt_dtype_enum_str(dtype: "nvtripy.dtype") -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make this a property of dtype so we don't have to update multiple places when adding new dtypes.

"""
Converts a tripy datatype to its corresponding TensorRT DataType enum string.

Args:
dtype: A tripy datatype

Returns:
The TensorRT DataType enum string (e.g., "kFP8", "kINT8")
"""
from nvtripy.common import datatype

TRIPY_DTYPE_TO_TRT_ENUM = {
datatype.float32: "kFLOAT",
datatype.float16: "kHALF",
datatype.int8: "kINT8",
datatype.int32: "kINT32",
datatype.bool: "kBOOL",
datatype.float8: "kFP8",
datatype.bfloat16: "kBF16",
datatype.int64: "kINT64",
datatype.int4: "kINT4",
}
return TRIPY_DTYPE_TO_TRT_ENUM[dtype]