-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[nvbug 5325284][fix] Increase Nemotron-H warmup request robustness #4954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… fails Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
…failed or succeeded (2) don't add BOS token to match expected outputs Signed-off-by: Tomer Asida <[email protected]>
…y state_indices during forward pass. Now LLM API test passes Signed-off-by: Tomer Asida <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request increases the robustness of the Mamba2Mixer forward pass during warmup runs by checking the validity of state indices and supplying dummy values if necessary. It also adds new LLM API unit tests for Nemotron-H to catch similar issues in the future.
- Updated unit tests to use function-based test definitions and to cover LLM API usage.
- Modified state_indices initialization in MambaHybridCacheManager to properly set the device and dtype.
- Refactored Mamba2Mixer to determine warmup cases by checking if state_indices is empty and to generate dummy indices accordingly.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| tests/unittest/_torch/modeling/test_modeling_nemotron_h.py | Refactored tests to adopt function-based style and updated KvCacheConfig usage. |
| tensorrt_llm/_torch/pyexecutor/resource_manager.py | Updated the initialization of state_indices with explicit device and dtype. |
| tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py | Revised warmup request handling by checking state_indices emptiness and creating fallback indices. |
Comments suppressed due to low confidence (3)
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py:162
- Ensure that attn_metadata.kv_cache_manager is always non-null in warmup runs, as the previous check for None was removed; consider adding validation if there's a chance it might be None.
state_indices = attn_metadata.kv_cache_manager.get_state_indices()
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py:278
- [nitpick] Add a comment explaining why ssm_states is updated unconditionally during warmup runs to aid maintainers in understanding the relaxed behavior in this code path.
ssm_states[indices] = current_ssm_states
tests/unittest/_torch/modeling/test_modeling_nemotron_h.py:214
- [nitpick] Consider clearly distinguishing between KvCacheConfig and KvCacheConfigCpp in naming and usage to reduce potential confusion in the test configuration.
kv_cache_config = KvCacheConfigCpp(max_tokens=num_blocks * tokens_per_block,
Signed-off-by: Tomer Asida <[email protected]>
…-LLM into fix-nemotron-h-warmup Signed-off-by: Tomer Asida <[email protected]>
|
/bot run |
|
PR_Github #7757 [ run ] triggered by Bot |
|
PR_Github #7757 [ run ] completed with state |
vegaluisjose
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @tomeras91 !
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
|
/bot run |
|
PR_Github #8027 [ run ] triggered by Bot |
|
/bot kill |
|
PR_Github #8028 [ kill ] triggered by Bot |
|
/bot run --disable-fail-fast |
|
PR_Github #8028 [ kill ] completed with state |
|
PR_Github #8031 [ run ] triggered by Bot |
|
PR_Github #8031 [ run ] completed with state |
|
/bot run |
Signed-off-by: Tomer Asida <[email protected]>
|
/bot run |
|
PR_Github #8097 [ run ] triggered by Bot |
|
PR_Github #8097 [ run ] completed with state |
Signed-off-by: Tomer Asida <[email protected]>
|
/bot run |
|
PR_Github #8125 [ run ] triggered by Bot |
|
PR_Github #8125 [ run ] completed with state |
Signed-off-by: Tomer Asida <[email protected]>
|
/bot run |
|
PR_Github #8146 [ run ] triggered by Bot |
|
PR_Github #8146 [ run ] completed with state |
Signed-off-by: Tomer Asida <[email protected]>
|
/bot run |
|
PR_Github #8160 [ run ] triggered by Bot |
|
PR_Github #8160 [ run ] completed with state |
|
/bot run |
|
PR_Github #8208 [ run ] triggered by Bot |
|
PR_Github #8208 [ run ] completed with state |
The
Mamba2Mixerblock forward pass in Nemotron-H assumes thatstate_indicesinMambaCacheManageris valid, which is prepared during the call toMambaHybridCacheManager.prepare_resources(). However, warmup runs don't callprepare_resources(), requiring special handling within theMamba2Mixerforward pass. Previously, warmup runs were assumed to have a single request withrequest_id=0and were identified by this condition.PR #4466 changed the behavior of warmup runs, so they no longer always have just a single request with
id=0. Consequently, this broke theMamba2Mixerforward pass during warmup runs, preventing Nemotron-H from being initialized via the LLM API.This PR enhances the stability of
Mamba2Mixerduring warmup runs by directly checking the validity ofstate_indices. Ifstate_indicesis invalid, it fills it with valid dummy values, reducing the differences between regular and warmup forward passes.Additionally, since PR #4466 broke Nemotron-H silently due to the absence of unittests using the LLM API for Nemotron-H, this PR also introduces such a test to prevent similar issues in the future.