|
1 | 1 | from dataclasses import dataclass, replace |
2 | | -from enum import Enum |
3 | 2 | from functools import lru_cache |
4 | 3 | from typing import List, Optional, Tuple, Union |
5 | 4 |
|
6 | 5 | import torch |
7 | | -from torch.distributions import Normal |
8 | 6 |
|
9 | | -from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, fp4_utils, |
| 7 | +from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, TopkIdsGenMethod, |
| 8 | + create_dummy_topk_ids, fp4_utils, |
10 | 9 | get_last_power_of_2_num_tokens_buckets, |
11 | 10 | last_positive_power_of_2, |
12 | 11 | next_positive_power_of_2) |
|
15 | 14 | OptimizationProfile, TunableRunner, TuningConfig) |
16 | 15 |
|
17 | 16 |
|
18 | | -class TopkIdsGenMethod(Enum): |
19 | | - """ |
20 | | - Methods for generating dummy topk_ids for autotuning. |
21 | | -
|
22 | | - - UNIFORM: Uniform distribution; this performs the worst as it does not reflect real runs |
23 | | - - RANDINT: Uniform with duplicates; this performs better than UNIFORM and GAUSSIAN |
24 | | - - GAUSSIAN: Gaussian distribution |
25 | | - """ |
26 | | - |
27 | | - UNIFORM = "uniform" |
28 | | - RANDINT = "randint" |
29 | | - GAUSSIAN = "gaussian" |
30 | | - |
31 | | - |
32 | | -def create_dummy_topk_ids( |
33 | | - num_tokens: int, |
34 | | - num_experts: int, |
35 | | - top_k: int, |
36 | | - device: torch.device, |
37 | | - method: TopkIdsGenMethod, |
38 | | -) -> torch.Tensor: |
39 | | - """ |
40 | | - Factory function to create dummy topk_ids for autotuning. |
41 | | -
|
42 | | - Args: |
43 | | - num_tokens: Number of tokens (batch dimension) |
44 | | - num_experts: Number of experts to choose from |
45 | | - top_k: Number of experts to select per token |
46 | | - device: Device to create tensor on |
47 | | - method: Generation method (see TopkIdsGenMethod) |
48 | | -
|
49 | | - Returns: |
50 | | - topk_ids tensor of shape (num_tokens, top_k) with dtype int32 |
51 | | - """ |
52 | | - # Note: RANDINT is uniform distribution with replacement which can cause duplicates. However we |
53 | | - # settle with RANDINT for the moment because, in practice, MoE tuned with RANDINT performs better |
54 | | - # than both GAUSSIAN and UNIFORM. In the future, we should adopt GAUSSIAN(mu, sigma) because the |
55 | | - # topk_id for each token is guaranteed to be unique. |
56 | | - |
57 | | - if method == TopkIdsGenMethod.UNIFORM: |
58 | | - rand_scores = torch.rand(num_tokens, num_experts, device=device) |
59 | | - topk_ids = rand_scores.argsort(dim=1)[:, :top_k] |
60 | | - |
61 | | - elif method == TopkIdsGenMethod.RANDINT: |
62 | | - topk_ids = torch.randint(0, |
63 | | - num_experts, (num_tokens, top_k), |
64 | | - device=device) |
65 | | - |
66 | | - elif method == TopkIdsGenMethod.GAUSSIAN: |
67 | | - # Make variance proportional to num_experts |
68 | | - sigma = num_experts / 3.0 |
69 | | - # Off-center mean to get slightly long-tail distribution |
70 | | - mean = 2 * num_experts / 3 |
71 | | - normal = Normal(loc=mean, scale=sigma) |
72 | | - |
73 | | - expert_indices = torch.arange(num_experts, |
74 | | - device=device, |
75 | | - dtype=torch.float32) |
76 | | - |
77 | | - weights = torch.exp(normal.log_prob(expert_indices)) |
78 | | - |
79 | | - weights_expanded = weights.unsqueeze(0).expand(num_tokens, -1) |
80 | | - topk_ids = torch.multinomial(weights_expanded, |
81 | | - num_samples=top_k, |
82 | | - replacement=False) |
83 | | - |
84 | | - return topk_ids.to(torch.int32).to(device) |
85 | | - |
86 | | - |
87 | 17 | def prepare_dummy_topk_and_hook( |
88 | 18 | topk_weights: Optional[torch.Tensor], |
89 | 19 | topk_ids: Optional[torch.Tensor], |
|
0 commit comments