Skip to content

Conversation

vidavakil
Copy link

test_selective_scan() fails when is_variable_B is False.

Turns out selective_scan_fwd_kernel incorporates an optimization of not multiplying the state by B if B is not variable. This does not impact MambaInnerFn, because MambaInnerFn never returns the state. But SelectiveScanFn may need to return the last_state. The changes to the code fix this problem, by multiplying the last_state by B before returning it when B is not variable.

…function has to return

the last_state. # The cuda kernel does a peculiar optimization of not multiplying the state
by B if B is not variable! This does not impact MambaInnerFn, because it never returns the
state. But SelectiveScanFn may needd to return the last state! Hence the following is needed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant