Skip to content

Conversation

@jhalakpatel
Copy link
Collaborator

@jhalakpatel jhalakpatel commented Oct 4, 2024

Add optional stride validation in MemRefValue::create to compute canonical stride and compare against given strides while creaing a memref view from DLPack tensors. We need to handle special cases for zero-sized and unit-sized dimensions since frameworks deal with them arbitrarily while converting to the corresponding DLPack tensor. Add Python tests to verify both canonical and non-canonical stride validation.

@jhalakpatel jhalakpatel force-pushed the jhalakp-memref-stride-check branch 2 times, most recently from b7059ca to 238e993 Compare October 4, 2024 00:33
@jhalakpatel jhalakpatel changed the title [API/MemRefValue] Validate strides against canonical strides for non-empty shapes [API/MemRef] Implement canonical stride validation for MemRefValue creation Oct 4, 2024
@jhalakpatel jhalakpatel force-pushed the jhalakp-memref-stride-check branch from 238e993 to ab28976 Compare October 4, 2024 01:00
jhalakpatel added a commit to jhalakpatel/TensorRT-Incubator that referenced this pull request Oct 4, 2024
MLIR-TensorRT requires strides for function arguments and results in canonical order.

NVIDIA#252 adds a check to validate memref stride against a canonical stride order. In Tripy, memref strides are derived from framework DL Pack tensors. Creating a memref with a non-canonical DL Pack tensor stride throws an exception.

Add a try-catch block to catch such an exception and augment with suggestions on creating a DL Pack tensor with canonical stride for Tripy-supported frameworks.

Add unit tests to create a non-canonical stride tensor to validate exceptions and suggestions.
jhalakpatel added a commit to jhalakpatel/TensorRT-Incubator that referenced this pull request Oct 4, 2024
MLIR-TensorRT requires strides for function arguments and results in canonical order.

NVIDIA#252 adds a check to validate memref stride against a canonical stride order. In Tripy, memref strides are derived from framework DL Pack tensors. Creating a memref with a non-canonical DL Pack tensor stride throws an exception.

Add a try-catch block to catch such an exception and augment with suggestions on creating a DL Pack tensor with canonical stride for Tripy-supported frameworks.

Add unit tests to create a non-canonical stride tensor to validate exceptions and suggestions.
@jhalakpatel jhalakpatel force-pushed the jhalakp-memref-stride-check branch 2 times, most recently from 6204f3a to ec624e7 Compare October 4, 2024 21:08
Copy link
Collaborator

@christopherbate christopherbate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion has to be a parameter to the Python function as well, not just the C++ function.

@jhalakpatel jhalakpatel force-pushed the jhalakp-memref-stride-check branch 2 times, most recently from d26730f to d67986e Compare October 8, 2024 18:01
Copy link
Collaborator

@christopherbate christopherbate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments, otherwise LGTM.

Add optional stride validation in `MemRefValue::create` to compute canonical stride and compare against given strides while creaing a memref view from DLPack tensors. We need to handle special cases for zero-sized and unit-sized dimensions since frameworks deal with them arbitrarily while converting to the corresponding DLPack tensor. Add Python tests to verify both canonical and non-canonical stride validation.
@jhalakpatel jhalakpatel force-pushed the jhalakp-memref-stride-check branch from d67986e to 4a91333 Compare October 8, 2024 18:33
@jhalakpatel jhalakpatel merged commit 4dbc5cc into main Oct 8, 2024
1 check failed
@jhalakpatel jhalakpatel deleted the jhalakp-memref-stride-check branch October 8, 2024 18:34
jhalakpatel added a commit to jhalakpatel/TensorRT-Incubator that referenced this pull request Oct 15, 2024
MLIR-TensorRT requires strides for function arguments and results in canonical order.

NVIDIA#252 adds a check to validate memref stride against a canonical stride order. In Tripy, memref strides are derived from framework DL Pack tensors. Creating a memref with a non-canonical DL Pack tensor stride throws an exception.

Add a try-catch block to catch such an exception and augment with suggestions on creating a DL Pack tensor with canonical stride for Tripy-supported frameworks.

Add unit tests to create a non-canonical stride tensor to validate exceptions and suggestions.
jhalakpatel added a commit that referenced this pull request Oct 15, 2024
MLIR-TensorRT requires strides for function arguments and results in
canonical order.

#252 adds a check to
validate memref stride against a canonical stride order. In Tripy,
memref strides are derived from framework DL Pack tensors. Creating a
memref with a non-canonical DL Pack tensor stride throws an exception.

Add a try-catch block to catch such an exception and augment with
suggestions on creating a DL Pack tensor with canonical stride for
Tripy-supported frameworks.

Add unit tests to create a non-canonical stride tensor to validate
exceptions and suggestions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants