-
Notifications
You must be signed in to change notification settings - Fork 19
Add tp.attention op #709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
yizhuoz004
wants to merge
1
commit into
main
Choose a base branch
from
tp-attention
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Add tp.attention op #709
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| # | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from typing import Optional | ||
|
|
||
| from nvtripy import export | ||
| from nvtripy.common import datatype | ||
| from nvtripy.frontend.ops import utils as op_utils | ||
| from nvtripy.trace.ops.attention import Attention | ||
| from nvtripy.utils import wrappers | ||
|
|
||
|
|
||
| @export.public_api(document_under="operations/functions") | ||
| @wrappers.interface( | ||
| dtype_constraints={ | ||
| "query": "T1", | ||
| "key": "T1", | ||
| "value": "T1", | ||
| wrappers.RETURN_VALUE: "T1", | ||
| }, | ||
| dtype_variables={ | ||
| "T1": ["float32", "float16", "bfloat16"], | ||
| }, | ||
| ) | ||
| def attention( | ||
| query: "nvtripy.Tensor", | ||
| key: "nvtripy.Tensor", | ||
| value: "nvtripy.Tensor", | ||
| *, | ||
| mask: Optional["nvtripy.Tensor"] = None, | ||
| normalization_quantize_scale: Optional["nvtripy.Tensor"] = None, | ||
| normalization_operation: str = "softmax", | ||
| causal: bool = False, | ||
| decomposable: bool = False, | ||
| normalization_quantize_to_type: Optional[datatype.dtype] = None, | ||
| ) -> "nvtripy.Tensor": | ||
| r""" | ||
| Performs a fused multi-head attention operation. | ||
| This operation implements the attention mechanism: | ||
| .. math:: | ||
| \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V | ||
| The operation consists of: | ||
| 1. Matrix multiplication between query and transposed key (BMM1) | ||
| 2. Optional masking | ||
| 3. Normalization (typically softmax) | ||
| 4. Optional quantization of the normalized output | ||
| 5. Matrix multiplication with value (BMM2) | ||
| Args: | ||
| query: The query tensor with shape ``[batch_size, num_heads_query, sequence_length_query, dim_head]``. | ||
| key: The key tensor with shape ``[batch_size, num_heads_key, sequence_length_key, dim_head]``. | ||
| value: The value tensor with shape ``[batch_size, num_heads_value, sequence_length_value, dim_head]``. | ||
| mask: Optional mask tensor with shape | ||
| ``[batch_size, num_heads_query, sequence_length_query, sequence_length_key]``. | ||
| For boolean masks (dtype=bool), ``True`` indicates positions that are allowed to attend. | ||
| For float masks, the values are added to the attention scores before normalization. | ||
| normalization_quantize_scale: Optional scale tensor for quantizing the normalization output. | ||
| Required if ``normalization_quantize_to_type`` is specified. | ||
| normalization_operation: The normalization operation to use. Must be one of "softmax" or "none". | ||
| Defaults to ``"softmax"``. | ||
| causal: If ``True``, applies causal (autoregressive) masking where each position can only | ||
| attend to earlier positions. Cannot be used together with explicit ``mask``. Defaults to ``False``. | ||
| decomposable: If ``True``, allows the operation to be decomposed into multiple kernels if | ||
| no fused kernel is available. Defaults to ``False``. | ||
| normalization_quantize_to_type: Optional data type for quantizing the normalization output. | ||
| Must be either :class:`nvtripy.float8` or :class:`nvtripy.int8`. | ||
| Requires ``normalization_quantize_scale`` to be provided. | ||
| Returns: | ||
| The attention output tensor with shape ``[batch_size, num_heads_query, sequence_length_query, dim_head]``. | ||
| .. code-block:: python | ||
| :linenos: | ||
| :caption: Basic Attention | ||
| batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64 | ||
| query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) | ||
| key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) | ||
| value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) | ||
| output = tp.attention(query, key, value) | ||
| assert output.shape == (batch_size, num_heads, seq_len, head_dim) | ||
| .. code-block:: python | ||
| :linenos: | ||
| :caption: Attention with Quantization | ||
| batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64 | ||
| query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) | ||
| key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) | ||
| value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) | ||
| # Quantize the normalization output (softmax) to float8 | ||
| mask = tp.ones((batch_size, num_heads, seq_len, seq_len), dtype=tp.bool) | ||
| scale = tp.Tensor([1.0], dtype=tp.float16) | ||
| output = tp.attention(query, key, value, mask=mask, | ||
| normalization_quantize_scale=scale, | ||
| normalization_quantize_to_type=tp.float8) | ||
| assert output.shape == (batch_size, num_heads, seq_len, head_dim) | ||
| .. code-block:: python | ||
| :linenos: | ||
| :caption: Attention with Mask | ||
| batch_size, num_heads, seq_len, head_dim = 2, 8, 128, 64 | ||
| query = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) | ||
| key = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) | ||
| value = tp.iota((batch_size, num_heads, seq_len, head_dim), dtype=tp.float16) | ||
| # Create a boolean mask where True indicates positions that can attend | ||
| mask = tp.ones((batch_size, num_heads, seq_len, seq_len), dtype=tp.bool) | ||
| output = tp.attention(query, key, value, mask=mask) | ||
| assert output.shape == (batch_size, num_heads, seq_len, head_dim) | ||
| """ | ||
| from nvtripy.common.exception import raise_error | ||
|
|
||
| if normalization_operation not in ("softmax", "none"): | ||
| raise_error( | ||
| f"Invalid normalization operation: {normalization_operation}. Must be one of 'softmax' or 'none'.", | ||
| ) | ||
|
|
||
| # Validation checks | ||
| if causal and mask is not None: | ||
| raise_error( | ||
| "Cannot use both `causal` and `mask` at the same time.", | ||
| details=[ | ||
| "The `causal` parameter applies implicit causal masking.", | ||
| "Please use either `causal=True` or provide an explicit `mask`.", | ||
| ], | ||
| ) | ||
|
|
||
| if normalization_quantize_to_type is not None: | ||
| if normalization_quantize_scale is None: | ||
| raise_error( | ||
| "`normalization_quantize_scale` must be provided when `normalization_quantize_to_type` is specified.", | ||
| ) | ||
|
|
||
| if normalization_quantize_to_type not in (datatype.float8, datatype.int8): | ||
| raise_error( | ||
| f"`normalization_quantize_to_type` must be either float8 or int8.", | ||
| details=[f"Got: {normalization_quantize_to_type}"], | ||
| ) | ||
|
|
||
| # Collect inputs based on what's provided | ||
| inputs = [query, key, value] | ||
| if mask is not None: | ||
| inputs.append(mask) | ||
| if normalization_quantize_scale is not None: | ||
| inputs.append(normalization_quantize_scale) | ||
|
|
||
| return op_utils.create_op( | ||
| Attention, | ||
| inputs, | ||
| normalization_operation=normalization_operation, | ||
| causal=causal, | ||
| decomposable=decomposable, | ||
| normalization_quantize_to_type=normalization_quantize_to_type, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Optional | ||
|
|
||
| from mlir_tensorrt.compiler import ir | ||
| from mlir_tensorrt.compiler.dialects import tensorrt | ||
| from nvtripy.common import datatype | ||
| from nvtripy.trace.ops import utils as op_utils | ||
| from nvtripy.trace.ops.base import TraceOp | ||
|
|
||
|
|
||
| @dataclass(repr=False) | ||
| class Attention(TraceOp): | ||
| normalization_operation: Optional[str] = "kSOFTMAX" | ||
| causal: bool = False | ||
| decomposable: bool = False | ||
| normalization_quantize_to_type: Optional[datatype.dtype] = None | ||
|
|
||
| infer_rank = op_utils.InferRankPolicies.same_shape_as_input(0) | ||
|
|
||
| def infer_dtypes(self): | ||
| self.outputs[0].dtype = self.inputs[0].dtype | ||
|
|
||
| def to_mlir(self, inputs, outputs): | ||
| assert len(inputs) >= 3, "Attention operation should have at least 3 inputs!" | ||
|
|
||
| query, key, value = inputs[0], inputs[1], inputs[2] | ||
| mask = inputs[3] if len(inputs) > 3 else None | ||
| normalization_quantize_scale = inputs[4] if len(inputs) > 4 else None | ||
|
|
||
| # Create attributes | ||
| normalization_operation_attr = None | ||
| if self.normalization_operation: | ||
| norm_op_str = "k" + self.normalization_operation.upper() | ||
| normalization_operation_attr = tensorrt.AttentionNormalizationOpAttr.get(norm_op_str) | ||
|
|
||
| causal_attr = None | ||
| if self.causal: | ||
| causal_attr = ir.BoolAttr.get(self.causal) | ||
|
|
||
| decomposable_attr = None | ||
| if self.decomposable: | ||
| decomposable_attr = ir.BoolAttr.get(self.decomposable) | ||
|
|
||
| normalization_quantize_to_type_attr = None | ||
| if self.normalization_quantize_to_type: | ||
| trt_dtype_str = op_utils.get_trt_dtype_enum_str(self.normalization_quantize_to_type) | ||
| normalization_quantize_to_type_attr = tensorrt.DataTypeAttr.get(trt_dtype_str) | ||
|
|
||
| return [ | ||
| tensorrt.attention( | ||
| outputs[0], | ||
| query, | ||
| key, | ||
| value, | ||
| mask=mask, | ||
| normalization_quantize_scale=normalization_quantize_scale, | ||
| normalization_operation=normalization_operation_attr, | ||
| causal=causal_attr, | ||
| decomposable=decomposable_attr, | ||
| normalization_quantize_to_type=normalization_quantize_to_type_attr, | ||
| ) | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,3 +70,34 @@ def get_broadcast_in_dim(input_rank: int, output_rank: int) -> List[int]: | |
|
|
||
| assert len(broadcast_dimensions) == input_rank | ||
| return broadcast_dimensions | ||
|
|
||
|
|
||
| ## | ||
| ## Datatype conversion | ||
| ## | ||
|
|
||
|
|
||
| def get_trt_dtype_enum_str(dtype: "nvtripy.dtype") -> str: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should make this a property of |
||
| """ | ||
| Converts a tripy datatype to its corresponding TensorRT DataType enum string. | ||
|
|
||
| Args: | ||
| dtype: A tripy datatype | ||
|
|
||
| Returns: | ||
| The TensorRT DataType enum string (e.g., "kFP8", "kINT8") | ||
| """ | ||
| from nvtripy.common import datatype | ||
|
|
||
| TRIPY_DTYPE_TO_TRT_ENUM = { | ||
| datatype.float32: "kFLOAT", | ||
| datatype.float16: "kHALF", | ||
| datatype.int8: "kINT8", | ||
| datatype.int32: "kINT32", | ||
| datatype.bool: "kBOOL", | ||
| datatype.float8: "kFP8", | ||
| datatype.bfloat16: "kBF16", | ||
| datatype.int64: "kINT64", | ||
| datatype.int4: "kINT4", | ||
| } | ||
| return TRIPY_DTYPE_TO_TRT_ENUM[dtype] | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the inputs to all 3 examples are the same, can we omit the input initialization in the docs so that it is easier to tell what is changing between the samples? Also, can we have the quantization sample omit the mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm conflicted on this - on one hand, it will make the examples much cleaner, but on the other, it'll mean that you can't just copy-paste the example code and have it work.
If all the tensors are the same shape, maybe a compromise could be:
although we would need to clarify that it's only being done for the sake of brevity and they don't all need to be the same tensor.