Fix CuTe composition stride-divisibility check (#3177)#3181
Open
jduprat wants to merge 2 commits intoNVIDIA:mainfrom
Open
Fix CuTe composition stride-divisibility check (#3177)#3181jduprat wants to merge 2 commits intoNVIDIA:mainfrom
jduprat wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
composition_impl() used a strict weakening of the divisibility condition: it accepted any rhs stride smaller than the current lhs mode shape, regardless of whether the shape was actually divisible by the stride. For A=(4,6,8):(2,3,5), B=6:3, this lets composition(A,B) compile and return (_2,_3):(_6,_3), but C(2)=3 != A(B(2))=7. Replace the weak check with the stronger condition used by pycute (layout.py:211). Fixes NVIDIA#3177
The strong divisibility check from the previous commit fixes the wrong-answer composition from NVIDIA#3177, but rejects the paper's §3.3.3 "apparent violation" cases that produce well-defined results, e.g. A = (4,2,8):(3,12,97), B = 3:3 -> 3:9 After the public composition() coalesces A to (8,8):(3,97), the strong check sees `8 % 3 != 0` and refuses to compile, even though A(0)=0, A(3)=9, A(6)=18 is well-defined. Add a third disjunct that accepts the safe-truncation pattern: when B's entire image fits inside the current LHS mode, higher modes are unreachable and cannot perturb the result. This is the §3.3.3 distinction between "apparent" and "real" divisibility violations. Predicate now accepts iff at least one of: (a) (rest_stride % curr_shape) == 0 -- skip mode entirely (b) (curr_shape % rest_stride) == 0 -- partial traversal (c) (rest_shape - 1) * rest_stride < curr_shape -- safe truncation: B's image stays within the current mode Verification matrix: Case Pre-coalesce LHS Decision ---------------------------------- --------------------- -------- paper §3.3.3 ok (returns 3:9) (8,8):(3,97) o 3:3 accept paper §3.3.3 fail-left (8,8):(3,97) o 4:3 reject paper §3.3.3 fail-right (4,2,8):(3,15,97) o 3:3 reject wrong-answer bug NVIDIA#3177 (4,6,8):(2,3,5) o 6:3 reject CuTe test (8,8):(8,1) o 2:3 (8,8):(8,1) o 2:3 accept CuTe test (8,8):(8,1) o 3:3 (8,8):(8,1) o 3:3 accept CuTe test (8,8):(8,1) o 4:3 (8,8):(8,1) o 4:3 reject Reference: arXiv:2603.02298 §3.3.3.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
composition_impl() used a strict weakening of the divisibility condition:
it accepted any rhs stride smaller than the current lhs mode shape,
regardless of whether the shape was actually divisible by the stride.
For A=(4,6,8):(2,3,5), B=6:3, this lets composition(A,B) compile and
return (_2,_3):(_6,_3), but C(2)=3 != A(B(2))=7.
Replace the weak check with the stronger condition used by
pycute (layout.py:211).
Fixes #3177