Skip to content

Conversation

@nvchenghaoz
Copy link
Collaborator

@nvchenghaoz nvchenghaoz commented Dec 3, 2025

This PR fixes the #7532, the goal is to move the logit selection before the final GEMM to reduce the overall computation for the final GEMM for the prefill / mixed stage.

Also integrate the code to handle the case where the transform is not enabled.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added gather_logits_before_lm_head transform that gathers hidden states for selected logits before the language model head computation, enabling new inference execution paths.
  • Bug Fixes

    • Fixed logits tensor handling in the inference engine to properly squeeze batch dimensions when logits are already in tensor format.
  • Tests

    • Added comprehensive unit tests for the gather_logits_before_lm_head transform supporting both generate and packed input formats.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Chenghao Zhang <[email protected]>
@nvchenghaoz nvchenghaoz requested a review from a team as a code owner December 3, 2025 22:19
@nvchenghaoz
Copy link
Collaborator Author

/bot run

@nvchenghaoz
Copy link
Collaborator Author

nvchenghaoz commented Dec 3, 2025

@nvpohanh , we discussed a bit in the meeting and Yoco mentioned he might work on that. I added the transform here because this is a necessary perf optimization for nemotron MOE model.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 3, 2025

📝 Walkthrough

Walkthrough

This PR introduces a new "gather_logits_before_lm_head" transform that optimizes LM head computation by inserting a gather operation to fetch selected hidden states before the linear head layer. Changes span configuration files, the sequence interface to track gather indices, executor logic for conditional handling, a new transform module with graph manipulation and custom operations, and comprehensive unit tests.

Changes

Cohort / File(s) Summary
Configuration
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Added new transform gather_logits_before_lm_head under transforms.compile with stage and enabled flag.
Sequence Interface
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Extended SequenceInfo to track logit_gather_ids as a per-sequence input; updated nest_sequences() method signature to accept optional logit_gather_ids parameter with default zero initialization.
Executor Logic
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Added _gather_before_lm_head_in_graph property; modified _compute_logits() to return model output directly; updated forward() with conditional logits gathering logic based on generation mode and transform state; integrated logit_gather_ids preparation and passing to sequence interface.
Transform Implementation
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
New module with custom gather operation (real and fake implementations), helper utilities for graph analysis (_get_model_device, _find_input_node, _find_lm_head_node), GatherLogitsBeforeLmHeadConfig and GatherLogitsBeforeLmHeadTransform classes for graph manipulation and LM head integration.
Test Configuration
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py
Added gather_logits_before_lm_head transform entry (disabled) to test configuration for Mistral model build.
Engine Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py
Updated logits handling to conditionally squeeze 3D tensor output to 2D when logits is already a tensor.
Transform Unit Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
New test module with SimpleLMHeadModel, direct operation tests (TestGatherLogitsBeforeLmHeadOp), and integration tests (TestGatherLogitsBeforeLmHeadTransform) covering generate/packed formats, transform application, and graceful degradation scenarios.

Sequence Diagram

sequenceDiagram
    participant Executor as ADEngine Executor
    participant SeqInfo as SequenceInfo
    participant Model as Model Graph
    participant Transform as GatherLogitsBeforeLmHead<br/>Transform
    participant LMHead as LM Head

    rect rgb(230, 245, 255)
    Note over Executor,Transform: Transform Application Phase
    Executor->>Transform: Apply transform to model graph
    activate Transform
    Transform->>Model: Locate LM head node
    Transform->>Model: Add logit_gather_ids input if missing
    Transform->>Model: Insert gather operation before LM head
    Transform->>Model: Rewrite LM head inputs to use gathered hidden states
    Transform->>Model: Mark model as transformed
    deactivate Transform
    end

    rect rgb(240, 255, 240)
    Note over Executor,LMHead: Forward Pass Phase
    Executor->>SeqInfo: Prepare sequence data with logit_gather_ids
    SeqInfo->>SeqInfo: Store logit_gather_ids for current batch
    Executor->>Model: Forward pass (context or generate)
    activate Model
    Model->>Model: Gather hidden states using logit_gather_ids
    Model->>LMHead: Pass gathered hidden states
    LMHead->>Model: Compute logits from gathered states
    deactivate Model
    Model-->>Executor: Return logits
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Graph manipulation logic in transform module (_find_lm_head_node, node rewiring, dynamic input addition)
  • Conditional execution paths in executor based on transform detection and generation mode
  • Integration complexity: logit_gather_ids threading across sequence interface, executor, and model graph
  • Custom operation handling with real and fake implementations for CUDA graph compatibility
  • State tracking: ensuring logit_gather_ids is properly populated and passed through multiple layers

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 78.79% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The description explains the issue (#7532) and solution but is missing key sections from the template such as Test Coverage details and PR Checklist confirmation. Add explicit Test Coverage section listing relevant test files (test_gather_logits_before_lm_head.py, test_ad_build_small_single.py, test_engine.py) and confirm PR Checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: moving logit selection before the final GEMM operation in AutoDeploy, which aligns with the PR objectives and changeset.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (6)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (2)

306-317: Consider adding a brief docstring explaining the detection logic.

The property checks two locations for the flag: directly on the model and on a wrapped model.model. While this handles both unwrapped GraphModule and wrapper cases (e.g., CapturedGraph), a brief inline comment explaining when each case applies would improve maintainability.

     @property
     def _gather_before_lm_head_in_graph(self) -> bool:
         """Check if gather_logits_before_lm_head transform was applied to the model."""
-        # Check on the model itself (for unwrapped GraphModule)
+        # Check on model itself (unwrapped GraphModule case, e.g., torch-simple backend)
         if hasattr(self.model, "_gather_logits_before_lm_head_applied"):
             return self.model._gather_logits_before_lm_head_applied
-        # Check on wrapped model (for CapturedGraph or other wrappers)
+        # Check on wrapped model (e.g., CapturedGraph from torch-cudagraph backend)
         if hasattr(self.model, "model") and hasattr(
             self.model.model, "_gather_logits_before_lm_head_applied"
         ):
             return self.model.model._gather_logits_before_lm_head_applied
         return False

340-356: Add strict=True to zip() for safety.

The static analyzer flags zip() on line 353 as missing an explicit strict= parameter. Since logits and last_logit_only should always have matching lengths (both are derived from the same sequence list), adding strict=True will catch any mismatch bugs early.

             else:
                 logits = list(torch.split(logits, self.cache_seq_interface.info.seq_len))
                 # gather+cat logits
                 logits_flat = torch.cat(
                     [
                         ls_one_seq[-last_only:]
-                        for ls_one_seq, last_only in zip(logits, last_logit_only)
+                        for ls_one_seq, last_only in zip(logits, last_logit_only, strict=True)
                     ],
                     dim=0,
                 )
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py (1)

33-46: Silence unused logit_gather_ids / seq_len parameters in SimpleLMHeadModel.forward

The extra inputs are needed for export and transform tests, but they’re intentionally unused, which triggers Ruff (ARG002). Consider explicitly marking them unused to keep linters quiet:

-    def forward(self, hidden_states, logit_gather_ids=None, seq_len=None):
-        # Simulate transformer output
-        hidden_states = self.linear1(hidden_states)
-        # LM head
-        logits = self.lm_head(hidden_states)
-        return logits
+    def forward(self, hidden_states, logit_gather_ids=None, seq_len=None):
+        # Inputs are part of the exported signature but not used in this toy model.
+        del logit_gather_ids, seq_len
+
+        # Simulate transformer output
+        hidden_states = self.linear1(hidden_states)
+        logits = self.lm_head(hidden_states)
+        return logits
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py (3)

225-239: Avoid mutating graph inputs when transform is going to be skipped

logit_gather_ids is added as a new graph input before checking for seq_len:

logit_gather_ids_node = _find_input_node(gm, "logit_gather_ids")
...
if logit_gather_ids_node is None:
    logit_gather_ids_node = add_graph_input(...)
...
if seq_len_node is None:
    ...  # log + early return
    return gm, TransformInfo(skipped=True, num_matches=0)

If seq_len is missing, the transform reports skipped=True but the module’s signature has still changed (new input with no corresponding graph changes). That can surprise callers which weren’t prepared to pass logit_gather_ids in this “skipped” path.

Consider checking seq_len_node first and only adding the logit_gather_ids input after you know the transform will proceed, e.g.:

-        logit_gather_ids_node = _find_input_node(gm, "logit_gather_ids")
-        seq_len_node = _find_input_node(gm, "seq_len")
-
-        # Add logit_gather_ids as input if it doesn't exist
-        if logit_gather_ids_node is None:
-            ...
-            logit_gather_ids_node = add_graph_input(...)
-
-        if seq_len_node is None:
+        seq_len_node = _find_input_node(gm, "seq_len")
+        if seq_len_node is None:
             ad_logger.warning(...)
             return gm, TransformInfo(skipped=True, num_matches=0)
+
+        logit_gather_ids_node = _find_input_node(gm, "logit_gather_ids")
+        if logit_gather_ids_node is None:
+            ...
+            logit_gather_ids_node = add_graph_input(...)

This keeps “skipped” truly non-mutating for the graph interface.


193-199: Mark unused _apply parameters as intentionally unused

factory and shared_config are required by the interface but unused here, which triggers Ruff (ARG002) and slightly obscures intent.

A small tweak keeps the signature intact while silencing the warning:

-    def _apply(
-        self,
-        gm: GraphModule,
-        cm,
-        factory,
-        shared_config: SharedConfig,
-    ) -> Tuple[GraphModule, TransformInfo]:
+    def _apply(
+        self,
+        gm: GraphModule,
+        cm,
+        factory,
+        shared_config: SharedConfig,
+    ) -> Tuple[GraphModule, TransformInfo]:
@@
-        # Return early if disabled
+        # Parameters `factory` and `shared_config` are unused for this transform.
+        del factory, shared_config
+
+        # Return early if disabled

Alternatively, adding brief # unused comments or renaming to _factory, _shared_config would also work if consistent with other transforms.


83-90: Optional: Remove or use _get_model_device to avoid dead code

_get_model_device isn’t referenced in this module. If it’s not part of a planned follow-up, consider removing it or wiring it into the transform where device information is actually needed. This keeps the transform file lean and focused.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 098b9ff and 5434b6a.

📒 Files selected for processing (7)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (7 hunks)
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (6 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used (e.g., use from package.subpackage import foo and then foo.SomeClass() instead of from package.subpackage.foo import SomeClass)
Python filenames should use snake_case (e.g., some_file.py)
Python class names should use PascalCase (e.g., class SomeClass)
Python function and method names should use snake_case (e.g., def my_awesome_function():)
Python local variable names should use snake_case, with prefix k for variable names that start with a number (e.g., k_99th_percentile = ...)
Python global variables should use upper snake_case with prefix G (e.g., G_MY_GLOBAL = ...)
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description (e.g., self.x = 5 followed by """<type>: Description of 'x'""" )
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of specific errors possible instead of catching all exceptions
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block to implement the logic

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
**/*.{cpp,h,cu,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
🧠 Learnings (4)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
📚 Learning: 2025-11-07T09:18:04.997Z
Learnt from: Funatiq
Repo: NVIDIA/TensorRT-LLM PR: 8587
File: tensorrt_llm/_torch/pyexecutor/llm_request.py:129-139
Timestamp: 2025-11-07T09:18:04.997Z
Learning: In `LogitsStorage.get()` method in `tensorrt_llm/_torch/pyexecutor/llm_request.py`, when `exclude_last=True`, there is an invariant that at least 2 chunks must have been appended to `_logits_indices`. The parameter is designed to drop the entire last chunk (not just the last token), which is expected behavior for the overlap scheduler that generates one extra token in a separate chunk.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
🧬 Code graph analysis (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py (5)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
  • SequenceInfo (65-758)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (276-344)
tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py (2)
  • gather_logits_before_lm_head (43-64)
  • gather_logits_before_lm_head_fake (68-80)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
  • InferenceOptimizer (23-78)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (198-221)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (3)
tensorrt_llm/_torch/attention_backend/interface.py (1)
  • num_ctx_tokens (272-273)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (3)
  • named_args (278-287)
  • is_generate (373-374)
  • seq_len (349-350)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • named_args (33-35)
🪛 Ruff (0.14.7)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

330-330: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py

41-41: Unused method argument: logit_gather_ids

(ARG002)


41-41: Unused method argument: seq_len

(ARG002)

tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py

197-197: Unused method argument: factory

(ARG002)


198-198: Unused method argument: shared_config

(ARG002)


250-250: Consider (gather_node, *tuple(lm_head_node.args[1:])) instead of concatenation

Replace with (gather_node, *tuple(lm_head_node.args[1:]))

(RUF005)

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

353-353: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (15)
tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py (2)

74-83: LGTM! Handling for new logits format is appropriate.

The added logic correctly handles the case where logits may be returned as a 3D tensor with a batch dimension from the new gather-before-LM-head path. Squeezing the batch dimension when dim() == 3 ensures compatibility with the downstream comparison against original_logits.


112-115: LGTM! Consistent handling with test_engine.

The same logits shape handling pattern is applied here, maintaining consistency across test functions.

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py (1)

168-168: LGTM! Test coverage for disabled transform path.

Adding "gather_logits_before_lm_head": {"enabled": False} provides explicit test coverage for the case when the new transform is disabled, complementing other tests that likely exercise the enabled path.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

173-175: LGTM! Transform configuration is well-positioned.

The gather_logits_before_lm_head transform is correctly placed in the compile stage before compile_model, ensuring the logit gathering optimization is applied before final model compilation.

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (5)

209-209: LGTM! Initialization for logit gather indices.


234-236: LGTM! Context request logit gathering index calculation.

The index num_ctx_tokens - 1 correctly points to the last token position in the cumulative context token stream for each sequence.


270-272: LGTM! Generation request logit gathering index calculation.

For generation requests, num_ctx_tokens + num_generation_tokens - 1 correctly computes the cumulative offset to the current generation token.


279-286: LGTM! Passing logit_gather_ids to sequence interface.

The new parameter is correctly passed through to nest_sequences.


334-334: LGTM! Assertion guards against unsupported feature.

The explicit assertion provides a clear error message when gather_context_logits is requested, preventing silent incorrect behavior.

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (6)

101-103: LGTM! Clear documentation for the new field.

The docstring clearly explains the purpose of logit_gather_ids - storing the index of the last token in each sequence for logit gathering.


207-207: LGTM! Storage allocation for logit_gather_ids.

Using max_batch_size for the allocation size is correct since there's one gather index per sequence.


216-223: LGTM! Adding logit_gather_ids to cached argument names.

This ensures the new field is properly tracked and passed through as a model input via update_in_out_nodes.


324-330: LGTM! Correct exclusion from prepare_metadata args.

The comment clearly explains that logit_gather_ids is excluded from args_for_prepare_metadata because it's not needed for attention metadata preparation, while still being included in _cached_arg_names to ensure it gets added as a model input.

Regarding the static analyzer hint (RUF005) about using iterable unpacking: this is a stylistic preference and the current tuple concatenation is perfectly clear and readable.


658-658: LGTM! Function signature extended correctly.

The optional parameter with None default maintains backward compatibility.


705-709: Verify the default zeros behavior for logit_gather_ids.

When logit_gather_ids is None, the code defaults to a list of zeros for max_batch_size sequences. This differs from other fields like slot_idx which use a computed "free" value. Since gather indices of 0 would point to the first token, verify this default is safe for cases where the transform is applied but logit_gather_ids isn't explicitly provided.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #26833 [ run ] triggered by Bot. Commit: 5434b6a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #26833 [ run ] completed with state SUCCESS. Commit: 5434b6a
/LLM/main/L0_MergeRequest_PR pipeline #20438 completed with status: 'FAILURE'

Gathered hidden states [batch, hidden] for generate, [1, max_batch_size, hidden] for packed
"""
# Generate format: [batch, 1, hidden] -> seq_len == 1
# Packed format: [1, total_tokens, hidden] -> seq_len > 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Just for my understanding, what is the "packed" format? Is this for a "context"/"prefill" run?

@@ -165,6 +165,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
"transforms": {
"insert_cached_attention": {"backend": "flashinfer"},
"compile_model": {"backend": "torch-cudagraph"},
"gather_logits_before_lm_head": {"enabled": False},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is not enabled by default?

it should be before compile_model btw

Copy link
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little too much vibe code ;)

from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.library.gather_logits_before_lm_head import (
gather_logits_before_lm_head,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use registered custom ops directly. Instead use the correct dispatch, e.g., torch.ops.auto_deploy.gather_logits_before-lm_head.default

lucaslie and others added 3 commits December 4, 2025 11:39
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27026 [ run ] triggered by Bot. Commit: f5532f2

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27026 [ run ] completed with state SUCCESS. Commit: f5532f2
/LLM/main/L0_MergeRequest_PR pipeline #20608 completed with status: 'FAILURE'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants