-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[TRTLLM-8946][feat] Improved heuristics to detect shardable regions #9200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TRTLLM-8946][feat] Improved heuristics to detect shardable regions #9200
Conversation
|
/bot run |
📝 WalkthroughWalkthroughRefactors the tensor-parallel (TP) sharding pipeline to accept flexible node inputs, add subgraph-aware column sharding with fused weight detection, and integrate layer/subgraph-aware handling for SSM and attention layers. Updates public API signatures, modifies enum defaults, and adjusts layer subgraph identification logic. Changes
Sequence DiagramsequenceDiagram
participant Main as TP Sharding<br/>Main Flow
participant LayerDetect as Layer Subgraph<br/>Detection
participant LayerType as Layer Type<br/>Classification
participant Sharding as Specialized<br/>Sharding Logic
participant Utils as Param<br/>Updates
Main->>LayerDetect: get_all_layer_subgraphs<br/>(nodes using is_any_lin_op)
LayerDetect-->>Main: layer_subgraphs<br/>[opening, subgraph, closing]
Main->>Main: separate simple_shards vs<br/>layer_subgraph_nodes
Main->>LayerType: detect layer type<br/>(SSM vs ATTENTION vs MLP)
alt SSM Layer Detected
LayerType-->>Sharding: _process_ssm_sharding
Sharding->>Sharding: detect fused weights<br/>via split_with_sizes
Sharding->>Utils: generate param updates<br/>per fused dimension
else ATTENTION Layer Detected
LayerType-->>Sharding: _process_column_sharding<br/>(with subgraph_nodes)
Sharding->>Sharding: detect fused weights<br/>(e.g., QKV)
Sharding->>Utils: generate param updates<br/>with layer_type=ATTENTION
else MLP or Simple
LayerType-->>Sharding: _process_simple_shard<br/>or _process_row_sharding
Sharding->>Utils: generate param updates
end
Sharding-->>Utils: apply WeightShardingInfo<br/>all_gather, layer_type
Utils-->>Main: sharding complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45–60 minutes
Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
535-755: Partial factory-config path still checks string"ep"/"bmm"whilesharding_dimsis nowShardingDimIn
detect_sharding_from_factory_config, the partial-config path checks:if sharding_config.support_partial_config: ... if "ep" in sharding_config.sharding_dims: ep_info = detect_ep_shard(gm, sharding_config) ... if "bmm" in sharding_config.sharding_dims: dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config)But
ShardingConfig.sharding_dimsis now aList[ShardingDim](matchingShardingTransformConfig), and defaults to[ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]. These"ep"/"bmm"string membership checks will therefore always be false, so EP and BMM heuristics will never run in the partial-factory-config mode.This looks like a behavior regression for users relying on
support_partial_config=True.The fix should be to compare against the enum values:
if ShardingDim.EP in sharding_config.sharding_dims: ep_info = detect_ep_shard(gm, sharding_config) ... if ShardingDim.BMM in sharding_config.sharding_dims: dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config)to keep the behavior consistent with the heuristic-only path in
Sharding._apply.tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
336-390: Mamba/SSM sharding transforms double-apply_insert_sharded_mambawith incorrect node referencesVerification confirms the core issues. In
_process_ssm_sharding(tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py), three separateWeightShardingInfoinstances are created withlayer_type=LayerType.MAMBA:
- Entry node (line 340–350)
- Weight nodes like conv1d (line 379–389)
- Out-projection node (line 418–426)
When
WeightShardingInfo.apply()is invoked for each, it calls_insert_sharded_mambaregardless of whether the node is the true first linear layer. This means_insert_sharded_mambareceives conv1d or out_proj asentry_nodeon subsequent calls, causing it to search for subgraph boundaries from the wrong starting point and mis-infer fused weight dimensions.Additionally, there's a type contract violation:
fused_weight_dimsis declared asOptional[list]inWeightShardingInfo(line 572), but_insert_sharded_mambaexpectsDict[str, list]. Since_process_ssm_shardingpasses a list value andapply()checksisinstance(self.fused_weight_dims, dict)(line 625), the dict is never passed through—it's alwaysNone, defeating fused dimension propagation.Recommend gating
_insert_sharded_mambato only execute for the entry node, or restructure so only the initial linear node carrieslayer_type=MAMBAwith the dict, while other weights use default layer type for regular TP sharding.Also note:
ShardingDim.SSMhas been removed from the enum (now onlyTP,EP,BMM), which is a breaking API change if external code references it.
🧹 Nitpick comments (6)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
1190-1211: Enum-basedShardingDimplus default dims change is a breaking surface; verify external usage
ShardingDimnow only exposesTP,EP, andBMM, andShardingConfig.sharding_dimsdefault is[ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]. Any previous code that referencedShardingDim.SSMor relied on SSM appearing in defaults will now fail at import or behave differently.Given this enum is public API, please double-check:
- All in-repo references to
ShardingDim.SSM(including configs and docs) have been removed or updated.- External configs (YAML/JSON) that previously used
"ssm"are either unsupported by design or migrated.If external users are expected to depend on this, consider keeping
SSMas a deprecated member that’s simply ignored by the sharding pipeline rather than removing it outright.
476-482:_validate_sharded_shapesis no longer invoked; confirm this is intentionalThe
_shard_parameter_nodepath foradd_dist=Falseused to call_validate_sharded_shapesto adjust hard-coded view/reshape and split params after TP sharding. That call is now commented out, and the only remaining callers for shape adjustment are the higher-level heuristics that emitParameterUpdateInfo(e.g.,_process_column_shardingand_process_ssm_shardinginsharding.py).If that’s intentional, it might be worth either:
- Removing
_validate_sharded_shapesentirely to avoid dead code, or- Wiring it back in for non-heuristic/legacy callers that still rely on
_shard_parameter_nodedirectly.Otherwise, manually constructed
WeightShardingInfoinstances that don’t go through the new layer/subgraph detection might silently stop updating downstream views/splits.Also applies to: 67-80
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
650-707:get_layer_after_linear_nodereturn type and structure are out of sync
get_layer_after_linear_nodeis annotated as returningList[Node], but it now returns a 3‑tuple-like structure ([opening_linear_nodes, backward_subgraph, terminating_linear_node]) or(None, None, None), and callers destructure it as:opening, layer_subgraph, closing = get_layer_after_linear_node(...)The implementation/usage are consistent with the new calling pattern, but the type hint and docstring are now misleading.
I’d suggest updating the signature and docstring to something like:
def get_layer_after_linear_node( linear_nodes: List[Node], terminating_indices: List[int] ) -> Tuple[Optional[List[Node]], Optional[List[Node]], Optional[Node]]: ...to reflect the actual contract.
655-667:lm_headdetection relies on node name; consider using weight target for robustnessThe special casing for the final output embedding:
if "lm_head" in extract_weight_node(linear_nodes[-1]).name: def filter_condition(node: Node) -> bool: return is_any_lin_op(node) and node != linear_nodes[-1]uses
weight_node.name, which is FX’s autogenerated node name. In many graphs the more semantically stable identifier isweight_node.target(e.g.,"lm_head.weight").To make this heuristic more robust across exporters and naming schemes, it would be safer to check
weight_node.target(or bothnameandtarget) for"lm_head".tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)
815-918: Layer-subgraph based TP sharding looks sound overall; minor nit on simple-shard input type
detect_column_row_shardnow:
- Early-exits when no linear nodes exist.
- Uses
get_all_layer_subgraphs(gm)to get(opening_nodes, layer_subgraph, closing_node)triples.- Handles SSM/Mamba, attention, and generic MLP layers differently, setting
layer_typeandmin_local_shapeappropriately.- Falls back to
_process_simple_shard(unprocessed_linear_nodes, ...)for leftover linears.- Logs a more detailed breakdown of simple vs. row/col vs. SSM vs. attention shards.
This is a good step toward more robust, subgraph-aware TP sharding.
One small mismatch:
_process_simple_shardis annotated to acceptUnion[Dict[Node, List[Node]], List[Node]]but is called here withunprocessed_linear_nodes, which is aset. Runtime-wise it works (you iterate over it), but it diverges from the type hint and docstring.If you want to keep the signature precise, consider either:
- Converting
unprocessed_linear_nodesto a list before passing it, or- Widening the type hint to
Collection[Node]or similar.Functionally, though, the new flow looks correct.
841-842: Optional: simplify concatenation of opening and closing nodesHere:
nodes_linear = opening + [closing]you could adopt the Ruff suggestion and write:
nodes_linear = [*opening, closing]which is a bit clearer about intent (one list of opening nodes plus a single closing node) and avoids an extra list literal.
Purely stylistic; behavior is identical.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py(17 hunks)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py(3 hunks)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py(3 hunks)tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py(4 hunks)
🧰 Additional context used
🧠 Learnings (7)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation with asserts for total size and TP divisibility.
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
Applied to files:
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
📚 Learning: 2025-11-14T11:22:03.729Z
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.
Applied to files:
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation with asserts for total size and TP divisibility.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.pytensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
📚 Learning: 2025-08-09T02:04:49.623Z
Learnt from: Fridah-nv
Repo: NVIDIA/TensorRT-LLM PR: 6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.623Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
📚 Learning: 2025-08-14T15:43:23.107Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: tensorrt_llm/_torch/attention_backend/trtllm.py:259-262
Timestamp: 2025-08-14T15:43:23.107Z
Learning: In TensorRT-LLM's attention backend, tensor parameters in the plan() method are assigned directly without validation (dtype, device, contiguity checks). This maintains consistency across all tensor inputs and follows the pattern of trusting callers to provide correctly formatted tensors.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
🧬 Code graph analysis (3)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
FP8TPShardingInfo(724-756)LayerType(557-561)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op(197-220)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3)
get_all_layer_subgraphs(390-413)filtered_nodes(223-271)is_any_lin_op(274-275)tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (9)
ShardingConfig(1197-1306)add(1236-1255)WeightShardingInfo(564-640)from_node(589-594)from_node(1098-1103)SplitDimension(510-518)ShardingDim(1189-1194)LayerType(557-561)ParameterUpdateInfo(643-657)
🪛 Ruff (0.14.4)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
841-841: Consider [*opening, closing] instead of concatenation
Replace with [*opening, closing]
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (2)
125-224: Distributed helper signature change looks consistent but depends onspawn_multiprocess_jobcontract
_run_sharding_execution_jobhas been extended to accept(rank, world_size)and is invoked via:dist_common.spawn_multiprocess_job( job=partial(_run_sharding_execution_job, model_cls, dist_op_expected, bias, from_config), size=device_count, )Assuming
spawn_multiprocess_jobcallsjob(rank, world_size)(orjob(rank, size)), this wiring is correct and preserves the previous behavior while makingworld_sizeexplicitly available inside the helper.Please just double‑check other usages of
spawn_multiprocess_jobin the repo to confirm the expected argument order(rank, world_size)is consistent everywhere.Also applies to: 362-384
259-283: Usinglayer_type=LayerType.ATTENTIONin expected GQA transforms aligns tests with new TP metadataIn the GQA block pattern detection, the expected
WeightShardingInfonow includes:layer_type=LayerType.ATTENTION,for Q/K/V/O linears. This matches how
detect_column_row_shardnow tags attention vs. MLP layers viaLayerType, and should keep the test comparison robust as the sharding pipeline becomes layer‑aware.No functional issues here; this looks like the right way to adapt the tests to the new metadata.
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
153-155: Transform configsharding_dimsdefault aligned with runtime config
ShardingTransformConfig.sharding_dimsnow defaults to:default_factory=lambda: [ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]which matches the updated
ShardingConfigdefault. This keeps the transform’s configuration in sync with the underlying sharding implementation and avoids surprises where a config class and runtime config diverge.Looks good as-is.
| def _process_column_sharding( | ||
| gm: GraphModule, | ||
| linear_nodes: List[Node], | ||
| subgraph_nodes: List[Node], | ||
| sharding_config: ShardingConfig, | ||
| rank: int, | ||
| world_size: int, | ||
| min_local_shape: int = 1, | ||
| fused_weight: bool = False, | ||
| ) -> None: | ||
| """ | ||
| Parse the column sharding from the candidate nodes and update the view and split nodes accordingly. | ||
| """ | ||
| fused_weight_dims = None | ||
| # check if there are split nodes in the subgraph. They may indicate fused weights (e.g., QKV) | ||
| split_nodes = list(filtered_nodes(subgraph_nodes, ops=[torch.ops.aten.split_with_sizes])) | ||
| if len(split_nodes) > 0: | ||
| assert len(linear_nodes) == 1 | ||
| linear_node = linear_nodes[0] | ||
| assert len(split_nodes) == 1, "Expecting exactly one split node for fused weights" | ||
| fused_weight_dims = split_nodes[0].args[1] | ||
| slice_nodes = list(filtered_nodes(subgraph_nodes, ops=[torch.ops.aten.slice])) | ||
| if len(slice_nodes) > 0: | ||
| # we are probably in fused QKV case with single linear node and 3 slice nodes | ||
| assert len(linear_nodes) == 1 | ||
| linear_node = linear_nodes[0] | ||
| assert all(s.args[1] == 2 for s in linear_node.users), "Expecting slice nodes with dim=3" | ||
| fused_weight_dims = [s.args[3] - s.args[2] for s in linear_node.users] | ||
| weight_dim = linear_node.meta["val"].shape[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential bug in fused QKV detection via aten.slice in _process_column_sharding
In _process_column_sharding, the slice-based fused-weight path does:
slice_nodes = list(filtered_nodes(subgraph_nodes, ops=[torch.ops.aten.slice]))
if len(slice_nodes) > 0:
assert len(linear_nodes) == 1
linear_node = linear_nodes[0]
assert all(s.args[1] == 2 for s in linear_node.users), "Expecting slice nodes with dim=3"
fused_weight_dims = [s.args[3] - s.args[2] for s in linear_node.users]Issues:
- The assertion and fused-dim computation iterate over all
linear_node.users, not just theslice_nodes. If the linear has any non‑slice users (e.g., debug ops, extra views), this will either assert incorrectly or compute bogus dims. - The assertion message says “dim=3” while checking
s.args[1] == 2, which is confusing and suggests a mismatch between expectation and condition. - If a future graph uses a mix of slices and other consumers, this code will break even though fused QKV is still valid.
I’d suggest tightening this branch to operate only on slice_nodes, e.g.:
assert all(s.args[1] == 2 for s in slice_nodes), "Expecting slice nodes with dim=2"
fused_weight_dims = [s.args[3] - s.args[2] for s in slice_nodes]and leave linear_node.users out of the fused-dim logic. This keeps the heuristic focused on the actual fused-QKV pattern instead of all downstream consumers.
Also applies to: 501-532
🧰 Tools
🪛 Ruff (0.14.4)
431-431: Unused function argument: gm
(ARG001)
🤖 Prompt for AI Agents
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py lines 430-457: the
slice-based fused-QKV detection iterates over linear_node.users instead of the
slice_nodes, and the assertion message/condition are inconsistent; restrict the
checks and dimension computations to the slice_nodes only and fix the assertion
to check the correct dim value/message (e.g., assert all(s.args[1] == 2 for s in
slice_nodes) with message "Expecting slice nodes with dim=2"), and compute
fused_weight_dims from slice_nodes (fused_weight_dims = [s.args[3] - s.args[2]
for s in slice_nodes]); apply the same change to the analogous block at lines
~501-532.
|
PR_Github #24695 [ run ] triggered by Bot. Commit: |
|
PR_Github #24695 [ run ] completed with state |
|
/bot run --reuse-test |
|
PR_Github #24785 [ run ] triggered by Bot. Commit: |
|
PR_Github #24785 [ run ] completed with state |
|
/bot run --reuse-test |
|
PR_Github #24905 [ run ] triggered by Bot. Commit: |
govind-ramnarayan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Since I don't know the code super well and Lucas is on the PR I'll let him decide when it's ready to accept.
|
PR_Github #24905 [ run ] completed with state |
|
/bot run --reuse-test |
|
PR_Github #25036 [ run ] triggered by Bot. Commit: |
|
PR_Github #25036 [ run ] completed with state |
lucaslie
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two higher-level questions that I thought off when reviewing this PR. Maybe worth filing a ticket for both of them but was curious to get your thoughts first:
- What would it take to support something like mamba sharding using our more modern pattern matching approach where we define patterns and replacement patterns in eager and let the export + pattern matching take care of removing and inserting the new subgraph? Simple examples are rms norm fusion or a more complex one is attention pattern matching
- Can we get rid of extracting the head_dim from the config via the factory? Is this even still in use?
lucaslie
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two higher-level questions that I thought off when reviewing this PR. Maybe worth filing a ticket for both of them but was curious to get your thoughts first:
- What would it take to support something like mamba sharding using our more modern pattern matching approach where we define patterns and replacement patterns in eager and let the export + pattern matching take care of removing and inserting the new subgraph? Simple examples are rms norm fusion or a more complex one is attention pattern matching
- Can we get rid of extracting the head_dim from the config via the factory? Is this even still in use?
|
/bot run |
|
PR_Github #25180 [ run ] triggered by Bot. Commit: |
|
PR_Github #25180 [ run ] completed with state |
48e7178 to
7eca0b4
Compare
|
/bot run |
|
PR_Github #25205 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #25763 [ run ] triggered by Bot. Commit: |
|
PR_Github #25763 [ run ] completed with state |
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
Signed-off-by: greg-kwasniewski1 <[email protected]>
edde9a3 to
59ebc81
Compare
|
/bot run |
|
PR_Github #25857 [ run ] triggered by Bot. Commit: |
Signed-off-by: greg-kwasniewski1 <[email protected]>
|
PR_Github #25857 [ run ] completed with state |
|
/bot run |
|
PR_Github #25878 [ run ] triggered by Bot. Commit: |
|
PR_Github #25878 [ run ] completed with state |
|
/bot run |
|
PR_Github #25893 [ run ] triggered by Bot. Commit: |
|
PR_Github #25893 [ run ] completed with state |
The new logic in detect_column_row_shard extracts individual layers (attention/MoE/MLP/SSM) not based on residual connections, but on consecutive pairs of opening/closing linear layers. Each extracted subgraph is defined by a set of opening layers (e.g,. q, k, v for attention, gate and up for MLP, etc) and a single closing linear layer (e.g., o_proj or down_proj).
For each of the subgraphs, the validity of column-row sharding is checked, and then applied either the megatron-style column-row TP sharding, or simple sharding if conditions are not met.
fixes #8946
fixes #8947
fixes #8949
fixes #4320
@coderabbitai summary
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.