Skip to content

Fp8 deepseek #975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,10 +683,12 @@ def _check_max_len_infer(self):
logger.info("begin check max_len infer")
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
b_seq_len[:] = self.batch_max_tokens
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True
).cuda()
total_token_num = self.batch_max_tokens
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
model_input = ModelInput(
Expand Down
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ def warmup(self, model):
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len.fill_(seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda()

model_input = ModelInput(
batch_size=batch_size,
Expand Down Expand Up @@ -252,13 +252,13 @@ def warmup_overlap(self, model):
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len.fill_(seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda()

micro_batch = ModelInput(
is_prefill=False,
Expand Down
25 changes: 25 additions & 0 deletions lightllm/common/deepseek2_page_size_variable_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import numpy as np
from .deepseek2_mem_manager import Deepseek2MemoryManager
from .page_size_variable_mem_manager import PageSizeVariableMemoryManager
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_page_size


def cdiv(a, b):
return (a + b - 1) // b


logger = init_logger(__name__)


class Deepseek2PageSizeVariableMemoryManager(PageSizeVariableMemoryManager, Deepseek2MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = torch.empty(
(layer_num, cdiv(size, get_page_size()) * get_page_size(), head_num, head_dim),
dtype=dtype,
device="cuda",
)
17 changes: 16 additions & 1 deletion lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size
self.req_to_token_indexs = None

def get_cell_size(self):
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
Expand Down Expand Up @@ -243,7 +244,9 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to
def _free_buffers(self):
self.kv_buffer = None

def alloc(self, need_size) -> torch.Tensor:
def alloc(
self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False
) -> torch.Tensor:
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"
Expand All @@ -257,6 +260,9 @@ def alloc(self, need_size) -> torch.Tensor:
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
return ans

def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor):
self.req_to_token_indexs[req_idx, start:end] = values

def free(self, free_index: Union[torch.Tensor, List[int]]):
"""_summary_

Expand Down Expand Up @@ -335,8 +341,17 @@ def __init__(self) -> None:
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}")
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
]
self.shared_tp_info_pages = [
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}")
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
]

def get_unrefed_token_num(self, dp_rank_in_node: int):
if self.is_multinode_tp:
return self.shared_tp_infos[0].get_value()
return self.shared_tp_infos[dp_rank_in_node].get_value()

def get_unrefed_page_num(self, dp_rank_in_node: int):
if self.is_multinode_tp:
return self.shared_tp_info_pages[0].get_value()
return self.shared_tp_info_pages[dp_rank_in_node].get_value()
4 changes: 4 additions & 0 deletions lightllm/common/mem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
from lightllm.common.page_size_variable_mem_manager import PageSizeVariableMemoryManager
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)
Expand All @@ -28,6 +29,9 @@ def select_mem_manager_class(mode):
elif "export_fp8kv_calibration" in mode:
memory_manager_class = ExportCalibrationMemoryManager
logger.info("Using mode export fp8kv calibration")
elif "page_size_variable" in mode:
memory_manager_class = PageSizeVariableMemoryManager
logger.info("Page size will be variable")
else:
memory_manager_class = MemoryManager
logger.info("Model kv cache using mode normal")
Expand Down
184 changes: 184 additions & 0 deletions lightllm/common/page_size_variable_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import torch
import numpy as np
from .mem_manager import MemoryManager
from typing import List, Union
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_unique_server_name, get_page_size
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from lightllm.utils.dist_utils import get_current_rank_in_node


def cdiv(a, b):
return (a + b - 1) // b


logger = init_logger(__name__)


class PageSizeVariableMemoryManager(MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
self.req_to_page_indexs = None
page_size = get_page_size()
self.page_idx_pool = torch.arange(
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_page_start = 0
self.can_use_page_size = cdiv(self.size, page_size)

rank_in_node = get_current_rank_in_node()
self.shared_can_use_page_num = SharedInt(
f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}"
)
self.shared_can_use_page_num.set_value(self.can_use_page_size)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = torch.empty(
(layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim),
dtype=dtype,
device="cuda",
)

# 要求长度必须是page_size的整数倍,page内token索引必须连续
def check_cache_page_valid(self, values: torch.Tensor):
end = len(values)
assert end % self.page_size == 0, "Values length must be a multiple of page size"
total_pages = end // self.page_size
for page_idx in range(total_pages):
values_start = page_idx * self.page_size
values_end = min((page_idx + 1) * self.page_size, end)
page_token_idxs = values[values_start:values_end]
if len(page_token_idxs) > 1:
expected_idxs = torch.arange(
page_token_idxs[0],
page_token_idxs[0] + len(page_token_idxs),
dtype=page_token_idxs.dtype,
device=page_token_idxs.device,
)
if not torch.equal(page_token_idxs, expected_idxs):
return False
return True

def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor):
# assert self.check_cache_page_valid(values), "Values must be valid for page size"
page_size = get_page_size()
self.req_to_page_indexs[req_idx, start // page_size : end // page_size] = values[::page_size] // page_size
self.req_to_token_indexs[req_idx, start:end] = values

def expand_by_page_size(self, b_token_len, page_size):
# 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4
b_page_len = cdiv(b_token_len, page_size)
need_pages_num = b_page_len.sum()
p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device)
cumsum_pages = torch.cumsum(b_page_len, dim=0)
last_page_positions = cumsum_pages - 1
remainders = b_token_len - (b_page_len - 1) * page_size
p_token_len[last_page_positions] = remainders
return need_pages_num, b_page_len, p_token_len

def get_paged_token_indexs(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill):
if is_prefill:
b_req_idx = b_req_idx.cuda()
b_seq_len = b_seq_len.cuda()
b_ready_cache_len = b_ready_cache_len.cuda()

b_token_len = b_seq_len - b_ready_cache_len
total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size)
if self.can_use_page_size < total_pages_needed:
raise RuntimeError(
f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {total_pages_needed}"
)

allocated_pages = self.page_idx_pool[
self.mark_page_start : self.mark_page_start + total_pages_needed
].cuda()

def get_offsets_by_length(b_len, max_len):
# 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4]
offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device)
offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1)
return torch.masked_select(offsets, offset_mask)

page_offsets = get_offsets_by_length(b_page_len, b_page_len.max())
token_offsets = get_offsets_by_length(p_token_len, page_size)

# 更新req_to_page_indexs, b_ready_cache_len必整除page_size
page_starts = b_ready_cache_len // page_size
req_id = torch.repeat_interleave(
torch.arange(len(b_req_idx), dtype=b_token_len.dtype, device=b_token_len.device), b_page_len
)
self.req_to_page_indexs[b_req_idx[req_id], page_starts[req_id] + page_offsets] = allocated_pages

self.mark_page_start += total_pages_needed
self.can_use_page_size -= total_pages_needed
page_bases = allocated_pages * page_size
return torch.repeat_interleave(page_bases, p_token_len) + token_offsets
else:
b_seq_len = b_seq_len.cuda()
b_req_idx = b_req_idx.cuda()
need_new_page_mask = (b_seq_len - 1) % page_size == 0
new_pages_num = need_new_page_mask.sum()
if self.can_use_page_size < new_pages_num:
raise RuntimeError(
f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {new_pages_num}"
)

token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device)
if new_pages_num > 0:
new_pages = self.page_idx_pool[self.mark_page_start : self.mark_page_start + new_pages_num].cuda()
self.mark_page_start += new_pages_num
self.can_use_page_size -= new_pages_num
token_idxs[need_new_page_mask] = new_pages * page_size

# 需要更新req_to_page_indexs
new_page_req_indices = b_req_idx[need_new_page_mask]
page_positions = (b_seq_len[need_new_page_mask] - 1) // page_size
self.req_to_page_indexs[new_page_req_indices, page_positions] = new_pages

mask = ~need_new_page_mask
if mask.any():
seq_lens = b_seq_len[mask]
token_idxs[mask] = (
self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size
+ (seq_lens - 1) % page_size
)
return token_idxs

def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_prefill=False) -> torch.Tensor:
page_size = get_page_size()
token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill)
self.can_use_mem_size -= need_size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.shared_can_use_page_num.set_value(self.can_use_page_size)
return token_idxs

def free(self, free_index: Union[torch.Tensor, List[int]]):
self.can_use_mem_size += len(free_index)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

page_size = get_page_size()
if isinstance(free_index, list):
free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True)

if len(free_index) == 0:
return

base_free_index = free_index[free_index % page_size == 0]
page_indices = base_free_index // page_size
for page_idx in sorted(page_indices, reverse=True): # 逆序放回,保持池的相对顺序
self.mark_page_start -= 1
self.page_idx_pool[self.mark_page_start] = page_idx
self.can_use_page_size += 1
self.shared_can_use_page_num.set_value(self.can_use_page_size)

return

def free_all(self):
super().free_all()
page_size = get_page_size()
self.mark_page_start = 0
self.can_use_page_size = cdiv(self.size, page_size)
self.shared_can_use_page_num.set_value(self.can_use_page_size)
self.page_idx_pool = torch.arange(
0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
11 changes: 10 additions & 1 deletion lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size
from lightllm.utils.config_utils import get_vocab_size

logger = init_logger(__name__)
Expand Down Expand Up @@ -62,6 +62,15 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
self.req_to_token_indexs = torch.zeros(
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
)
mem_manager.req_to_token_indexs = self.req_to_token_indexs
if hasattr(mem_manager, "req_to_page_indexs"):
page_size = get_page_size()
self.req_to_page_indexs = torch.zeros(
(max_request_num + 1, (max_sequence_length + page_size - 1) // page_size),
dtype=torch.int32,
device="cuda",
)
mem_manager.req_to_page_indexs = self.req_to_page_indexs
self.mem_manager = mem_manager
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
self.max_request_num = max_request_num
Expand Down
33 changes: 21 additions & 12 deletions lightllm/models/deepseek2/flashattention_infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
import torch.distributed as dist
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.envs_utils import get_page_size


def cdiv(a, b):
return (a + b - 1) // b


class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo):
_shared_page_table_buffer = None

def __init__(self):
super().__init__()
self.page_size = get_page_size()

@classmethod
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
Expand Down Expand Up @@ -39,19 +45,22 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.cu_seqlens_k = self.b1_cu_kv_seq_len
max_seq_len_k = self.max_kv_seq_len
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(
model.graph_max_batch_size, model.graph_max_len_in_batch
length = cdiv(model.graph_max_len_in_batch, self.page_size)
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length)
self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape(
self.batch_size, length
)
self.page_table = page_buffer[self.microbatch_index][
: self.batch_size * model.graph_max_len_in_batch
].reshape(self.batch_size, model.graph_max_len_in_batch)
else:
self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to(
input_ids.device
)
length = cdiv(self.max_len_in_batch, self.page_size)
self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device)

self.page_table[:, :max_seq_len_k].copy_(
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
)
self.page_table[:, max_seq_len_k:].fill_(0)
if "page_size_variable" in model.mode:
length = cdiv(max_seq_len_k, self.page_size)
self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length])
self.page_table[:, length:].fill_(0)
else:
self.page_table[:, :max_seq_len_k].copy_(
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
)
self.page_table[:, max_seq_len_k:].fill_(0)
return
Loading