-
Notifications
You must be signed in to change notification settings - Fork 19
Add tp.attention op #709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add tp.attention op #709
Conversation
198b12c to
cf6e04e
Compare
| query: The query tensor with shape ``[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]``. | ||
| key: The key tensor with shape ``[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]``. | ||
| value: The value tensor with shape ``[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to verify why in TRT we use numHeadsQuery and numHeadsKeyValue separately.
| 5. Matrix multiplication with value (BMM2) | ||
| Args: | ||
| query: The query tensor with shape ``[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we use snake_case to be consistent with the rest of the documentation?
| ## | ||
|
|
||
|
|
||
| def get_trt_dtype_enum_str(dtype: "nvtripy.dtype") -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should make this a property of dtype so we don't have to update multiple places when adding new dtypes.
cf6e04e to
a58916a
Compare
| assert output.shape == (batch_size, num_heads, seq_len, head_dim) | ||
| .. code-block:: python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the inputs to all 3 examples are the same, can we omit the input initialization in the docs so that it is easier to tell what is changing between the samples? Also, can we have the quantization sample omit the mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm conflicted on this - on one hand, it will make the examples much cleaner, but on the other, it'll mean that you can't just copy-paste the example code and have it work.
If all the tensors are the same shape, maybe a compromise could be:
query = key = value = tp.iota(...)although we would need to clarify that it's only being done for the sake of brevity and they don't all need to be the same tensor.
a58916a to
d3949d7
Compare
No description provided.