Skip to content

Commit 116dc65

Browse files
committed
address feedback and fix test
Signed-off-by: Anthony Chang <[email protected]>
1 parent c1e69eb commit 116dc65

File tree

3 files changed

+73
-73
lines changed

3 files changed

+73
-73
lines changed

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 2 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from dataclasses import dataclass, replace
2-
from enum import Enum
32
from functools import lru_cache
43
from typing import List, Optional, Tuple, Union
54

65
import torch
7-
from torch.distributions import Normal
86

9-
from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, fp4_utils,
7+
from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, TopkIdsGenMethod,
8+
create_dummy_topk_ids, fp4_utils,
109
get_last_power_of_2_num_tokens_buckets,
1110
last_positive_power_of_2,
1211
next_positive_power_of_2)
@@ -15,75 +14,6 @@
1514
OptimizationProfile, TunableRunner, TuningConfig)
1615

1716

18-
class TopkIdsGenMethod(Enum):
19-
"""
20-
Methods for generating dummy topk_ids for autotuning.
21-
22-
- UNIFORM: Uniform distribution; this performs the worst as it does not reflect real runs
23-
- RANDINT: Uniform with duplicates; this performs better than UNIFORM and GAUSSIAN
24-
- GAUSSIAN: Gaussian distribution
25-
"""
26-
27-
UNIFORM = "uniform"
28-
RANDINT = "randint"
29-
GAUSSIAN = "gaussian"
30-
31-
32-
def create_dummy_topk_ids(
33-
num_tokens: int,
34-
num_experts: int,
35-
top_k: int,
36-
device: torch.device,
37-
method: TopkIdsGenMethod,
38-
) -> torch.Tensor:
39-
"""
40-
Factory function to create dummy topk_ids for autotuning.
41-
42-
Args:
43-
num_tokens: Number of tokens (batch dimension)
44-
num_experts: Number of experts to choose from
45-
top_k: Number of experts to select per token
46-
device: Device to create tensor on
47-
method: Generation method (see TopkIdsGenMethod)
48-
49-
Returns:
50-
topk_ids tensor of shape (num_tokens, top_k) with dtype int32
51-
"""
52-
# Note: RANDINT is uniform distribution with replacement which can cause duplicates. However we
53-
# settle with RANDINT for the moment because, in practice, MoE tuned with RANDINT performs better
54-
# than both GAUSSIAN and UNIFORM. In the future, we should adopt GAUSSIAN(mu, sigma) because the
55-
# topk_id for each token is guaranteed to be unique.
56-
57-
if method == TopkIdsGenMethod.UNIFORM:
58-
rand_scores = torch.rand(num_tokens, num_experts, device=device)
59-
topk_ids = rand_scores.argsort(dim=1)[:, :top_k]
60-
61-
elif method == TopkIdsGenMethod.RANDINT:
62-
topk_ids = torch.randint(0,
63-
num_experts, (num_tokens, top_k),
64-
device=device)
65-
66-
elif method == TopkIdsGenMethod.GAUSSIAN:
67-
# Make variance proportional to num_experts
68-
sigma = num_experts / 3.0
69-
# Off-center mean to get slightly long-tail distribution
70-
mean = 2 * num_experts / 3
71-
normal = Normal(loc=mean, scale=sigma)
72-
73-
expert_indices = torch.arange(num_experts,
74-
device=device,
75-
dtype=torch.float32)
76-
77-
weights = torch.exp(normal.log_prob(expert_indices))
78-
79-
weights_expanded = weights.unsqueeze(0).expand(num_tokens, -1)
80-
topk_ids = torch.multinomial(weights_expanded,
81-
num_samples=top_k,
82-
replacement=False)
83-
84-
return topk_ids.to(torch.int32).to(device)
85-
86-
8717
def prepare_dummy_topk_and_hook(
8818
topk_weights: Optional[torch.Tensor],
8919
topk_ids: Optional[torch.Tensor],

tensorrt_llm/_torch/utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Dict, List
77

88
import torch
9+
from torch.distributions import Normal
910

1011
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
1112
from tensorrt_llm.mapping import Mapping
@@ -370,3 +371,72 @@ def wrapper(*args, **kwargs):
370371
return wrapper
371372

372373
return decorator(func) if func else decorator
374+
375+
376+
class TopkIdsGenMethod(Enum):
377+
"""
378+
Methods for generating dummy topk_ids for autotuning.
379+
380+
- UNIFORM: Uniform distribution; this performs the worst as it does not reflect real runs
381+
- RANDINT: Uniform with duplicates; this performs better than UNIFORM and GAUSSIAN
382+
- GAUSSIAN: Gaussian distribution
383+
"""
384+
385+
UNIFORM = "uniform"
386+
RANDINT = "randint"
387+
GAUSSIAN = "gaussian"
388+
389+
390+
def create_dummy_topk_ids(
391+
num_tokens: int,
392+
num_experts: int,
393+
top_k: int,
394+
device: torch.device,
395+
method: TopkIdsGenMethod,
396+
) -> torch.Tensor:
397+
"""
398+
Factory function to create dummy topk_ids for autotuning.
399+
400+
Args:
401+
num_tokens: Number of tokens (batch dimension)
402+
num_experts: Number of experts to choose from
403+
top_k: Number of experts to select per token
404+
device: Device to create tensor on
405+
method: Generation method (see TopkIdsGenMethod)
406+
407+
Returns:
408+
topk_ids tensor of shape (num_tokens, top_k) with dtype int32
409+
"""
410+
# Note: RANDINT is uniform distribution with replacement which can cause duplicates. However we
411+
# settle with RANDINT for the moment because, in practice, MoE tuned with RANDINT performs better
412+
# than both GAUSSIAN and UNIFORM. In the future, we should adopt GAUSSIAN(mu, sigma) because the
413+
# topk_id for each token is guaranteed to be unique.
414+
415+
if method == TopkIdsGenMethod.UNIFORM:
416+
rand_scores = torch.rand(num_tokens, num_experts, device=device)
417+
topk_ids = rand_scores.argsort(dim=1)[:, :top_k]
418+
419+
elif method == TopkIdsGenMethod.RANDINT:
420+
topk_ids = torch.randint(0,
421+
num_experts, (num_tokens, top_k),
422+
device=device)
423+
424+
elif method == TopkIdsGenMethod.GAUSSIAN:
425+
# Make variance proportional to num_experts
426+
sigma = num_experts / 3.0
427+
# Off-center mean to get slightly long-tail distribution
428+
mean = 2 * num_experts / 3
429+
normal = Normal(loc=mean, scale=sigma)
430+
431+
expert_indices = torch.arange(num_experts,
432+
device=device,
433+
dtype=torch.float32)
434+
435+
weights = torch.exp(normal.log_prob(expert_indices))
436+
437+
weights_expanded = weights.unsqueeze(0).expand(num_tokens, -1)
438+
topk_ids = torch.multinomial(weights_expanded,
439+
num_samples=top_k,
440+
replacement=False)
441+
442+
return topk_ids.to(torch.int32).to(device)

tests/unittest/_torch/misc/test_autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def test_multiple_dynamic_shapes_cache():
318318
# We also test the cache serialization and deserialization here.
319319
AutoTuner.get().profiling_cache.clear()
320320
AutoTuner.get().profiling_cache.load_cache(
321-
os.path.join(temp_dir.name, "test_multiple_dynamic_shapes.rank0.json"))
321+
os.path.join(temp_dir.name, "test_multiple_dynamic_shapes.json"))
322322
cache_entries = tuner.profiling_cache.get_specific_custom_op(
323323
"test_multiple_dynamic_shapes")
324324

0 commit comments

Comments
 (0)