Skip to content

Commit 64ec271

Browse files
committed
short circuit autotune logics when not in autotune mode; address feedback; last minute fix expert weight dtype
Signed-off-by: Anthony Chang <[email protected]>
1 parent ae7eace commit 64ec271

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

.clangd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ CompileFlags:
2323
- cuda
2424
# Allow variadic CUDA functions
2525
- "-Xclang=-fcuda-allow-variadic-functions"
26+
- "-I/mnt/trtllm-gen/amodel/cuda/gpgpu_internal/include"
2627

2728
---
2829

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def prepare_dummy_topk_and_hook(
5656
Tuple of (routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook)
5757
"""
5858

59+
# NOTE: This prevents auto-tuning related code from being executed in actual runs
60+
tuner = AutoTuner.get()
61+
if not tuner.is_tuning_mode:
62+
return routing_logits, topk_weights, topk_ids, base_tuning_config
63+
5964
if routing_logits is None:
6065
routing_logits_for_tuner = torch.randn(hidden_states.shape[0],
6166
num_experts,
@@ -91,6 +96,7 @@ def prepare_dummy_topk_and_hook(
9196
# Attention DP: topk is pre-computed, no routing needed
9297
topk_ids_for_tuner, topk_weights_for_tuner = routing_method.apply(
9398
routing_logits_for_tuner)
99+
topk_weights_for_tuner = topk_weights_for_tuner.to(torch.bfloat16)
94100
# Don't pass routing_logits to avoid C++ warning about all three being provided
95101
routing_logits_for_tuner = None
96102
else:
@@ -122,7 +128,7 @@ def recreate_dummy_topk_if_needed(
122128
topk_ids_for_tuner, topk_weights_for_tuner = routing_method.apply(
123129
routing_logits_for_tuner)
124130
inputs[-1] = topk_ids_for_tuner
125-
inputs[-2] = topk_weights_for_tuner
131+
inputs[-2] = topk_weights_for_tuner.to(torch.bfloat16)
126132
# Note: routing_logits is None in attention DP, no need to adjust
127133
assert inputs[0] is None
128134

tensorrt_llm/_torch/modules/fused_moe/routing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def routing_method_type(self):
219219
return RoutingMethodType.Default
220220

221221

222-
class Deepseekv3RoutingImpl():
222+
class Deepseekv3RoutingImpl:
223223

224224
def __init__(
225225
self,
@@ -556,7 +556,6 @@ def routing_method_type(self) -> RoutingMethodType:
556556
return RoutingMethodType.RenormalizeNaive
557557

558558

559-
# Mapping from RoutingMethodType to the corresponding class
560559
ROUTING_METHOD_TYPE_TO_CLASS: Dict[RoutingMethodType,
561560
Type[BaseMoeRoutingMethod]] = {
562561
RoutingMethodType.Default:

0 commit comments

Comments
 (0)