Skip to content

Commit a58916a

Browse files
committed
Add tp.attention op
1 parent b344e42 commit a58916a

File tree

3 files changed

+290
-0
lines changed

3 files changed

+290
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Optional
19+
20+
from nvtripy import export
21+
from nvtripy.common import datatype
22+
from nvtripy.frontend.ops import utils as op_utils
23+
from nvtripy.trace.ops.attention import Attention
24+
from nvtripy.utils import wrappers
25+
26+
27+
@export.public_api(document_under="operations/functions")
28+
@wrappers.interface(
29+
dtype_constraints={
30+
"query": "T1",
31+
"key": "T1",
32+
"value": "T1",
33+
wrappers.RETURN_VALUE: "T1",
34+
},
35+
dtype_variables={
36+
"T1": ["float32", "float16", "bfloat16"],
37+
},
38+
)
39+
def attention(
40+
query: "nvtripy.Tensor",
41+
key: "nvtripy.Tensor",
42+
value: "nvtripy.Tensor",
43+
*,
44+
mask: Optional["nvtripy.Tensor"] = None,
45+
normalization_quantize_scale: Optional["nvtripy.Tensor"] = None,
46+
normalization_operation: str = "softmax",
47+
causal: bool = False,
48+
decomposable: bool = False,
49+
normalization_quantize_to_type: Optional[datatype.dtype] = None,
50+
) -> "nvtripy.Tensor":
51+
r"""
52+
Performs a fused multi-head attention operation.
53+
54+
This operation implements the attention mechanism:
55+
56+
.. math::
57+
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
58+
59+
The operation consists of:
60+
61+
1. Matrix multiplication between query and transposed key (BMM1)
62+
2. Optional masking
63+
3. Normalization (typically softmax)
64+
4. Optional quantization of the normalized output
65+
5. Matrix multiplication with value (BMM2)
66+
67+
Args:
68+
query: The query tensor with shape ``[batch_size, num_heads_query, sequence_length_query, dim_head]``.
69+
key: The key tensor with shape ``[batch_size, num_heads_key, sequence_length_key, dim_head]``.
70+
value: The value tensor with shape ``[batch_size, num_heads_value, sequence_length_value, dim_head]``.
71+
mask: Optional mask tensor with shape
72+
``[batch_size, num_heads_query, sequence_length_query, sequence_length_key]``.
73+
For boolean masks (dtype=bool), ``True`` indicates positions that are allowed to attend.
74+
For float masks, the values are added to the attention scores before normalization.
75+
normalization_quantize_scale: Optional scale tensor for quantizing the normalization output.
76+
Required if ``normalization_quantize_to_type`` is specified.
77+
normalization_operation: The normalization operation to use. Must be one of "softmax" or "none".
78+
Defaults to ``"softmax"``.
79+
causal: If ``True``, applies causal (autoregressive) masking where each position can only
80+
attend to earlier positions. Cannot be used together with explicit ``mask``. Defaults to ``False``.
81+
decomposable: If ``True``, allows the operation to be decomposed into multiple kernels if
82+
no fused kernel is available. Defaults to ``False``.
83+
normalization_quantize_to_type: Optional data type for quantizing the normalization output.
84+
Must be either :class:`nvtripy.float8` or :class:`nvtripy.int8`.
85+
Requires ``normalization_quantize_scale`` to be provided.
86+
87+
Returns:
88+
The attention output tensor with shape ``[batch_size, num_heads_query, sequence_length_query, dim_head]``.
89+
90+
.. code-block:: python
91+
:linenos:
92+
:caption: Basic Attention
93+
94+
batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64
95+
query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
96+
key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
97+
value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
98+
99+
output = tp.attention(query, key, value)
100+
101+
assert output.shape == (batch_size, num_heads, seq_len, head_dim)
102+
103+
.. code-block:: python
104+
:linenos:
105+
:caption: Attention with Quantization
106+
107+
batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64
108+
query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
109+
key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
110+
value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
111+
112+
# Quantize the normalization output (softmax) to float8
113+
mask = tp.ones((batch_size, num_heads, seq_len, seq_len), dtype=tp.bool)
114+
scale = tp.Tensor([1.0], dtype=tp.float16)
115+
116+
output = tp.attention(query, key, value, mask=mask,
117+
normalization_quantize_scale=scale,
118+
normalization_quantize_to_type=tp.float8)
119+
120+
assert output.shape == (batch_size, num_heads, seq_len, head_dim)
121+
122+
.. code-block:: python
123+
:linenos:
124+
:caption: Attention with Mask
125+
126+
batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64
127+
query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
128+
key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
129+
value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16)
130+
131+
# Create a boolean mask where True indicates positions that can attend
132+
mask = tp.ones((batch_size, num_heads, seq_len, seq_len), dtype=tp.bool)
133+
134+
output = tp.attention(query, key, value, mask=mask)
135+
136+
assert output.shape == (batch_size, num_heads, seq_len, head_dim)
137+
"""
138+
from nvtripy.common.exception import raise_error
139+
140+
if normalization_operation not in ("softmax", "none"):
141+
raise_error(
142+
f"Invalid normalization operation: {normalization_operation}. Must be one of 'softmax' or 'none'.",
143+
)
144+
145+
# Validation checks
146+
if causal and mask is not None:
147+
raise_error(
148+
"Cannot use both `causal` and `mask` at the same time.",
149+
details=[
150+
"The `causal` parameter applies implicit causal masking.",
151+
"Please use either `causal=True` or provide an explicit `mask`.",
152+
],
153+
)
154+
155+
if normalization_quantize_to_type is not None:
156+
if normalization_quantize_scale is None:
157+
raise_error(
158+
"`normalization_quantize_scale` must be provided when `normalization_quantize_to_type` is specified.",
159+
)
160+
161+
if normalization_quantize_to_type not in (datatype.float8, datatype.int8):
162+
raise_error(
163+
f"`normalization_quantize_to_type` must be either float8 or int8.",
164+
details=[f"Got: {normalization_quantize_to_type}"],
165+
)
166+
167+
# Collect inputs based on what's provided
168+
inputs = [query, key, value]
169+
if mask is not None:
170+
inputs.append(mask)
171+
if normalization_quantize_scale is not None:
172+
inputs.append(normalization_quantize_scale)
173+
174+
return op_utils.create_op(
175+
Attention,
176+
inputs,
177+
normalization_operation=normalization_operation,
178+
causal=causal,
179+
decomposable=decomposable,
180+
normalization_quantize_to_type=normalization_quantize_to_type,
181+
)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from dataclasses import dataclass
19+
from typing import Optional
20+
21+
from mlir_tensorrt.compiler import ir
22+
from mlir_tensorrt.compiler.dialects import tensorrt
23+
from nvtripy.common import datatype
24+
from nvtripy.trace.ops import utils as op_utils
25+
from nvtripy.trace.ops.base import TraceOp
26+
27+
28+
@dataclass(repr=False)
29+
class Attention(TraceOp):
30+
normalization_operation: Optional[str] = "kSOFTMAX"
31+
causal: bool = False
32+
decomposable: bool = False
33+
normalization_quantize_to_type: Optional[datatype.dtype] = None
34+
35+
infer_rank = op_utils.InferRankPolicies.same_shape_as_input(0)
36+
37+
def infer_dtypes(self):
38+
self.outputs[0].dtype = self.inputs[0].dtype
39+
40+
def to_mlir(self, inputs, outputs):
41+
assert len(inputs) >= 3, "Attention operation should have at least 3 inputs!"
42+
43+
query, key, value = inputs[0], inputs[1], inputs[2]
44+
mask = inputs[3] if len(inputs) > 3 else None
45+
normalization_quantize_scale = inputs[4] if len(inputs) > 4 else None
46+
47+
# Create attributes
48+
normalization_operation_attr = None
49+
if self.normalization_operation:
50+
norm_op_str = "k" + self.normalization_operation.upper()
51+
normalization_operation_attr = tensorrt.AttentionNormalizationOpAttr.get(norm_op_str)
52+
53+
causal_attr = None
54+
if self.causal:
55+
causal_attr = ir.BoolAttr.get(self.causal)
56+
57+
decomposable_attr = None
58+
if self.decomposable:
59+
decomposable_attr = ir.BoolAttr.get(self.decomposable)
60+
61+
normalization_quantize_to_type_attr = None
62+
if self.normalization_quantize_to_type:
63+
trt_dtype_str = op_utils.get_trt_dtype_enum_str(self.normalization_quantize_to_type)
64+
normalization_quantize_to_type_attr = tensorrt.DataTypeAttr.get(trt_dtype_str)
65+
66+
return [
67+
tensorrt.attention(
68+
query,
69+
key,
70+
value,
71+
mask=mask,
72+
normalization_quantize_scale=normalization_quantize_scale,
73+
normalization_operation=normalization_operation_attr,
74+
causal=causal_attr,
75+
decomposable=decomposable_attr,
76+
normalization_quantize_to_type=normalization_quantize_to_type_attr,
77+
)
78+
]

tripy/nvtripy/trace/ops/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,34 @@ def get_broadcast_in_dim(input_rank: int, output_rank: int) -> List[int]:
7070

7171
assert len(broadcast_dimensions) == input_rank
7272
return broadcast_dimensions
73+
74+
75+
##
76+
## Datatype conversion
77+
##
78+
79+
80+
def get_trt_dtype_enum_str(dtype: "nvtripy.dtype") -> str:
81+
"""
82+
Converts a tripy datatype to its corresponding TensorRT DataType enum string.
83+
84+
Args:
85+
dtype: A tripy datatype
86+
87+
Returns:
88+
The TensorRT DataType enum string (e.g., "kFP8", "kINT8")
89+
"""
90+
from nvtripy.common import datatype
91+
92+
TRIPY_DTYPE_TO_TRT_ENUM = {
93+
datatype.float32: "kFLOAT",
94+
datatype.float16: "kHALF",
95+
datatype.int8: "kINT8",
96+
datatype.int32: "kINT32",
97+
datatype.bool: "kBOOL",
98+
datatype.float8: "kFP8",
99+
datatype.bfloat16: "kBF16",
100+
datatype.int64: "kINT64",
101+
datatype.int4: "kINT4",
102+
}
103+
return TRIPY_DTYPE_TO_TRT_ENUM[dtype]

0 commit comments

Comments
 (0)