-
Notifications
You must be signed in to change notification settings - Fork 2.7k
[TRANSFORMATIONS] Fix SDPAReshapeFusion with reshape output shape #32261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes an issue in SDPAReshapeFusion where the replaced reshape operation didn't match the output shape of the original reshape operation.
- Removes rank constraints from input patterns to make fusion more flexible
- Adds logic to check output shape compatibility and adjust the replacement node accordingly
- Includes a new test case to verify the fix for reshape output shape handling
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
sdpa_fusion.cpp | Removes rank constraints and adds output shape validation logic |
sdpa_fusion_test.cpp | Adds test case for reshape optimization with output reshape |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
||
auto new_sdpa_node = sdpa_node->clone_with_new_inputs( | ||
{q_node, k_node, v_node, sdpa_node->input(3).get_source_output(), sdpa_node->input(4).get_source_output()}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for adjusting post_sdpa_node when output shapes don't match is unclear. Consider adding a comment explaining why we're reassigning post_sdpa_node to its input node, and what this achieves in terms of shape compatibility.
// If the output shape of the new SDPA node does not match the expected output shape, | |
// reassign post_sdpa_node to its input node. This ensures that we replace the correct node | |
// in the graph, maintaining shape compatibility after the transformation. |
Copilot uses AI. Check for mistakes.
} | ||
new_sdpa_node->set_friendly_name(post_sdpa_node->get_friendly_name()); | ||
ov::copy_runtime_info(m.get_matched_nodes(), new_sdpa_node); | ||
ov::copy_runtime_info(post_sdpa_node, new_sdpa_node); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The runtime info copying has changed from copying from all matched nodes to only copying from post_sdpa_node. This could result in loss of runtime information from other matched nodes. Consider whether this change is intentional and if all necessary runtime info is preserved.
ov::copy_runtime_info(post_sdpa_node, new_sdpa_node); | |
ov::copy_runtime_info(m.get_matched_nodes(), new_sdpa_node); |
Copilot uses AI. Check for mistakes.
Details:
Tickets: