Skip to content

Conversation

@greg-kwasniewski1
Copy link
Collaborator

@greg-kwasniewski1 greg-kwasniewski1 commented Nov 16, 2025

The new logic in detect_column_row_shard extracts individual layers (attention/MoE/MLP/SSM) not based on residual connections, but on consecutive pairs of opening/closing linear layers. Each extracted subgraph is defined by a set of opening layers (e.g,. q, k, v for attention, gate and up for MLP, etc) and a single closing linear layer (e.g., o_proj or down_proj).

For each of the subgraphs, the validity of column-row sharding is checked, and then applied either the megatron-style column-row TP sharding, or simple sharding if conditions are not met.

fixes #8946
fixes #8947
fixes #8949
fixes #4320

@coderabbitai summary

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@greg-kwasniewski1 greg-kwasniewski1 requested a review from a team as a code owner November 16, 2025 21:49
@greg-kwasniewski1 greg-kwasniewski1 added the AutoDeploy <NV> AutoDeploy Backend label Nov 16, 2025
@greg-kwasniewski1 greg-kwasniewski1 moved this from Backlog to In review in AutoDeploy Board Nov 16, 2025
@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 16, 2025

📝 Walkthrough

Walkthrough

Refactors the tensor-parallel (TP) sharding pipeline to accept flexible node inputs, add subgraph-aware column sharding with fused weight detection, and integrate layer/subgraph-aware handling for SSM and attention layers. Updates public API signatures, modifies enum defaults, and adjusts layer subgraph identification logic.

Changes

Cohort / File(s) Summary
TP sharding pipeline refactoring
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Refactored to accept flexible node inputs (dict or list) in _process_simple_shard. Added subgraph-aware column sharding with fused weight detection via _process_column_sharding(subgraph_nodes). Integrated layer-type awareness and SSM/MAMBA support via _process_ssm_sharding. Extended main flow to use layer_subgraphs and unprocessed_linear_nodes. Updated logging to track simple TP, row-column, SSM, and attention shards separately.
Layer subgraph and linear node collection
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
Updated get_all_layer_subgraphs to use is_any_lin_op instead of is_linear_op, changed accumulation to triples [opening, layer_subgraph, closing]. Modified get_layer_after_linear_node with boundary and filter conditions, added termination guards. Restructured layer subgraph construction to handle lm_head special cases and improve robustness.
Sharding utilities and enum updates
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Removed SSM member from ShardingDim enum. Updated ShardingConfig default sharding_dims from [SSM, TP, EP, BMM] to [TP, EP, BMM]. Narrowed split node filtering to only split_with_sizes. Disabled runtime validation call in _shard_parameter_node when add_dist is False.
Test updates
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
Renamed _run_job to _run_sharding_execution_job with added rank and world_size parameters. Added LayerType import. Updated expected transformations to include layer_type=LayerType.ATTENTION in WeightShardingInfo. Removed standalone __main__ invocation.

Sequence Diagram

sequenceDiagram
    participant Main as TP Sharding<br/>Main Flow
    participant LayerDetect as Layer Subgraph<br/>Detection
    participant LayerType as Layer Type<br/>Classification
    participant Sharding as Specialized<br/>Sharding Logic
    participant Utils as Param<br/>Updates

    Main->>LayerDetect: get_all_layer_subgraphs<br/>(nodes using is_any_lin_op)
    LayerDetect-->>Main: layer_subgraphs<br/>[opening, subgraph, closing]
    
    Main->>Main: separate simple_shards vs<br/>layer_subgraph_nodes
    
    Main->>LayerType: detect layer type<br/>(SSM vs ATTENTION vs MLP)
    
    alt SSM Layer Detected
        LayerType-->>Sharding: _process_ssm_sharding
        Sharding->>Sharding: detect fused weights<br/>via split_with_sizes
        Sharding->>Utils: generate param updates<br/>per fused dimension
    else ATTENTION Layer Detected
        LayerType-->>Sharding: _process_column_sharding<br/>(with subgraph_nodes)
        Sharding->>Sharding: detect fused weights<br/>(e.g., QKV)
        Sharding->>Utils: generate param updates<br/>with layer_type=ATTENTION
    else MLP or Simple
        LayerType-->>Sharding: _process_simple_shard<br/>or _process_row_sharding
        Sharding->>Utils: generate param updates
    end
    
    Sharding-->>Utils: apply WeightShardingInfo<br/>all_gather, layer_type
    Utils-->>Main: sharding complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45–60 minutes

  • Structural complexity: Interconnected changes across sharding pipeline, node utils, and config that require understanding the new layer-aware flow and how SSM/attention layers integrate with existing TP logic.
  • API changes: Public signature updates to _process_simple_shard, _process_column_sharding, enum modifications (ShardingDim), and ShardingConfig defaults that cascade through the codebase.
  • New logic branches: SSM detection and specialized sharding, fused weight handling via split/slice inspection, and layer-type-aware parameter generation add decision points requiring careful validation.
  • Test updates: Changed test helper signature and expected outputs (layer_type annotation) need verification against actual behavior.

Areas requiring extra attention:

  • Verification that is_any_lin_op replacement in node_utils.py correctly identifies all linear-like operations without over-matching or missing critical nodes.
  • Fused weight detection logic (split_with_sizes and slice inspection) and consistency checks for fused_weight_dims.
  • SSM-specific sharding path and its interaction with existing TP/EP sharding logic.
  • Correctness of the layer subgraph triple structure [opening, layer_subgraph, closing] and termination index handling.
  • Disabled validation call in _shard_parameter_node when add_dist is False—ensure this doesn't hide issues.

Possibly related PRs

  • [TRTLLM-8201][feat] Nemotron H MoE Sharding #8744: Overlapping changes to sharding.py, sharding_utils.py, and node_utils.py that introduce WeightShardingInfo, LayerType, and ShardingDim enum modifications affecting the same sharding infrastructure.

Suggested reviewers

  • Fridah-nv
  • suyoggupta
  • MrGeva

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.45% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: improved heuristics for detecting shardable regions in the TP sharding pipeline, which is the core focus of the changeset.
Description check ✅ Passed The PR description explains the key technical change (layer extraction by consecutive opening/closing linear pairs instead of residual connections) and the conditional sharding logic applied to each subgraph, meeting the core requirements.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide custom instructions to shape the summary (bullet lists, tables, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example:

"Create a concise high-level summary as a bullet-point list. Then include a Markdown table showing lines added and removed by each contributing author."


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)

535-755: Partial factory-config path still checks string "ep"/"bmm" while sharding_dims is now ShardingDim

In detect_sharding_from_factory_config, the partial-config path checks:

if sharding_config.support_partial_config:
    ...
    if "ep" in sharding_config.sharding_dims:
        ep_info = detect_ep_shard(gm, sharding_config)
    ...
    if "bmm" in sharding_config.sharding_dims:
        dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config)

But ShardingConfig.sharding_dims is now a List[ShardingDim] (matching ShardingTransformConfig), and defaults to [ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]. These "ep"/"bmm" string membership checks will therefore always be false, so EP and BMM heuristics will never run in the partial-factory-config mode.

This looks like a behavior regression for users relying on support_partial_config=True.

The fix should be to compare against the enum values:

if ShardingDim.EP in sharding_config.sharding_dims:
    ep_info = detect_ep_shard(gm, sharding_config)
...
if ShardingDim.BMM in sharding_config.sharding_dims:
    dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config)

to keep the behavior consistent with the heuristic-only path in Sharding._apply.

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

336-390: Mamba/SSM sharding transforms double-apply _insert_sharded_mamba with incorrect node references

Verification confirms the core issues. In _process_ssm_sharding (tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py), three separate WeightShardingInfo instances are created with layer_type=LayerType.MAMBA:

  • Entry node (line 340–350)
  • Weight nodes like conv1d (line 379–389)
  • Out-projection node (line 418–426)

When WeightShardingInfo.apply() is invoked for each, it calls _insert_sharded_mamba regardless of whether the node is the true first linear layer. This means _insert_sharded_mamba receives conv1d or out_proj as entry_node on subsequent calls, causing it to search for subgraph boundaries from the wrong starting point and mis-infer fused weight dimensions.

Additionally, there's a type contract violation: fused_weight_dims is declared as Optional[list] in WeightShardingInfo (line 572), but _insert_sharded_mamba expects Dict[str, list]. Since _process_ssm_sharding passes a list value and apply() checks isinstance(self.fused_weight_dims, dict) (line 625), the dict is never passed through—it's always None, defeating fused dimension propagation.

Recommend gating _insert_sharded_mamba to only execute for the entry node, or restructure so only the initial linear node carries layer_type=MAMBA with the dict, while other weights use default layer type for regular TP sharding.

Also note: ShardingDim.SSM has been removed from the enum (now only TP, EP, BMM), which is a breaking API change if external code references it.

🧹 Nitpick comments (6)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

1190-1211: Enum-based ShardingDim plus default dims change is a breaking surface; verify external usage

ShardingDim now only exposes TP, EP, and BMM, and ShardingConfig.sharding_dims default is [ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]. Any previous code that referenced ShardingDim.SSM or relied on SSM appearing in defaults will now fail at import or behave differently.

Given this enum is public API, please double-check:

  • All in-repo references to ShardingDim.SSM (including configs and docs) have been removed or updated.
  • External configs (YAML/JSON) that previously used "ssm" are either unsupported by design or migrated.

If external users are expected to depend on this, consider keeping SSM as a deprecated member that’s simply ignored by the sharding pipeline rather than removing it outright.


476-482: _validate_sharded_shapes is no longer invoked; confirm this is intentional

The _shard_parameter_node path for add_dist=False used to call _validate_sharded_shapes to adjust hard-coded view/reshape and split params after TP sharding. That call is now commented out, and the only remaining callers for shape adjustment are the higher-level heuristics that emit ParameterUpdateInfo (e.g., _process_column_sharding and _process_ssm_sharding in sharding.py).

If that’s intentional, it might be worth either:

  • Removing _validate_sharded_shapes entirely to avoid dead code, or
  • Wiring it back in for non-heuristic/legacy callers that still rely on _shard_parameter_node directly.

Otherwise, manually constructed WeightShardingInfo instances that don’t go through the new layer/subgraph detection might silently stop updating downstream views/splits.

Also applies to: 67-80

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)

650-707: get_layer_after_linear_node return type and structure are out of sync

get_layer_after_linear_node is annotated as returning List[Node], but it now returns a 3‑tuple-like structure ([opening_linear_nodes, backward_subgraph, terminating_linear_node]) or (None, None, None), and callers destructure it as:

opening, layer_subgraph, closing = get_layer_after_linear_node(...)

The implementation/usage are consistent with the new calling pattern, but the type hint and docstring are now misleading.

I’d suggest updating the signature and docstring to something like:

def get_layer_after_linear_node(
    linear_nodes: List[Node], terminating_indices: List[int]
) -> Tuple[Optional[List[Node]], Optional[List[Node]], Optional[Node]]:
    ...

to reflect the actual contract.


655-667: lm_head detection relies on node name; consider using weight target for robustness

The special casing for the final output embedding:

if "lm_head" in extract_weight_node(linear_nodes[-1]).name:
    def filter_condition(node: Node) -> bool:
        return is_any_lin_op(node) and node != linear_nodes[-1]

uses weight_node.name, which is FX’s autogenerated node name. In many graphs the more semantically stable identifier is weight_node.target (e.g., "lm_head.weight").

To make this heuristic more robust across exporters and naming schemes, it would be safer to check weight_node.target (or both name and target) for "lm_head".

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)

815-918: Layer-subgraph based TP sharding looks sound overall; minor nit on simple-shard input type

detect_column_row_shard now:

  • Early-exits when no linear nodes exist.
  • Uses get_all_layer_subgraphs(gm) to get (opening_nodes, layer_subgraph, closing_node) triples.
  • Handles SSM/Mamba, attention, and generic MLP layers differently, setting layer_type and min_local_shape appropriately.
  • Falls back to _process_simple_shard(unprocessed_linear_nodes, ...) for leftover linears.
  • Logs a more detailed breakdown of simple vs. row/col vs. SSM vs. attention shards.

This is a good step toward more robust, subgraph-aware TP sharding.

One small mismatch: _process_simple_shard is annotated to accept Union[Dict[Node, List[Node]], List[Node]] but is called here with unprocessed_linear_nodes, which is a set. Runtime-wise it works (you iterate over it), but it diverges from the type hint and docstring.

If you want to keep the signature precise, consider either:

  • Converting unprocessed_linear_nodes to a list before passing it, or
  • Widening the type hint to Collection[Node] or similar.

Functionally, though, the new flow looks correct.


841-842: Optional: simplify concatenation of opening and closing nodes

Here:

nodes_linear = opening + [closing]

you could adopt the Ruff suggestion and write:

nodes_linear = [*opening, closing]

which is a bit clearer about intent (one list of opening nodes plus a single closing node) and avoids an extra list literal.

Purely stylistic; behavior is identical.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e0f6965 and cac7fe4.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (17 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4 hunks)
🧰 Additional context used
🧠 Learnings (7)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation with asserts for total size and TP divisibility.
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
📚 Learning: 2025-11-14T11:22:03.729Z
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation with asserts for total size and TP divisibility.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
📚 Learning: 2025-08-09T02:04:49.623Z
Learnt from: Fridah-nv
Repo: NVIDIA/TensorRT-LLM PR: 6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.623Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
📚 Learning: 2025-08-14T15:43:23.107Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: tensorrt_llm/_torch/attention_backend/trtllm.py:259-262
Timestamp: 2025-08-14T15:43:23.107Z
Learning: In TensorRT-LLM's attention backend, tensor parameters in the plan() method are assigned directly without validation (dtype, device, contiguity checks). This maintains consistency across all tensor inputs and follows the pattern of trusting callers to provide correctly formatted tensors.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
🧬 Code graph analysis (3)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
  • FP8TPShardingInfo (724-756)
  • LayerType (557-561)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (197-220)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3)
  • get_all_layer_subgraphs (390-413)
  • filtered_nodes (223-271)
  • is_any_lin_op (274-275)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (9)
  • ShardingConfig (1197-1306)
  • add (1236-1255)
  • WeightShardingInfo (564-640)
  • from_node (589-594)
  • from_node (1098-1103)
  • SplitDimension (510-518)
  • ShardingDim (1189-1194)
  • LayerType (557-561)
  • ParameterUpdateInfo (643-657)
🪛 Ruff (0.14.4)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

841-841: Consider [*opening, closing] instead of concatenation

Replace with [*opening, closing]

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (2)

125-224: Distributed helper signature change looks consistent but depends on spawn_multiprocess_job contract

_run_sharding_execution_job has been extended to accept (rank, world_size) and is invoked via:

dist_common.spawn_multiprocess_job(
    job=partial(_run_sharding_execution_job, model_cls, dist_op_expected, bias, from_config),
    size=device_count,
)

Assuming spawn_multiprocess_job calls job(rank, world_size) (or job(rank, size)), this wiring is correct and preserves the previous behavior while making world_size explicitly available inside the helper.

Please just double‑check other usages of spawn_multiprocess_job in the repo to confirm the expected argument order (rank, world_size) is consistent everywhere.

Also applies to: 362-384


259-283: Using layer_type=LayerType.ATTENTION in expected GQA transforms aligns tests with new TP metadata

In the GQA block pattern detection, the expected WeightShardingInfo now includes:

layer_type=LayerType.ATTENTION,

for Q/K/V/O linears. This matches how detect_column_row_shard now tags attention vs. MLP layers via LayerType, and should keep the test comparison robust as the sharding pipeline becomes layer‑aware.

No functional issues here; this looks like the right way to adapt the tests to the new metadata.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)

153-155: Transform config sharding_dims default aligned with runtime config

ShardingTransformConfig.sharding_dims now defaults to:

default_factory=lambda: [ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]

which matches the updated ShardingConfig default. This keeps the transform’s configuration in sync with the underlying sharding implementation and avoids surprises where a config class and runtime config diverge.

Looks good as-is.

Comment on lines 430 to 479
def _process_column_sharding(
gm: GraphModule,
linear_nodes: List[Node],
subgraph_nodes: List[Node],
sharding_config: ShardingConfig,
rank: int,
world_size: int,
min_local_shape: int = 1,
fused_weight: bool = False,
) -> None:
"""
Parse the column sharding from the candidate nodes and update the view and split nodes accordingly.
"""
fused_weight_dims = None
# check if there are split nodes in the subgraph. They may indicate fused weights (e.g., QKV)
split_nodes = list(filtered_nodes(subgraph_nodes, ops=[torch.ops.aten.split_with_sizes]))
if len(split_nodes) > 0:
assert len(linear_nodes) == 1
linear_node = linear_nodes[0]
assert len(split_nodes) == 1, "Expecting exactly one split node for fused weights"
fused_weight_dims = split_nodes[0].args[1]
slice_nodes = list(filtered_nodes(subgraph_nodes, ops=[torch.ops.aten.slice]))
if len(slice_nodes) > 0:
# we are probably in fused QKV case with single linear node and 3 slice nodes
assert len(linear_nodes) == 1
linear_node = linear_nodes[0]
assert all(s.args[1] == 2 for s in linear_node.users), "Expecting slice nodes with dim=3"
fused_weight_dims = [s.args[3] - s.args[2] for s in linear_node.users]
weight_dim = linear_node.meta["val"].shape[2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential bug in fused QKV detection via aten.slice in _process_column_sharding

In _process_column_sharding, the slice-based fused-weight path does:

slice_nodes = list(filtered_nodes(subgraph_nodes, ops=[torch.ops.aten.slice]))
if len(slice_nodes) > 0:
    assert len(linear_nodes) == 1
    linear_node = linear_nodes[0]
    assert all(s.args[1] == 2 for s in linear_node.users), "Expecting slice nodes with dim=3"
    fused_weight_dims = [s.args[3] - s.args[2] for s in linear_node.users]

Issues:

  • The assertion and fused-dim computation iterate over all linear_node.users, not just the slice_nodes. If the linear has any non‑slice users (e.g., debug ops, extra views), this will either assert incorrectly or compute bogus dims.
  • The assertion message says “dim=3” while checking s.args[1] == 2, which is confusing and suggests a mismatch between expectation and condition.
  • If a future graph uses a mix of slices and other consumers, this code will break even though fused QKV is still valid.

I’d suggest tightening this branch to operate only on slice_nodes, e.g.:

assert all(s.args[1] == 2 for s in slice_nodes), "Expecting slice nodes with dim=2"
fused_weight_dims = [s.args[3] - s.args[2] for s in slice_nodes]

and leave linear_node.users out of the fused-dim logic. This keeps the heuristic focused on the actual fused-QKV pattern instead of all downstream consumers.

Also applies to: 501-532

🧰 Tools
🪛 Ruff (0.14.4)

431-431: Unused function argument: gm

(ARG001)

🤖 Prompt for AI Agents
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py lines 430-457: the
slice-based fused-QKV detection iterates over linear_node.users instead of the
slice_nodes, and the assertion message/condition are inconsistent; restrict the
checks and dimension computations to the slice_nodes only and fix the assertion
to check the correct dim value/message (e.g., assert all(s.args[1] == 2 for s in
slice_nodes) with message "Expecting slice nodes with dim=2"), and compute
fused_weight_dims from slice_nodes (fused_weight_dims = [s.args[3] - s.args[2]
for s in slice_nodes]); apply the same change to the analogous block at lines
~501-532.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24695 [ run ] triggered by Bot. Commit: cac7fe4

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24695 [ run ] completed with state SUCCESS. Commit: cac7fe4
/LLM/main/L0_MergeRequest_PR pipeline #18650 completed with status: 'FAILURE'

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24785 [ run ] triggered by Bot. Commit: cac7fe4

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24785 [ run ] completed with state SUCCESS. Commit: cac7fe4
/LLM/main/L0_MergeRequest_PR pipeline #18702 completed with status: 'FAILURE'

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24905 [ run ] triggered by Bot. Commit: cac7fe4

Copy link
Collaborator

@govind-ramnarayan govind-ramnarayan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Since I don't know the code super well and Lucas is on the PR I'll let him decide when it's ready to accept.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24905 [ run ] completed with state SUCCESS. Commit: cac7fe4
/LLM/main/L0_MergeRequest_PR pipeline #18806 completed with status: 'FAILURE'

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25036 [ run ] triggered by Bot. Commit: 48e7178

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25036 [ run ] completed with state SUCCESS. Commit: 48e7178
/LLM/main/L0_MergeRequest_PR pipeline #18918 completed with status: 'FAILURE'

Copy link
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two higher-level questions that I thought off when reviewing this PR. Maybe worth filing a ticket for both of them but was curious to get your thoughts first:

  1. What would it take to support something like mamba sharding using our more modern pattern matching approach where we define patterns and replacement patterns in eager and let the export + pattern matching take care of removing and inserting the new subgraph? Simple examples are rms norm fusion or a more complex one is attention pattern matching
  2. Can we get rid of extracting the head_dim from the config via the factory? Is this even still in use?

Copy link
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two higher-level questions that I thought off when reviewing this PR. Maybe worth filing a ticket for both of them but was curious to get your thoughts first:

  1. What would it take to support something like mamba sharding using our more modern pattern matching approach where we define patterns and replacement patterns in eager and let the export + pattern matching take care of removing and inserting the new subgraph? Simple examples are rms norm fusion or a more complex one is attention pattern matching
  2. Can we get rid of extracting the head_dim from the config via the factory? Is this even still in use?

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25180 [ run ] triggered by Bot. Commit: 48e7178

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25180 [ run ] completed with state FAILURE. Commit: 48e7178
/LLM/main/L0_MergeRequest_PR pipeline #19039 completed with status: 'FAILURE'

@greg-kwasniewski1 greg-kwasniewski1 force-pushed the gk/improved_sharding_heuristics branch from 48e7178 to 7eca0b4 Compare November 20, 2025 13:12
@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25205 [ run ] triggered by Bot. Commit: 7eca0b4

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25763 [ run ] triggered by Bot. Commit: b1949cb

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25763 [ run ] completed with state FAILURE. Commit: b1949cb
/LLM/main/L0_MergeRequest_PR pipeline #19538 completed with status: 'FAILURE'

lucaslie and others added 14 commits November 26, 2025 07:16
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
@greg-kwasniewski1 greg-kwasniewski1 force-pushed the gk/improved_sharding_heuristics branch from edde9a3 to 59ebc81 Compare November 26, 2025 15:27
@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25857 [ run ] triggered by Bot. Commit: 59ebc81

Signed-off-by: greg-kwasniewski1 <[email protected]>
@tensorrt-cicd
Copy link
Collaborator

PR_Github #25857 [ run ] completed with state SUCCESS. Commit: 59ebc81
/LLM/main/L0_MergeRequest_PR pipeline #19604 completed with status: 'FAILURE'

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25878 [ run ] triggered by Bot. Commit: 0dfd2ed

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25878 [ run ] completed with state SUCCESS. Commit: 0dfd2ed
/LLM/main/L0_MergeRequest_PR pipeline #19623 completed with status: 'FAILURE'

@greg-kwasniewski1
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25893 [ run ] triggered by Bot. Commit: 0dfd2ed

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25893 [ run ] completed with state FAILURE. Commit: 0dfd2ed
/LLM/main/L0_MergeRequest_PR pipeline #19637 completed with status: 'FAILURE'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

AutoDeploy <NV> AutoDeploy Backend

Projects

Status: In review

4 participants