-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[#7532][feat] AutoDeploy: Move the logit selection before the final Gemm #9681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
|
/bot run |
|
@nvpohanh , we discussed a bit in the meeting and Yoco mentioned he might work on that. I added the transform here because this is a necessary perf optimization for nemotron MOE model. |
📝 WalkthroughWalkthroughThis PR introduces a new "gather_logits_before_lm_head" transform that optimizes LM head computation by inserting a gather operation to fetch selected hidden states before the linear head layer. Changes span configuration files, the sequence interface to track gather indices, executor logic for conditional handling, a new transform module with graph manipulation and custom operations, and comprehensive unit tests. Changes
Sequence DiagramsequenceDiagram
participant Executor as ADEngine Executor
participant SeqInfo as SequenceInfo
participant Model as Model Graph
participant Transform as GatherLogitsBeforeLmHead<br/>Transform
participant LMHead as LM Head
rect rgb(230, 245, 255)
Note over Executor,Transform: Transform Application Phase
Executor->>Transform: Apply transform to model graph
activate Transform
Transform->>Model: Locate LM head node
Transform->>Model: Add logit_gather_ids input if missing
Transform->>Model: Insert gather operation before LM head
Transform->>Model: Rewrite LM head inputs to use gathered hidden states
Transform->>Model: Mark model as transformed
deactivate Transform
end
rect rgb(240, 255, 240)
Note over Executor,LMHead: Forward Pass Phase
Executor->>SeqInfo: Prepare sequence data with logit_gather_ids
SeqInfo->>SeqInfo: Store logit_gather_ids for current batch
Executor->>Model: Forward pass (context or generate)
activate Model
Model->>Model: Gather hidden states using logit_gather_ids
Model->>LMHead: Pass gathered hidden states
LMHead->>Model: Compute logits from gathered states
deactivate Model
Model-->>Executor: Return logits
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (6)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (2)
306-317: Consider adding a brief docstring explaining the detection logic.The property checks two locations for the flag: directly on the model and on a wrapped
model.model. While this handles both unwrappedGraphModuleand wrapper cases (e.g.,CapturedGraph), a brief inline comment explaining when each case applies would improve maintainability.@property def _gather_before_lm_head_in_graph(self) -> bool: """Check if gather_logits_before_lm_head transform was applied to the model.""" - # Check on the model itself (for unwrapped GraphModule) + # Check on model itself (unwrapped GraphModule case, e.g., torch-simple backend) if hasattr(self.model, "_gather_logits_before_lm_head_applied"): return self.model._gather_logits_before_lm_head_applied - # Check on wrapped model (for CapturedGraph or other wrappers) + # Check on wrapped model (e.g., CapturedGraph from torch-cudagraph backend) if hasattr(self.model, "model") and hasattr( self.model.model, "_gather_logits_before_lm_head_applied" ): return self.model.model._gather_logits_before_lm_head_applied return False
340-356: Addstrict=Truetozip()for safety.The static analyzer flags
zip()on line 353 as missing an explicitstrict=parameter. Sincelogitsandlast_logit_onlyshould always have matching lengths (both are derived from the same sequence list), addingstrict=Truewill catch any mismatch bugs early.else: logits = list(torch.split(logits, self.cache_seq_interface.info.seq_len)) # gather+cat logits logits_flat = torch.cat( [ ls_one_seq[-last_only:] - for ls_one_seq, last_only in zip(logits, last_logit_only) + for ls_one_seq, last_only in zip(logits, last_logit_only, strict=True) ], dim=0, )tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py (1)
33-46: Silence unusedlogit_gather_ids/seq_lenparameters inSimpleLMHeadModel.forwardThe extra inputs are needed for export and transform tests, but they’re intentionally unused, which triggers Ruff (ARG002). Consider explicitly marking them unused to keep linters quiet:
- def forward(self, hidden_states, logit_gather_ids=None, seq_len=None): - # Simulate transformer output - hidden_states = self.linear1(hidden_states) - # LM head - logits = self.lm_head(hidden_states) - return logits + def forward(self, hidden_states, logit_gather_ids=None, seq_len=None): + # Inputs are part of the exported signature but not used in this toy model. + del logit_gather_ids, seq_len + + # Simulate transformer output + hidden_states = self.linear1(hidden_states) + logits = self.lm_head(hidden_states) + return logitstensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py (3)
225-239: Avoid mutating graph inputs when transform is going to be skipped
logit_gather_idsis added as a new graph input before checking forseq_len:logit_gather_ids_node = _find_input_node(gm, "logit_gather_ids") ... if logit_gather_ids_node is None: logit_gather_ids_node = add_graph_input(...) ... if seq_len_node is None: ... # log + early return return gm, TransformInfo(skipped=True, num_matches=0)If
seq_lenis missing, the transform reportsskipped=Truebut the module’s signature has still changed (new input with no corresponding graph changes). That can surprise callers which weren’t prepared to passlogit_gather_idsin this “skipped” path.Consider checking
seq_len_nodefirst and only adding thelogit_gather_idsinput after you know the transform will proceed, e.g.:- logit_gather_ids_node = _find_input_node(gm, "logit_gather_ids") - seq_len_node = _find_input_node(gm, "seq_len") - - # Add logit_gather_ids as input if it doesn't exist - if logit_gather_ids_node is None: - ... - logit_gather_ids_node = add_graph_input(...) - - if seq_len_node is None: + seq_len_node = _find_input_node(gm, "seq_len") + if seq_len_node is None: ad_logger.warning(...) return gm, TransformInfo(skipped=True, num_matches=0) + + logit_gather_ids_node = _find_input_node(gm, "logit_gather_ids") + if logit_gather_ids_node is None: + ... + logit_gather_ids_node = add_graph_input(...)This keeps “skipped” truly non-mutating for the graph interface.
193-199: Mark unused_applyparameters as intentionally unused
factoryandshared_configare required by the interface but unused here, which triggers Ruff (ARG002) and slightly obscures intent.A small tweak keeps the signature intact while silencing the warning:
- def _apply( - self, - gm: GraphModule, - cm, - factory, - shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + def _apply( + self, + gm: GraphModule, + cm, + factory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: @@ - # Return early if disabled + # Parameters `factory` and `shared_config` are unused for this transform. + del factory, shared_config + + # Return early if disabledAlternatively, adding brief
# unusedcomments or renaming to_factory,_shared_configwould also work if consistent with other transforms.
83-90: Optional: Remove or use_get_model_deviceto avoid dead code
_get_model_deviceisn’t referenced in this module. If it’s not part of a planned follow-up, consider removing it or wiring it into the transform where device information is actually needed. This keeps the transform file lean and focused.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py(7 hunks)tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py(6 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py(1 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py(1 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used (e.g., usefrom package.subpackage import fooand thenfoo.SomeClass()instead offrom package.subpackage.foo import SomeClass)
Python filenames should use snake_case (e.g.,some_file.py)
Python class names should use PascalCase (e.g.,class SomeClass)
Python function and method names should use snake_case (e.g.,def my_awesome_function():)
Python local variable names should use snake_case, with prefixkfor variable names that start with a number (e.g.,k_99th_percentile = ...)
Python global variables should use upper snake_case with prefixG(e.g.,G_MY_GLOBAL = ...)
Python constants should use upper snake_case (e.g.,MY_CONSTANT = ...)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description (e.g.,self.x = 5followed by"""<type>: Description of 'x'""")
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of specific errors possible instead of catching all exceptions
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block to implement the logic
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.pytests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.pytensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.pytensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.pytests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.pytensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.pytensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
🧠 Learnings (4)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
📚 Learning: 2025-11-07T09:18:04.997Z
Learnt from: Funatiq
Repo: NVIDIA/TensorRT-LLM PR: 8587
File: tensorrt_llm/_torch/pyexecutor/llm_request.py:129-139
Timestamp: 2025-11-07T09:18:04.997Z
Learning: In `LogitsStorage.get()` method in `tensorrt_llm/_torch/pyexecutor/llm_request.py`, when `exclude_last=True`, there is an invariant that at least 2 chunks must have been appended to `_logits_indices`. The parameter is designed to drop the entire last chunk (not just the last token), which is expected behavior for the overlap scheduler that generates one extra token in a separate chunk.
Applied to files:
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
🧬 Code graph analysis (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py (5)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
SequenceInfo(65-758)tensorrt_llm/_torch/auto_deploy/export/export.py (1)
torch_export_to_gm(276-344)tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py (2)
gather_logits_before_lm_head(43-64)gather_logits_before_lm_head_fake(68-80)tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
InferenceOptimizer(23-78)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op(198-221)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (3)
tensorrt_llm/_torch/attention_backend/interface.py (1)
num_ctx_tokens(272-273)tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (3)
named_args(278-287)is_generate(373-374)seq_len(349-350)tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
named_args(33-35)
🪛 Ruff (0.14.7)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
330-330: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
41-41: Unused method argument: logit_gather_ids
(ARG002)
41-41: Unused method argument: seq_len
(ARG002)
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
197-197: Unused method argument: factory
(ARG002)
198-198: Unused method argument: shared_config
(ARG002)
250-250: Consider (gather_node, *tuple(lm_head_node.args[1:])) instead of concatenation
Replace with (gather_node, *tuple(lm_head_node.args[1:]))
(RUF005)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
353-353: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (15)
tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py (2)
74-83: LGTM! Handling for new logits format is appropriate.The added logic correctly handles the case where logits may be returned as a 3D tensor with a batch dimension from the new gather-before-LM-head path. Squeezing the batch dimension when
dim() == 3ensures compatibility with the downstream comparison againstoriginal_logits.
112-115: LGTM! Consistent handling with test_engine.The same logits shape handling pattern is applied here, maintaining consistency across test functions.
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py (1)
168-168: LGTM! Test coverage for disabled transform path.Adding
"gather_logits_before_lm_head": {"enabled": False}provides explicit test coverage for the case when the new transform is disabled, complementing other tests that likely exercise the enabled path.tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
173-175: LGTM! Transform configuration is well-positioned.The
gather_logits_before_lm_headtransform is correctly placed in the compile stage beforecompile_model, ensuring the logit gathering optimization is applied before final model compilation.tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (5)
209-209: LGTM! Initialization for logit gather indices.
234-236: LGTM! Context request logit gathering index calculation.The index
num_ctx_tokens - 1correctly points to the last token position in the cumulative context token stream for each sequence.
270-272: LGTM! Generation request logit gathering index calculation.For generation requests,
num_ctx_tokens + num_generation_tokens - 1correctly computes the cumulative offset to the current generation token.
279-286: LGTM! Passing logit_gather_ids to sequence interface.The new parameter is correctly passed through to
nest_sequences.
334-334: LGTM! Assertion guards against unsupported feature.The explicit assertion provides a clear error message when
gather_context_logitsis requested, preventing silent incorrect behavior.tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (6)
101-103: LGTM! Clear documentation for the new field.The docstring clearly explains the purpose of
logit_gather_ids- storing the index of the last token in each sequence for logit gathering.
207-207: LGTM! Storage allocation for logit_gather_ids.Using
max_batch_sizefor the allocation size is correct since there's one gather index per sequence.
216-223: LGTM! Adding logit_gather_ids to cached argument names.This ensures the new field is properly tracked and passed through as a model input via
update_in_out_nodes.
324-330: LGTM! Correct exclusion from prepare_metadata args.The comment clearly explains that
logit_gather_idsis excluded fromargs_for_prepare_metadatabecause it's not needed for attention metadata preparation, while still being included in_cached_arg_namesto ensure it gets added as a model input.Regarding the static analyzer hint (RUF005) about using iterable unpacking: this is a stylistic preference and the current tuple concatenation is perfectly clear and readable.
658-658: LGTM! Function signature extended correctly.The optional parameter with
Nonedefault maintains backward compatibility.
705-709: Verify the default zeros behavior for logit_gather_ids.When
logit_gather_idsisNone, the code defaults to a list of zeros formax_batch_sizesequences. This differs from other fields likeslot_idxwhich use a computed "free" value. Since gather indices of 0 would point to the first token, verify this default is safe for cases where the transform is applied butlogit_gather_idsisn't explicitly provided.
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
Outdated
Show resolved
Hide resolved
|
PR_Github #26833 [ run ] triggered by Bot. Commit: |
|
PR_Github #26833 [ run ] completed with state |
| Gathered hidden states [batch, hidden] for generate, [1, max_batch_size, hidden] for packed | ||
| """ | ||
| # Generate format: [batch, 1, hidden] -> seq_len == 1 | ||
| # Packed format: [1, total_tokens, hidden] -> seq_len > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: Just for my understanding, what is the "packed" format? Is this for a "context"/"prefill" run?
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
Outdated
Show resolved
Hide resolved
| @@ -165,6 +165,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): | |||
| "transforms": { | |||
| "insert_cached_attention": {"backend": "flashinfer"}, | |||
| "compile_model": {"backend": "torch-cudagraph"}, | |||
| "gather_logits_before_lm_head": {"enabled": False}, | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is not enabled by default?
it should be before compile_model btw
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
Outdated
Show resolved
Hide resolved
lucaslie
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a little too much vibe code ;)
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
Outdated
Show resolved
Hide resolved
| from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo | ||
| from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm | ||
| from tensorrt_llm._torch.auto_deploy.transform.library.gather_logits_before_lm_head import ( | ||
| gather_logits_before_lm_head, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't use registered custom ops directly. Instead use the correct dispatch, e.g., torch.ops.auto_deploy.gather_logits_before-lm_head.default
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
|
/bot run |
|
PR_Github #27026 [ run ] triggered by Bot. Commit: |
|
PR_Github #27026 [ run ] completed with state |
This PR fixes the #7532, the goal is to move the logit selection before the final GEMM to reduce the overall computation for the final GEMM for the prefill / mixed stage.
Also integrate the code to handle the case where the transform is not enabled.
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.