Skip to content

Commit 538487a

Browse files
router and infer parrall. (#965)
Co-authored-by: baishihao <[email protected]>
1 parent 8ed97c7 commit 538487a

File tree

70 files changed

+3112
-2568
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+3112
-2568
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1818
from lightllm.common.basemodel.cuda_graph import CudaGraph
1919
from lightllm.common.quantization import Quantcfg
20+
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
2021
from lightllm.utils.log_utils import init_logger
2122
from lightllm.utils.dist_utils import get_dp_world_size
2223
from lightllm.utils.envs_utils import get_env_start_args
23-
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
24+
from lightllm.distributed.communication_op import dist_group_manager
2425
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
2526
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
2627
from lightllm.utils.envs_utils import set_model_init_status
@@ -237,6 +238,7 @@ def _init_custom(self):
237238

238239
@torch.no_grad()
239240
def forward(self, model_input: ModelInput):
241+
model_input.to_cuda()
240242
assert model_input.mem_indexes.is_cuda
241243

242244
if model_input.is_prefill:
@@ -346,6 +348,14 @@ def _decode(
346348
self,
347349
model_input: ModelInput,
348350
) -> ModelOutput:
351+
# for overlap mode
352+
if model_input.input_ids is None:
353+
model_input.input_ids = gather_token(
354+
self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
355+
model_input.b_req_idx,
356+
model_input.b_mtp_index,
357+
)
358+
349359
if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch):
350360
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
351361
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
@@ -453,6 +463,9 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo):
453463

454464
@torch.no_grad()
455465
def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: ModelInput):
466+
model_input0.to_cuda()
467+
model_input1.to_cuda()
468+
456469
assert model_input0.mem_indexes.is_cuda
457470
assert model_input1.mem_indexes.is_cuda
458471
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
@@ -490,6 +503,22 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
490503

491504
@torch.no_grad()
492505
def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput):
506+
model_input0.to_cuda()
507+
model_input1.to_cuda()
508+
509+
if model_input0.input_ids is None:
510+
model_input0.input_ids = gather_token(
511+
self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
512+
model_input0.b_req_idx,
513+
model_input0.b_mtp_index,
514+
)
515+
if model_input1.input_ids is None:
516+
model_input1.input_ids = gather_token(
517+
self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
518+
model_input1.b_req_idx,
519+
model_input1.b_mtp_index,
520+
)
521+
493522
assert model_input0.batch_size == model_input1.batch_size
494523
assert model_input0.mem_indexes.is_cuda
495524
assert model_input1.mem_indexes.is_cuda
@@ -659,6 +688,7 @@ def _check_max_len_infer(self):
659688
b_seq_len[:] = self.batch_max_tokens
660689
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
661690
total_token_num = self.batch_max_tokens
691+
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
662692
model_input = ModelInput(
663693
batch_size=1,
664694
total_token_num=total_token_num,
@@ -667,6 +697,7 @@ def _check_max_len_infer(self):
667697
mem_indexes=mem_indexes,
668698
b_req_idx=b_req_idx,
669699
b_seq_len=b_seq_len,
700+
b_mtp_index=b_mtp_index,
670701
is_prefill=True,
671702
b_ready_cache_len=b_ready_cache_len,
672703
)
@@ -714,13 +745,15 @@ def _init_padded_req(self):
714745
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
715746
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
716747
total_token_num = prefill_input_len * batch_size
748+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
717749
model_input = ModelInput(
718750
batch_size=batch_size,
719751
total_token_num=total_token_num,
720752
max_len_in_batch=prefill_input_len,
721753
input_ids=dummy_input_ids,
722754
mem_indexes=mem_indexes,
723755
b_req_idx=b_req_idx,
756+
b_mtp_index=b_mtp_index,
724757
b_seq_len=b_seq_len,
725758
b_ready_cache_len=b_ready_cache_len,
726759
is_prefill=True,

lightllm/common/basemodel/batch_objs.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from dataclasses import dataclass, field
33
from typing import Optional
4+
from typing import List
45

56

67
@dataclass
@@ -10,20 +11,38 @@ class ModelInput:
1011
total_token_num: int
1112
max_len_in_batch: int
1213
input_ids: torch.Tensor
13-
mem_indexes: torch.Tensor
1414
b_req_idx: torch.Tensor
15+
b_mtp_index: torch.Tensor
1516
b_seq_len: torch.Tensor
17+
mem_indexes: torch.Tensor = None
1618
is_prefill: bool = False
1719
b_ready_cache_len: torch.Tensor = None
1820
multimodal_params: list = field(default_factory=list)
1921

22+
# cpu 变量
23+
mem_indexes_cpu: torch.Tensor = None
24+
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
25+
# 的一些变量
26+
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
27+
2028
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
2129
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。
2230

2331
# deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下
2432
# 的 draft 模型的输入
2533
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
2634

35+
def to_cuda(self):
36+
if self.input_ids is not None:
37+
self.input_ids = self.input_ids.cuda(non_blocking=True)
38+
if self.mem_indexes is None:
39+
self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True)
40+
self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
41+
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
42+
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
43+
if self.b_ready_cache_len is not None:
44+
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
45+
2746

2847
@dataclass
2948
class ModelOutput:

lightllm/common/basemodel/cuda_graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def warmup(self, model):
202202
)
203203
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
204204
b_seq_len.fill_(seq_len)
205+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
205206

206207
model_input = ModelInput(
207208
batch_size=batch_size,
@@ -211,6 +212,7 @@ def warmup(self, model):
211212
mem_indexes=mem_indexes,
212213
b_req_idx=b_req_idx,
213214
b_seq_len=b_seq_len,
215+
b_mtp_index=b_mtp_index,
214216
is_prefill=False,
215217
**model._gen_special_model_input(batch_size),
216218
)
@@ -256,13 +258,15 @@ def warmup_overlap(self, model):
256258
)
257259
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
258260
b_seq_len.fill_(seq_len)
261+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
259262

260263
micro_batch = ModelInput(
261264
is_prefill=False,
262265
batch_size=batch_size,
263266
total_token_num=total_token_num,
264267
max_len_in_batch=max_len_in_batch,
265268
input_ids=input_ids,
269+
b_mtp_index=b_mtp_index,
266270
mem_indexes=mem_indexes,
267271
b_req_idx=b_req_idx,
268272
b_seq_len=b_seq_len,

lightllm/common/basemodel/infer_struct.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .triton_kernel.gen_prefill_params import gen_prefill_params
77
from .triton_kernel.gen_decode_params import gen_decode_params
88
from .triton_kernel.multimodal_emb import mark_multimodal_obj
9+
from .batch_objs import ModelInput
910

1011

1112
class InferStateInfo:
@@ -87,9 +88,10 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
8788
self.b_kv_seq_len,
8889
self.b1_cu_kv_seq_len,
8990
self.position_ids,
90-
self.max_q_seq_len,
91-
self.max_kv_seq_len,
92-
) = gen_decode_params(b_seq_len=self.b_seq_len)
91+
) = gen_decode_params(self.b_seq_len)
92+
self.max_q_seq_len = 1
93+
# TODO: check the correctness
94+
self.max_kv_seq_len = self.max_len_in_batch
9395
self.b_start_loc = self.b1_cu_kv_seq_len[0:-1]
9496

9597
def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _fwd_kernel_scatter(
9+
next_token_ids,
10+
req_to_next_token_ids,
11+
b_req_idx,
12+
b_mtp_index,
13+
b_has_out,
14+
req_to_next_token_ids_stride,
15+
req_to_next_token_ids_stride_1,
16+
num_size,
17+
HAS_OUT_IS_NONE: tl.constexpr,
18+
BLOCK: tl.constexpr,
19+
):
20+
block_index = tl.program_id(0)
21+
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
22+
block_mask = block_range < num_size
23+
24+
cur_req_idx = tl.load(b_req_idx + block_range, mask=block_mask)
25+
cur_mtp_index = tl.load(b_mtp_index + block_range, mask=block_mask)
26+
cur_next_token_id = tl.load(next_token_ids + block_range, mask=block_mask)
27+
28+
if not HAS_OUT_IS_NONE:
29+
cur_has_out = tl.load(b_has_out + block_range, mask=block_mask, other=False)
30+
tl.store(
31+
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index,
32+
cur_next_token_id,
33+
mask=cur_has_out & block_mask,
34+
)
35+
else:
36+
tl.store(
37+
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index,
38+
cur_next_token_id,
39+
mask=block_mask,
40+
)
41+
42+
return
43+
44+
45+
@torch.no_grad()
46+
def scatter_token(
47+
next_token_ids: torch.Tensor,
48+
req_to_next_token_ids: torch.Tensor,
49+
b_req_idx: torch.Tensor,
50+
b_mtp_index: torch.Tensor,
51+
b_has_out: torch.Tensor = None,
52+
):
53+
"""
54+
This function is used to scatter the token_info(GPU tensor) to the req_to_token_info(CPU tensor).
55+
Args:
56+
next_token_ids: (batch_size,)
57+
req_to_next_token_ids: (max_req_num, max_mtp_step)
58+
b_req_idx: (batch_size,)
59+
b_mtp_index: (batch_size,)
60+
"""
61+
assert next_token_ids.shape[0] == b_req_idx.shape[0]
62+
batch_size = b_req_idx.shape[0]
63+
BLOCK = 256
64+
65+
grid = (triton.cdiv(batch_size, BLOCK),)
66+
num_warps = 1
67+
68+
_fwd_kernel_scatter[grid](
69+
next_token_ids=next_token_ids,
70+
req_to_next_token_ids=req_to_next_token_ids,
71+
b_req_idx=b_req_idx,
72+
b_mtp_index=b_mtp_index,
73+
b_has_out=b_has_out,
74+
req_to_next_token_ids_stride=req_to_next_token_ids.stride(0),
75+
req_to_next_token_ids_stride_1=req_to_next_token_ids.stride(1),
76+
num_size=batch_size,
77+
HAS_OUT_IS_NONE=b_has_out is None,
78+
BLOCK=BLOCK,
79+
num_warps=num_warps,
80+
num_stages=1,
81+
)
82+
return
83+
84+
85+
@triton.jit
86+
def _fwd_kernel_gather(
87+
req_to_next_token_ids,
88+
req_to_next_token_ids_stride,
89+
req_to_next_token_ids_stride_1,
90+
output,
91+
b_req_idx,
92+
b_mtp_index,
93+
num_size,
94+
BLOCK: tl.constexpr,
95+
):
96+
block_index = tl.program_id(0)
97+
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
98+
block_mask = block_range < num_size
99+
cur_req_idx = tl.load(b_req_idx + block_range, mask=block_mask)
100+
cur_mtp_index = tl.load(b_mtp_index + block_range, mask=block_mask)
101+
cur_next_token_id = tl.load(
102+
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index, mask=block_mask
103+
)
104+
tl.store(output + block_range, cur_next_token_id, mask=block_mask)
105+
return
106+
107+
108+
def gather_token(req_to_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, b_mtp_index: torch.Tensor):
109+
"""
110+
This function is used to gather the token_info(CPU tensor) to the token_info(GPU tensor).
111+
Args:
112+
req_to_token_info: (max_req_num, max_mtp_step)
113+
b_req_idx: (batch_size,)
114+
b_mtp_index: (batch_size,)
115+
Returns:
116+
output: (batch_size,)
117+
"""
118+
batch_size = b_req_idx.shape[0]
119+
output = torch.empty(batch_size, dtype=req_to_next_token_ids.dtype, device="cuda")
120+
BLOCK = 256
121+
grid = (triton.cdiv(batch_size, BLOCK),)
122+
num_warps = 1
123+
_fwd_kernel_gather[grid](
124+
req_to_next_token_ids=req_to_next_token_ids,
125+
req_to_next_token_ids_stride=req_to_next_token_ids.stride(0),
126+
req_to_next_token_ids_stride_1=req_to_next_token_ids.stride(1),
127+
output=output,
128+
b_req_idx=b_req_idx,
129+
b_mtp_index=b_mtp_index,
130+
num_size=batch_size,
131+
BLOCK=BLOCK,
132+
num_warps=num_warps,
133+
num_stages=1,
134+
)
135+
return output
136+
137+
138+
def test_scatter_token_to_cpu():
139+
batch_size = 30
140+
req_to_token_info = torch.zeros((1000, 1), dtype=torch.float32, pin_memory=True)
141+
token_info = torch.randn((batch_size,)).cuda()
142+
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
143+
mtp_index = torch.zeros((batch_size,), dtype=torch.int32).cuda()
144+
scatter_token(token_info, req_to_token_info, req_ids, mtp_index)
145+
diff = (req_to_token_info[20 : 20 + batch_size].cuda().view(-1) - token_info).abs().max()
146+
assert diff < 1e-6
147+
print("test_scatter_token_to_cpu passed")
148+
149+
150+
def test_gather_token():
151+
batch_size = 30
152+
req_to_token_info = torch.zeros((1000, 1), dtype=torch.float32, pin_memory=True)
153+
token_info = torch.randn((batch_size,)).cuda()
154+
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
155+
mtp_index = torch.zeros((batch_size,), dtype=torch.int32).cuda()
156+
scatter_token(token_info, req_to_token_info, req_ids, mtp_index)
157+
output = gather_token(req_to_token_info, req_ids, mtp_index)
158+
diff = (token_info - output).abs().max()
159+
assert diff < 1e-6
160+
print("test_gather_token passed")
161+
162+
163+
if __name__ == "__main__":
164+
test_scatter_token_to_cpu()
165+
test_gather_token()

lightllm/common/basemodel/triton_kernel/gen_decode_params.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,5 @@ def gen_decode_params(b_seq_len: torch.Tensor):
1010
position_ids = b_seq_len - 1
1111
b_q_seq_len = torch.ones_like(b_seq_len)
1212
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)
13-
max_q_seq_len = b_q_seq_len.max().item()
14-
max_kv_seq_len = b_kv_seq_len.max().item()
15-
return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids, max_q_seq_len, max_kv_seq_len
13+
14+
return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids

0 commit comments

Comments
 (0)