Skip to content

Conversation

@yizhuoz004
Copy link
Collaborator

No description provided.

Comment on lines 68 to 70
query: The query tensor with shape ``[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]``.
key: The key tensor with shape ``[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]``.
value: The value tensor with shape ``[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]``.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to verify why in TRT we use numHeadsQuery and numHeadsKeyValue separately.

@yizhuoz004 yizhuoz004 changed the title Add AttentionOp Add tp.attention op Nov 4, 2025
5. Matrix multiplication with value (BMM2)
Args:
query: The query tensor with shape ``[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we use snake_case to be consistent with the rest of the documentation?

##


def get_trt_dtype_enum_str(dtype: "nvtripy.dtype") -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make this a property of dtype so we don't have to update multiple places when adding new dtypes.

assert output.shape == (batch_size, num_heads, seq_len, head_dim)
.. code-block:: python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the inputs to all 3 examples are the same, can we omit the input initialization in the docs so that it is easier to tell what is changing between the samples? Also, can we have the quantization sample omit the mask?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm conflicted on this - on one hand, it will make the examples much cleaner, but on the other, it'll mean that you can't just copy-paste the example code and have it work.

If all the tensors are the same shape, maybe a compromise could be:

query = key = value = tp.iota(...)

although we would need to clarify that it's only being done for the sake of brevity and they don't all need to be the same tensor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants