-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None][feat] AutoDeploy: Remove redundant copies in mamba layers #9461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[None][feat] AutoDeploy: Remove redundant copies in mamba layers #9461
Conversation
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
📝 WalkthroughWalkthroughThe changes refactor a CUDA cached causal convolution operation from tensor-returning to in-place modification semantics. A wrapper function is introduced to maintain backward compatibility, while fusion logic is updated to use the new wrapper instead of the raw operator. Output tensor assembly logic in the Triton backend is also adjusted to conditionally construct results from prefill and decode paths. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (1)
207-216: The decode path has a critical bug: the return value ofcausal_conv1d_updateis not captured.The function
causal_conv1d_updatereturns the convolution output (shape matchingx_decode), but the code at lines 207-216 discards this return value. The function does not modifyx_decodein place; instead, it copiesx_decodeintoconv_stateand returns the computed output. Without capturing the return value, the decode output is lost entirely.The call should be:
x_decode = causal_conv1d_update( x_decode, # [batch, dim] conv_state_cache, w2d, bias, activation=activation, cache_seqlens=None, conv_state_indices=slot_idx[num_prefill:].to(torch.int32), pad_slot_id=PAD_SLOT_ID, )
🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (1)
245-247: Consider adding a docstring to the wrapper function.The wrapper function correctly calls the in-place op and returns the input, maintaining backward compatibility. However, adding a docstring would improve clarity:
def cuda_cached_causal_conv1d_wrapper(input, *args, **kwargs): + """Wrapper for cuda_cached_causal_conv1d that returns the modified input. + + The underlying op modifies input in-place; this wrapper provides + a functional interface for backward compatibility. + """ torch.ops.auto_deploy.cuda_cached_causal_conv1d(input, *args, **kwargs) return input
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py(7 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py(4 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py(4 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used (e.g., usefrom package.subpackage import fooand thenfoo.SomeClass()instead offrom package.subpackage.foo import SomeClass)
Python filenames should use snake_case (e.g.,some_file.py)
Python class names should use PascalCase (e.g.,class SomeClass)
Python function and method names should use snake_case (e.g.,def my_awesome_function():)
Python local variable names should use snake_case, with prefixkfor variable names that start with a number (e.g.,k_99th_percentile = ...)
Python global variables should use upper snake_case with prefixG(e.g.,G_MY_GLOBAL = ...)
Python constants should use upper snake_case (e.g.,MY_CONSTANT = ...)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description (e.g.,self.x = 5followed by"""<type>: Description of 'x'""")
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of specific errors possible instead of catching all exceptions
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block to implement the logic
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
🧠 Learnings (4)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.707Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device allreduce implementation (cpp/tensorrt_llm/thop/allreduceOp.cpp), the goto pattern in runNCCLAllReduceDeviceFusion is intentionally used for future extensibility, allowing multiple switch cases to fallback to the default handler. While not aesthetically ideal, this pattern supports adding more fusion cases later that can reuse the same fallback logic.
📚 Learning: 2025-11-14T11:22:03.729Z
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.
Applied to files:
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
Applied to files:
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py
📚 Learning: 2025-09-23T15:12:38.312Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device allreduce implementation (cpp/tensorrt_llm/thop/allreduceOp.cpp), the goto pattern in runNCCLAllReduceDeviceFusion is intentionally used for future extensibility, allowing multiple switch cases to fallback to the default handler. While not aesthetically ideal, this pattern supports adding more fusion cases later that can reuse the same fallback logic.
Applied to files:
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py (2)
tests/unittest/_torch/thop/parallel/test_custom_ops.py (1)
custom_ops(37-42)tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (1)
cuda_cached_causal_conv1d_wrapper(245-247)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
args(288-290)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (9)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py (2)
1-15: LGTM! Copyright header properly added.The NVIDIA copyright header with Apache 2.0 license has been correctly added to this file, as required by the coding guidelines.
104-119: LGTM! Wrapper integration is clean.The switch from using the raw operator to the
cuda_cached_causal_conv1d_wrapperis implemented correctly. The wrapper is consistently used in both pattern matching (line 109) and fusion (line 119), which aligns with the in-place modification semantics introduced in the CUDA backend.tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (3)
1-15: LGTM! Copyright header properly added.The NVIDIA copyright header with Apache 2.0 license has been correctly added to this file, as required by the coding guidelines.
207-208: LGTM! Clean initialization pattern.Initializing
y_prefillandy_decodetoNoneenables the dispatch-based return logic and helps eliminate redundant tensor allocations, aligning with the PR objectives.
282-294: I'll help you verify the shape compatibility in the dispatch return logic. Let me start by examining the mamba triton backend file and locating the relevant function definitions.
<function_calls>
#!/bin/bashFind the file and check its size
wc -l tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
</function_calls>
#!/bin/bashFind mamba_chunk_scan_combined and selective_state_update signatures
rg -n "def mamba_chunk_scan_combined|def selective_state_update" tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
</function_calls>Now let me examine the relevant sections of the file to understand the shape flow:
<function_calls>
#!/bin/bashRead the critical section around lines 282-294 and some context before/after
sed -n '250,310p' tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
</function_calls>Let me also search for where y_prefill and y_decode are created:
<function_calls>
#!/bin/bashSearch for y_prefill and y_decode creation
rg -n "y_prefill|y_decode" tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py | head -30
</function_calls><function_calls>
#!/bin/bashGet broader context on y_prefill and y_decode
sed -n '200,295p' tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
</function_calls>Let me search for the mamba_chunk_scan_combined and selective_state_update function definitions:
<function_calls>
#!/bin/bashSearch across the entire codebase for mamba_chunk_scan_combined definition
rg -n "def mamba_chunk_scan_combined" --type py
</function_calls><function_calls>
#!/bin/bashSearch across the entire codebase for selective_state_update definition
rg -n "def selective_state_update" --type py
</function_calls><function_calls>
#!/bin/bashLook for the function calls in the triton_backend_mamba file
rg -B5 -A5 "mamba_chunk_scan_combined|selective_state_update" tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
</function_calls><function_calls>
#!/bin/bashRead the entire file to understand context better
wc -l tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py && echo "---" && cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
</function_calls>tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py (4)
1-15: LGTM! Copyright header properly added.The NVIDIA copyright header with Apache 2.0 license has been correctly added to this file, as required by the coding guidelines.
112-140: LGTM! In-place semantics properly declared.The changes correctly implement in-place mutation semantics:
mutates_args={"input"}properly declares the mutation to PyTorch- Return type changed to
Noneis consistent with in-place operations- Documentation clearly states the in-place behavior
This refactoring aligns with the PR objective of removing redundant copies.
198-199: LGTM! Prefill path correctly implements in-place modification.The scatter operation
inp_flat[:total_prefill_tokens] = y_varlen.transpose(0, 1)correctly writes the results back to the input buffer. Sinceinp_flatis a view ofinput(line 156), the modifications properly propagate.
242-242: LGTM! Fake registration and wrapper exposure are correct.
- Line 242: The fake function correctly returns
None, consistent with the in-place operation- Line 273: The wrapper is correctly returned by
get_cached_attention_op, providing the public API while encapsulating the in-place semanticsAlso applies to: 273-273
|
/bot run |
|
PR_Github #25781 [ run ] triggered by Bot. Commit: |
|
PR_Github #25781 [ run ] completed with state |
|
/bot run |
|
PR_Github #25823 [ run ] triggered by Bot. Commit: |
|
PR_Github #25823 [ run ] completed with state |
|
/bot run |
|
PR_Github #25867 [ run ] triggered by Bot. Commit: |
|
PR_Github #25867 [ run ] completed with state |
Signed-off-by: Chenghao Zhang <[email protected]>
|
/bot run |
|
PR_Github #25883 [ run ] triggered by Bot. Commit: |
|
PR_Github #25883 [ run ] completed with state |
Part 2 for #9344
This PR is to remove the redundant copies after causal conv and after the SSM.
Added the copyright files for the files that I touched.
Summary by CodeRabbit
Performance Improvements
Refactor
✏️ Tip: You can customize this high-level summary in your review settings.