-
Notifications
You must be signed in to change notification settings - Fork 58
Fix deterministic indexing with broadcast #1705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the index_put
implementation on XPU by ensuring deterministic indexing, centralizing shape logic, and bolstering test coverage.
- Introduce a
valsShape
helper to compute expanded-value shapes. - Extend
computeLinearIndex
andmakeLinearIndex
to returndims_before
anddims_indexed
. - Simplify value expansion in
index_put_deterministic_kernel
viavalsShape
. - Add new deterministic tests for
index_put
with optional tensors and shape-mismatch checks.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
File | Description |
---|---|
test/xpu/test_indexing_xpu.py | Added deterministic index_put tests, including 0D/1D/2D values and mismatch assertions |
src/ATen/native/xpu/sycl/IndexingUtils.h | Extended computeLinearIndex /makeLinearIndex to return two new dimension counts |
src/ATen/native/xpu/sycl/Indexing.cpp | Added valsShape helper and replaced manual expansion in the deterministic kernel |
Comments suppressed due to low confidence (2)
test/xpu/test_indexing_xpu.py:18
- [nitpick] The helper names
func
andfunc1
are ambiguous—consider renaming them to clearly reflect their purpose (e.g.,index_put_with_guard
andsimple_index_put
).
def func(x, i, v):
test/xpu/test_indexing_xpu.py:35
- [nitpick] Variable
values2d
does not match thevalue0d
/value1d
pattern—rename tovalue2d
for consistency.
values2d = torch.randn(n, 1)
@chunhuanMeng Pls remove |
done |
@sys_pytorchxpubot triage result for run 15864776461Triage bot UT analaysis result for reference only, please note unique error message only report once: |
This PR is mirror of pytorch/pytorch#154296 |
Introduces enhancements to the
index_put
implementation for XPU tensors, focusing on deterministic behavior, improved shape handling, and expanded test coverage. Key changes include adding new helper functions, extending themakeLinearIndex
andcomputeLinearIndex
methods, and updating the associated test suite.Enhancements to
index_put
Implementation:New Helper Function for Shape Handling:
valsShape
to compute the target shape for expanded values duringindex_put
operations. This simplifies and centralizes shape manipulation logic. (src/ATen/native/xpu/sycl/Indexing.cpp
)Extended
makeLinearIndex
andcomputeLinearIndex
:dims_before
anddims_indexed
to track dimensions before and during indexing. These are now returned as part of the tuple fromcomputeLinearIndex
and propagated throughmakeLinearIndex
. (src/ATen/native/xpu/sycl/IndexingUtils.h
)Simplified Value Expansion in
index_put_deterministic_kernel
:valsShape
. This makes the code more concise and reduces duplication. (src/ATen/native/xpu/sycl/Indexing.cpp
)Test Suite Enhancements:
test_index_put_deterministic_with_optional_tensors
, to validate deterministic behavior ofindex_put
with various tensor shapes and scenarios. This includes checks for shape mismatches and proper handling of 0D, 1D, and 2D values. (test/xpu/test_indexing_xpu.py
)These changes collectively improve the robustness, maintainability, and test coverage of the
index_put
functionality for XPU tensors.