Skip to content

Fix CuTe composition stride-divisibility check (#3177)#3181

Open
jduprat wants to merge 2 commits intoNVIDIA:mainfrom
jduprat:main
Open

Fix CuTe composition stride-divisibility check (#3177)#3181
jduprat wants to merge 2 commits intoNVIDIA:mainfrom
jduprat:main

Conversation

@jduprat
Copy link
Copy Markdown

@jduprat jduprat commented Apr 21, 2026

composition_impl() used a strict weakening of the divisibility condition:
it accepted any rhs stride smaller than the current lhs mode shape,
regardless of whether the shape was actually divisible by the stride.

For A=(4,6,8):(2,3,5), B=6:3, this lets composition(A,B) compile and
return (_2,_3):(_6,_3), but C(2)=3 != A(B(2))=7.

Replace the weak check with the stronger condition used by
pycute (layout.py:211).

Fixes #3177

jduprat added 2 commits April 30, 2026 21:21
composition_impl() used a strict weakening of the divisibility condition:
it accepted any rhs stride smaller than the current lhs mode shape,
regardless of whether the shape was actually divisible by the stride.

For A=(4,6,8):(2,3,5), B=6:3, this lets composition(A,B) compile and
return (_2,_3):(_6,_3), but C(2)=3 != A(B(2))=7.

Replace the weak check with the stronger condition used by
pycute (layout.py:211).

Fixes NVIDIA#3177
The strong divisibility check from the previous commit fixes the
wrong-answer composition from NVIDIA#3177, but rejects the paper's §3.3.3
"apparent violation" cases that produce well-defined results, e.g.

    A = (4,2,8):(3,12,97), B = 3:3   ->   3:9

After the public composition() coalesces A to (8,8):(3,97), the
strong check sees `8 % 3 != 0` and refuses to compile, even though
A(0)=0, A(3)=9, A(6)=18 is well-defined.

Add a third disjunct that accepts the safe-truncation pattern: when
B's entire image fits inside the current LHS mode, higher modes are
unreachable and cannot perturb the result. This is the §3.3.3
distinction between "apparent" and "real" divisibility violations.

Predicate now accepts iff at least one of:
  (a) (rest_stride % curr_shape) == 0   -- skip mode entirely
  (b) (curr_shape % rest_stride) == 0   -- partial traversal
  (c) (rest_shape - 1) * rest_stride < curr_shape
                                        -- safe truncation: B's image
                                           stays within the current mode

Verification matrix:

  Case                                  Pre-coalesce LHS         Decision
  ----------------------------------    ---------------------    --------
  paper §3.3.3 ok (returns 3:9)         (8,8):(3,97) o 3:3       accept
  paper §3.3.3 fail-left                (8,8):(3,97) o 4:3       reject
  paper §3.3.3 fail-right               (4,2,8):(3,15,97) o 3:3  reject
  wrong-answer bug NVIDIA#3177                (4,6,8):(2,3,5) o 6:3    reject
  CuTe test (8,8):(8,1) o 2:3           (8,8):(8,1) o 2:3        accept
  CuTe test (8,8):(8,1) o 3:3           (8,8):(8,1) o 3:3        accept
  CuTe test (8,8):(8,1) o 4:3           (8,8):(8,1) o 4:3        reject

Reference: arXiv:2603.02298 §3.3.3.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Discrepancy between CuTe C++ and pycute

1 participant