diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4516e18c3..dbd4f780c 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -683,10 +683,12 @@ def _check_max_len_infer(self): logger.info("begin check max_len infer") dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda") b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") + mem_indexes = self.mem_manager.alloc( + len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() total_token_num = self.batch_max_tokens b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") model_input = ModelInput( diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 07792865e..db77298c7 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -196,13 +196,13 @@ def warmup(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda() model_input = ModelInput( batch_size=batch_size, @@ -252,13 +252,13 @@ def warmup_overlap(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda() micro_batch = ModelInput( is_prefill=False, diff --git a/lightllm/common/deepseek2_page_size_variable_mem_manager.py b/lightllm/common/deepseek2_page_size_variable_mem_manager.py new file mode 100755 index 000000000..6c3cd7014 --- /dev/null +++ b/lightllm/common/deepseek2_page_size_variable_mem_manager.py @@ -0,0 +1,25 @@ +import torch +import numpy as np +from .deepseek2_mem_manager import Deepseek2MemoryManager +from .page_size_variable_mem_manager import PageSizeVariableMemoryManager +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_page_size + + +def cdiv(a, b): + return (a + b - 1) // b + + +logger = init_logger(__name__) + + +class Deepseek2PageSizeVariableMemoryManager(PageSizeVariableMemoryManager, Deepseek2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.kv_buffer = torch.empty( + (layer_num, cdiv(size, get_page_size()) * get_page_size(), head_num, head_dim), + dtype=dtype, + device="cuda", + ) diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 4142ce4aa..483d10fd3 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -52,6 +52,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size + self.req_to_token_indexs = None def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) @@ -243,7 +244,9 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def alloc(self, need_size) -> torch.Tensor: + def alloc( + self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False + ) -> torch.Tensor: if need_size > self.mark_end - self.mark_start: logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") assert False, "error alloc state" @@ -257,6 +260,9 @@ def alloc(self, need_size) -> torch.Tensor: self.shared_can_use_token_num.set_value(self.can_use_mem_size) return ans + def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): + self.req_to_token_indexs[req_idx, start:end] = values + def free(self, free_index: Union[torch.Tensor, List[int]]): """_summary_ @@ -335,8 +341,17 @@ def __init__(self) -> None: SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}") for rank_in_node in range(0, self.node_world_size, self.dp_world_size) ] + self.shared_tp_info_pages = [ + SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}") + for rank_in_node in range(0, self.node_world_size, self.dp_world_size) + ] def get_unrefed_token_num(self, dp_rank_in_node: int): if self.is_multinode_tp: return self.shared_tp_infos[0].get_value() return self.shared_tp_infos[dp_rank_in_node].get_value() + + def get_unrefed_page_num(self, dp_rank_in_node: int): + if self.is_multinode_tp: + return self.shared_tp_info_pages[0].get_value() + return self.shared_tp_info_pages[dp_rank_in_node].get_value() diff --git a/lightllm/common/mem_utils.py b/lightllm/common/mem_utils.py index dfb8e849d..5f3ee6164 100644 --- a/lightllm/common/mem_utils.py +++ b/lightllm/common/mem_utils.py @@ -4,6 +4,7 @@ from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager +from lightllm.common.page_size_variable_mem_manager import PageSizeVariableMemoryManager from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -28,6 +29,9 @@ def select_mem_manager_class(mode): elif "export_fp8kv_calibration" in mode: memory_manager_class = ExportCalibrationMemoryManager logger.info("Using mode export fp8kv calibration") + elif "page_size_variable" in mode: + memory_manager_class = PageSizeVariableMemoryManager + logger.info("Page size will be variable") else: memory_manager_class = MemoryManager logger.info("Model kv cache using mode normal") diff --git a/lightllm/common/page_size_variable_mem_manager.py b/lightllm/common/page_size_variable_mem_manager.py new file mode 100755 index 000000000..8456f2902 --- /dev/null +++ b/lightllm/common/page_size_variable_mem_manager.py @@ -0,0 +1,184 @@ +import torch +import numpy as np +from .mem_manager import MemoryManager +from typing import List, Union +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_unique_server_name, get_page_size +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.utils.dist_utils import get_current_rank_in_node + + +def cdiv(a, b): + return (a + b - 1) // b + + +logger = init_logger(__name__) + + +class PageSizeVariableMemoryManager(MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + self.req_to_page_indexs = None + page_size = get_page_size() + self.page_idx_pool = torch.arange( + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_page_start = 0 + self.can_use_page_size = cdiv(self.size, page_size) + + rank_in_node = get_current_rank_in_node() + self.shared_can_use_page_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_page_num_{rank_in_node}" + ) + self.shared_can_use_page_num.set_value(self.can_use_page_size) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.kv_buffer = torch.empty( + (layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim), + dtype=dtype, + device="cuda", + ) + + # 要求长度必须是page_size的整数倍,page内token索引必须连续 + def check_cache_page_valid(self, values: torch.Tensor): + end = len(values) + assert end % self.page_size == 0, "Values length must be a multiple of page size" + total_pages = end // self.page_size + for page_idx in range(total_pages): + values_start = page_idx * self.page_size + values_end = min((page_idx + 1) * self.page_size, end) + page_token_idxs = values[values_start:values_end] + if len(page_token_idxs) > 1: + expected_idxs = torch.arange( + page_token_idxs[0], + page_token_idxs[0] + len(page_token_idxs), + dtype=page_token_idxs.dtype, + device=page_token_idxs.device, + ) + if not torch.equal(page_token_idxs, expected_idxs): + return False + return True + + def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): + # assert self.check_cache_page_valid(values), "Values must be valid for page size" + page_size = get_page_size() + self.req_to_page_indexs[req_idx, start // page_size : end // page_size] = values[::page_size] // page_size + self.req_to_token_indexs[req_idx, start:end] = values + + def expand_by_page_size(self, b_token_len, page_size): + # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4 + b_page_len = cdiv(b_token_len, page_size) + need_pages_num = b_page_len.sum() + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) + cumsum_pages = torch.cumsum(b_page_len, dim=0) + last_page_positions = cumsum_pages - 1 + remainders = b_token_len - (b_page_len - 1) * page_size + p_token_len[last_page_positions] = remainders + return need_pages_num, b_page_len, p_token_len + + def get_paged_token_indexs(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill): + if is_prefill: + b_req_idx = b_req_idx.cuda() + b_seq_len = b_seq_len.cuda() + b_ready_cache_len = b_ready_cache_len.cuda() + + b_token_len = b_seq_len - b_ready_cache_len + total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size) + if self.can_use_page_size < total_pages_needed: + raise RuntimeError( + f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {total_pages_needed}" + ) + + allocated_pages = self.page_idx_pool[ + self.mark_page_start : self.mark_page_start + total_pages_needed + ].cuda() + + def get_offsets_by_length(b_len, max_len): + # 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4] + offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device) + offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1) + return torch.masked_select(offsets, offset_mask) + + page_offsets = get_offsets_by_length(b_page_len, b_page_len.max()) + token_offsets = get_offsets_by_length(p_token_len, page_size) + + # 更新req_to_page_indexs, b_ready_cache_len必整除page_size + page_starts = b_ready_cache_len // page_size + req_id = torch.repeat_interleave( + torch.arange(len(b_req_idx), dtype=b_token_len.dtype, device=b_token_len.device), b_page_len + ) + self.req_to_page_indexs[b_req_idx[req_id], page_starts[req_id] + page_offsets] = allocated_pages + + self.mark_page_start += total_pages_needed + self.can_use_page_size -= total_pages_needed + page_bases = allocated_pages * page_size + return torch.repeat_interleave(page_bases, p_token_len) + token_offsets + else: + b_seq_len = b_seq_len.cuda() + b_req_idx = b_req_idx.cuda() + need_new_page_mask = (b_seq_len - 1) % page_size == 0 + new_pages_num = need_new_page_mask.sum() + if self.can_use_page_size < new_pages_num: + raise RuntimeError( + f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {new_pages_num}" + ) + + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) + if new_pages_num > 0: + new_pages = self.page_idx_pool[self.mark_page_start : self.mark_page_start + new_pages_num].cuda() + self.mark_page_start += new_pages_num + self.can_use_page_size -= new_pages_num + token_idxs[need_new_page_mask] = new_pages * page_size + + # 需要更新req_to_page_indexs + new_page_req_indices = b_req_idx[need_new_page_mask] + page_positions = (b_seq_len[need_new_page_mask] - 1) // page_size + self.req_to_page_indexs[new_page_req_indices, page_positions] = new_pages + + mask = ~need_new_page_mask + if mask.any(): + seq_lens = b_seq_len[mask] + token_idxs[mask] = ( + self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size + + (seq_lens - 1) % page_size + ) + return token_idxs + + def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_prefill=False) -> torch.Tensor: + page_size = get_page_size() + token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill) + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.shared_can_use_page_num.set_value(self.can_use_page_size) + return token_idxs + + def free(self, free_index: Union[torch.Tensor, List[int]]): + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + page_size = get_page_size() + if isinstance(free_index, list): + free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True) + + if len(free_index) == 0: + return + + base_free_index = free_index[free_index % page_size == 0] + page_indices = base_free_index // page_size + for page_idx in sorted(page_indices, reverse=True): # 逆序放回,保持池的相对顺序 + self.mark_page_start -= 1 + self.page_idx_pool[self.mark_page_start] = page_idx + self.can_use_page_size += 1 + self.shared_can_use_page_num.set_value(self.can_use_page_size) + + return + + def free_all(self): + super().free_all() + page_size = get_page_size() + self.mark_page_start = 0 + self.can_use_page_size = cdiv(self.size, page_size) + self.shared_can_use_page_num.set_value(self.can_use_page_size) + self.page_idx_pool = torch.arange( + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 59f607a01..0786bbb08 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -5,7 +5,7 @@ from typing import List, Optional from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size from lightllm.utils.config_utils import get_vocab_size logger = init_logger(__name__) @@ -62,6 +62,15 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_to_token_indexs = torch.zeros( (max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda" ) + mem_manager.req_to_token_indexs = self.req_to_token_indexs + if hasattr(mem_manager, "req_to_page_indexs"): + page_size = get_page_size() + self.req_to_page_indexs = torch.zeros( + (max_request_num + 1, (max_sequence_length + page_size - 1) // page_size), + dtype=torch.int32, + device="cuda", + ) + mem_manager.req_to_page_indexs = self.req_to_page_indexs self.mem_manager = mem_manager self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py index d2ae055ce..52ba3beb4 100644 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ b/lightllm/models/deepseek2/flashattention_infer_struct.py @@ -4,6 +4,11 @@ import torch.distributed as dist from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.envs_utils import get_page_size + + +def cdiv(a, b): + return (a + b - 1) // b class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): @@ -11,6 +16,7 @@ class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): def __init__(self): super().__init__() + self.page_size = get_page_size() @classmethod def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): @@ -39,19 +45,22 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_k = self.b1_cu_kv_seq_len max_seq_len_k = self.max_kv_seq_len if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch + length = cdiv(model.graph_max_len_in_batch, self.page_size) + page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) + self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( + self.batch_size, length ) - self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) else: - self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to( - input_ids.device - ) + length = cdiv(self.max_len_in_batch, self.page_size) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device) - self.page_table[:, :max_seq_len_k].copy_( - model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k] - ) - self.page_table[:, max_seq_len_k:].fill_(0) + if "page_size_variable" in model.mode: + length = cdiv(max_seq_len_k, self.page_size) + self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + self.page_table[:, length:].fill_(0) + else: + self.page_table[:, :max_seq_len_k].copy_( + model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k] + ) + self.page_table[:, max_seq_len_k:].fill_(0) return diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py index a00c45601..25ebe3889 100644 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -3,16 +3,21 @@ import numpy as np import torch.distributed as dist from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +def cdiv(a, b): + return (a + b - 1) // b + + class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo): def __init__(self): super().__init__() self.prefill_wrapper = None self.decode_wrapper = None self.flashinfer_extra_state = None + self.page_size = get_page_size() def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) @@ -23,24 +28,37 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device) + length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length + : self.batch_size * length ] else: self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) + if "page_size_variable" in model.mode: + b_page_len = cdiv(self.b_seq_len, self.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) + repack_kv_index( + self.req_manager.req_to_page_indexs, + self.b_req_idx, + b_page_len, + self.kv_starts[:-1], + cdiv(self.max_len_in_batch, self.page_size), + self.kv_indices, + ) + else: + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + self.b_seq_len, + self.b_start_loc, + self.max_len_in_batch, + self.kv_indices, + ) if self.decode_wrapper is None: self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.flashinfer_extra_state.workspace_buffer, @@ -58,7 +76,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.kv_lora_rank, self.flashinfer_extra_state.qk_rope_head_dim, - 1, + self.page_size, False, # causal self.flashinfer_extra_state.softmax_scale, self.flashinfer_extra_state.q_data_type, @@ -97,7 +115,7 @@ def copy_for_cuda_graph(self, new_infer_state): new_infer_state.flashinfer_extra_state.tp_q_head_num, new_infer_state.flashinfer_extra_state.kv_lora_rank, new_infer_state.flashinfer_extra_state.qk_rope_head_dim, - 1, + self.page_size, False, # causal new_infer_state.flashinfer_extra_state.softmax_scale, new_infer_state.flashinfer_extra_state.q_data_type, diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index eccbe430d..6d9e64773 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -95,6 +95,18 @@ def _bind_attention(self): self._token_attention_kernel = partial( Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self ) + elif "page_size_variable" in self.mode: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + if get_env_start_args().enable_fa3: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention_paged, self + ) + elif get_env_start_args().enable_flashinfer_decode: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer_paged, self + ) + else: + raise Exception("Page size variable mode is not supported in other backends.") else: self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) if get_env_start_args().enable_fa3: @@ -576,6 +588,35 @@ def _token_gqa_decode_attention_flashattention( ) return o_tensor + def _token_gqa_decode_attention_flashattention_paged( + self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank) + k_descale, v_descale = None, None + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=self.softmax_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o_tensor + def _token_gqa_decode_attention_flashinfer( self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): @@ -595,6 +636,25 @@ def _token_gqa_decode_attention_flashinfer( ) return o_tensor + def _token_gqa_decode_attention_flashinfer_paged( + self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) + + infer_state.decode_wrapper.run( + q_nope, + q_rope, + kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank), + kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim), + out=o_tensor, + return_lse=False, + ) + return o_tensor + def _token_gqa_decode_attention_flashdecoding( self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 9101cb963..c2380c4ab 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -10,6 +10,7 @@ from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager +from lightllm.common.deepseek2_page_size_variable_mem_manager import Deepseek2PageSizeVariableMemoryManager from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager from lightllm.utils.log_utils import init_logger from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale @@ -97,6 +98,10 @@ def _init_mem_manager(self): manager_class = Deepseek2MemoryManager if "triton_fp8kv" in self.mode: manager_class = Deepseek2FP8KVMemoryManager + elif "page_size_variable" in self.mode: + manager_class = Deepseek2PageSizeVariableMemoryManager + elif self.mode: + raise ValueError(f"Unsupported mode for deepseek2: {self.mode}") # mtp 模式下需要在mem manger上扩展draft model使用的layer added_mtp_layer_num = 0 diff --git a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py index 5b922604d..39deb1b6f 100644 --- a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py @@ -34,7 +34,7 @@ def _fwd_kernel_destindex_copy_kv( offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE) offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE) - dest_index = tl.load(Dest_loc + cur_index) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index 6259c3ccd..af0aaa2f6 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -44,7 +44,7 @@ def _sample_kv_kernel( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, mask=offs_m < block_end_loc, other=0, - ) + ).to(tl.int64) off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :] off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :] kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0) diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 98f628f07..28611e901 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -3,17 +3,22 @@ import numpy as np import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.dist_utils import get_current_device_id from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index from lightllm.common.basemodel.batch_objs import ModelInput +def cdiv(a, b): + return (a + b - 1) // b + + class FlashAttentionStateInfo(LlamaInferStateInfo): _shared_page_table_buffer = None def __init__(self): super().__init__() + self.page_size = get_page_size() @classmethod def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): @@ -28,32 +33,33 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - self.page_table = torch.empty( - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device - ) - self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) + length = cdiv(self.max_seq_len, self.page_size) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) + if "page_size_variable" in model.mode: + self.page_table.copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + else: + self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) else: # Meta information of flashattention for decoding self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch + length = cdiv(model.graph_max_len_in_batch, self.page_size) + page_buffer = FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) + self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( + self.batch_size, length ) - self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) else: - self.page_table = torch.empty( - (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device - ) + length = cdiv(self.max_len_in_batch, self.page_size) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) - self.page_table[:, :max_seq_len_k].copy_( - model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k], - non_blocking=True, - ) - self.page_table[:, max_seq_len_k:].fill_(0) + length = cdiv(max_seq_len_k, self.page_size) + if "page_size_variable" in model.mode: + self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + else: + self.page_table[:, :length].copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) + self.page_table[:, length:].fill_(0) if "offline_calibration_fp8kv" in model.mode: if self.is_prefill: diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index a0c40b57a..3b9a378c4 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -3,16 +3,21 @@ import numpy as np import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +def cdiv(a, b): + return (a + b - 1) // b + + class LlamaFlashInferStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() self.prefill_wrapper = None self.decode_wrapper = None self.flashinfer_extra_state = None + self.page_size = get_page_size() def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) @@ -22,29 +27,41 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: - self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=input_ids.device - ) + self.kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length + : self.batch_size * length ] else: self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) self.kv_starts = self.b1_cu_kv_seq_len.int() + if "page_size_variable" in model.mode: + b_page_len = cdiv(self.b_seq_len, self.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) + self.kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size + repack_kv_index( + self.req_manager.req_to_page_indexs, + self.b_req_idx, + b_page_len, + self.kv_starts[:-1], + cdiv(self.max_kv_seq_len, self.page_size), + self.kv_indices, + ) + else: + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + self.b_seq_len, + self.b_start_loc, + self.max_kv_seq_len, + self.kv_indices, + ) if self.decode_wrapper is None: self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.flashinfer_extra_state.workspace_buffer, @@ -53,16 +70,16 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): use_tensor_cores=True, paged_kv_indptr_buffer=self.kv_starts, paged_kv_indices_buffer=self.kv_indices, - paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + paged_kv_last_page_len_buffer=self.kv_last_page_len, ) self.decode_wrapper.plan( self.kv_starts, self.kv_indices, - self.kv_last_page_len_buffer, + self.kv_last_page_len, self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.tp_kv_head_num, self.flashinfer_extra_state.head_dim, - 1, + self.page_size, q_data_type=self.flashinfer_extra_state.q_data_type, kv_data_type=self.flashinfer_extra_state.kv_data_type, non_blocking=True, @@ -72,19 +89,33 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): q_starts = self.b1_cu_q_seq_len.int() kv_starts = self.b1_cu_kv_seq_len.int() kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - kv_starts[:-1], - self.max_kv_seq_len, - kv_indices, - ) + if "page_size_variable" in model.mode: + b_page_len = cdiv(self.b_seq_len, self.page_size) + kv_starts[1:] = b_page_len.cumsum(0) + kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size + repack_kv_index( + self.req_manager.req_to_page_indexs, + self.b_req_idx, + b_page_len, + kv_starts[:-1], + cdiv(self.max_kv_seq_len, self.page_size), + kv_indices, + ) + else: + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + self.b_seq_len, + kv_starts[:-1], + self.max_kv_seq_len, + kv_indices, + ) self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( self.flashinfer_extra_state.workspace_buffer, qo_indptr_buf=q_starts, @@ -100,7 +131,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.tp_kv_head_num, self.flashinfer_extra_state.head_dim, - 1, + self.page_size, causal=True, pos_encoding_mode="NONE", logits_soft_cap=0.0, @@ -115,11 +146,11 @@ def copy_for_cuda_graph(self, new_infer_state): self.decode_wrapper.plan( new_infer_state.kv_starts, new_infer_state.kv_indices, - new_infer_state.kv_last_page_len_buffer, + new_infer_state.kv_last_page_len, new_infer_state.flashinfer_extra_state.tp_q_head_num, new_infer_state.flashinfer_extra_state.tp_kv_head_num, new_infer_state.flashinfer_extra_state.head_dim, - 1, + self.page_size, q_data_type=new_infer_state.flashinfer_extra_state.q_data_type, kv_data_type=new_infer_state.flashinfer_extra_state.kv_data_type, non_blocking=True, diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 4b06a75c3..334865879 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -87,6 +87,14 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) + elif "page_size_variable" in self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_context_attention_flashattention, self + ) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_token_decode_attention_flashattention, self + ) + self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif not self.mode: self._context_attention_kernel = partial( LlamaTransformerLayerInfer._context_attention_flashattention, self @@ -99,9 +107,16 @@ def _bind_attention(self): raise Exception(f"Unsupported mode for fa3 backend: {self.mode}") return elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self - ) + if "page_size_variable" in self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_context_attention_flashinfer_kernel, self + ) + elif not self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self + ) + else: + raise Exception(f"Unsupported mode for flashinfer backend: {self.mode}") else: self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) if "ppl_int8kv" in self.mode: @@ -166,6 +181,12 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) + elif "page_size_variable" in self.mode: + assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_token_decode_attention_flashinfer, self + ) + self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif not self.mode: if get_env_start_args().enable_flashinfer_decode: self._token_attention_kernel = partial( @@ -266,6 +287,20 @@ def _context_attention_flashinfer_kernel( ) return o_tensor + def _paged_context_attention_flashinfer_kernel( + self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( + -1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_ + ) + infer_state.prefill_wrapper.run( + q.view(q.shape[0], -1, self.head_dim_), + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), + out=o_tensor.view(q.shape[0], -1, self.head_dim_), + ) + return o_tensor + def _context_attention_kernel( self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None ) -> torch.Tensor: @@ -317,6 +352,38 @@ def _context_attention_kernel_ppl_int8kv( ) return o_tensor + def _paged_context_attention_flashattention( + self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None + ): + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ + ) + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_) + q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=cache_k, + v_cache=cache_v, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=infer_state.q_max_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o + def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ @@ -546,6 +613,23 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat ) return o_tensor + def _paged_token_decode_attention_flashinfer( + self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None + ): + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) + + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( + -1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_ + ) + infer_state.decode_wrapper.run( + q.view(calcu_shape1), + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), + out=o_tensor.view(calcu_shape1), + ) + return o_tensor + def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size @@ -824,6 +908,38 @@ def _token_decode_attention_gqa_flashdecoding_vsm( alloc_tensor_func=self.alloc_tensor, ) + def _paged_token_decode_attention_flashattention( + self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None + ): + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ + ) + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_) + q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=cache_k, + v_cache=cache_v, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=sm_scale, + causal=False, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o + def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index dfbf0c84b..296ed3a2c 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -165,7 +165,7 @@ def make_argument_parser() -> argparse.ArgumentParser: nargs="+", help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding | triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv - | export_fp8kv_calibration + | export_fp8kv_calibration | page_size_variable triton_flashdecoding mode is for long context, current support llama llama2 qwen; triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; @@ -177,6 +177,8 @@ def make_argument_parser() -> argparse.ArgumentParser: Calibration need to disable cudagraph and use fa3 or flashinfer backend. ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; ppl_fp16 mode use ppl fast fp16 decode attention kernel; + page_size_variable allow to use page size > 1, use PAGE_SIZE env to set page size, + page_size_variable only support fa3 and flashinfer backend for now you need to read source code to make sure the supported detail mode for all models""", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index c2a87b4c3..2e03301b6 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -125,6 +125,13 @@ def normal_or_p_d_start(args): "--enable_flashinfer_prefill and --enable_flashinfer_decode" ) assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph" + if "page_size_variable" in args.mode: + assert args.enable_fa3 is True or ( + args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True + ), ( + "page_size_variable mode need enable fa3 or flashinfer, add --enable_fa3 or " + "--enable_flashinfer_prefill and --enable_flashinfer_decode" + ) # 部分模式还不能支持与高级动态调度算法协同,to do. if args.diverse_mode: diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py new file mode 100644 index 000000000..687fa1a22 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -0,0 +1,431 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py +import torch +import numpy as np +from typing import Tuple, Dict, Set, List +from sortedcontainers import SortedSet +from .shared_arr import SharedArray +from lightllm.utils.envs_utils import get_page_size + + +class UniqueTimeIdGenerator: + def __init__(self): + self.counter = 0 + + def generate_time_id(self): + self.counter += 1 + return self.counter + + +time_gen = UniqueTimeIdGenerator() + + +class TreeNode: + def __init__(self): + self.children: Dict[int, TreeNode] = {} # page_hash -> TreeNode + self.parent: TreeNode = None + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None # 用于记录存储的 token_index 为每个元素在 token mem 中的index位置 + self.ref_counter = 0 + self.time_id = time_gen.generate_time_id() # 用于标识时间周期 + + self.node_value_len = 0 + self.node_prefix_total_len = 0 + self.total_children_count = 0 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + def get_compare_key(self): + return (0 if self.ref_counter == 0 else 1, self.total_children_count, self.time_id) + + def _compute_key(self, tokens: torch.Tensor) -> int: + page_tokens = tokens[: self.page_size] + return page_tokens.item() if self.page_size == 1 else hash(page_tokens.cpu().numpy().tobytes()) + + def find_matched_child(self, token_id_key: torch.Tensor) -> Tuple["TreeNode", int]: + target_key = self._compute_key(token_id_key) + if target_key in self.children: + child = self.children[target_key] + prefix_len = match(token_id_key, child.token_id_key) + # 只匹配page_size的整数倍长度 + if self.page_size > 1: + if prefix_len % self.page_size != 0: + if self._page_size_is_power_of_2: + # 位运算加速 + prefix_len = prefix_len & ~self._page_size_mask + else: + prefix_len = (prefix_len // self.page_size) * self.page_size + if prefix_len == 0: + return None, 0 + return child, prefix_len + + return None, 0 + + def split_node(self, prefix_len): + split_parent_node = TreeNode() + split_parent_node.parent = self.parent + self.parent.children[self._compute_key(self.token_id_key)] = split_parent_node + + split_parent_node.token_id_key = self.token_id_key[0:prefix_len] + split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + split_parent_node.children = {} + + remaining_tokens = self.token_id_key[prefix_len:] + split_parent_node.children[self._compute_key(remaining_tokens)] = self + split_parent_node.ref_counter = self.ref_counter + split_parent_node.total_children_count = 1 + + new_len = len(split_parent_node.token_mem_index_value) + split_parent_node.node_value_len = new_len + split_parent_node.node_prefix_total_len = split_parent_node.parent.node_prefix_total_len + new_len + + self.token_id_key = remaining_tokens + self.token_mem_index_value = self.token_mem_index_value[prefix_len:] + self.parent = split_parent_node + new_len = len(self.token_mem_index_value) + self.node_value_len = new_len + self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len + return split_parent_node + + def add_and_return_new_child(self, token_id_key, token_mem_index_value): + child = TreeNode() + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + + self.children[self._compute_key(token_id_key)] = child + child.parent = self + self.total_children_count += 1 + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def remove_child(self, child_node: "TreeNode"): + del self.children[self._compute_key(child_node.token_id_key)] + child_node.parent = None + self.total_children_count -= 1 + return + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return self.total_children_count == 0 + + +def match(t1: torch.Tensor, t2: torch.Tensor) -> int: + # Ensure same shape for comparison: flatten and get min length + t1_flat = t1.flatten() + t2_flat = t2.flatten() + min_len = min(t1_flat.size(0), t2_flat.size(0)) + + # Compare elements and find first mismatch + diff = t1_flat[:min_len] != t2_flat[:min_len] + mismatch_indices = torch.nonzero(diff) + + if mismatch_indices.numel() == 0: + return min_len # All matched up to min_len + else: + return mismatch_indices[0].item() + + +class PagedRadixCache: + """ + unique_name 主要用于解决单机,多实列部署时的shm冲突 + """ + + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + self.mem_manager = mem_manager + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + # 预计算page_size相关的常量 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + self.root_node = TreeNode() + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 # 初始化为 1 保证永远不会被 evict 掉 + + self.evict_tree_set: Set[TreeNode] = SortedSet(key=lambda x: x.get_compare_key()) # 自定义比较器 + self.evict_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + + def _get_page_aligned_key(self, key, value=None, free_truncated=False): + aligned_len = len(key) + if aligned_len == 0: + return None, None + # page_size > 1时, 需要确保输入的key长度是page_size的整数倍 + if self.page_size > 1: + if aligned_len % self.page_size != 0: + if self._page_size_is_power_of_2: + # 位运算加速 + aligned_len = aligned_len & ~self._page_size_mask + else: + aligned_len = (aligned_len // self.page_size) * self.page_size + + # 释放被截断的部分 + if free_truncated and aligned_len < len(key) and self.mem_manager is not None: + truncated_value = value[aligned_len:] if value is not None else key[aligned_len:] + if len(truncated_value) > 0: + self.mem_manager.free(truncated_value) + + return ( + key[:aligned_len] if aligned_len > 0 else None, + value[:aligned_len] if value is not None and aligned_len > 0 else None, + ) + return key, value + + def insert(self, key, value=None): + if value is None: + value = key + + assert len(key) == len(value) # and len(key) >= 1 + key, value = self._get_page_aligned_key(key, value, free_truncated=True) + if key is None: + return 0 + return self._insert_helper(self.root_node, key, value) + + def _insert_helper(self, node: TreeNode, key, value): + if node.is_leaf(): + self.evict_tree_set.discard(node) + + try: + child, prefix_len = node.find_matched_child(key) + if child is not None: + if prefix_len == len(key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + remaining_key = key[prefix_len:] + remaining_value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(remaining_key, remaining_value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + new_node = node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0 + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + key, _ = self._get_page_aligned_key(key) + if key is None: + return None, 0, None + + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + return tree_node, len(value), value + else: + self.dec_node_ref_counter(self.root_node) + return None, 0, None + + def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update_refs=False) -> TreeNode: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + # from 0 to 1 need update refs token num + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + try: + if len(key) == 0: + return node + + child, prefix_len = node.find_matched_child(key) + if child is not None: + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return self._match_prefix_helper(child, key[prefix_len:], ans_value_list, update_refs=update_refs) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + # from 0 to 1 need update refs token num + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node + else: + assert False, "error state" + + return node + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert node.ref_counter == 0 and node.is_leaf() and node != self.root_node, "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return + + def assert_leafs_is_right(self): + for node in self.evict_tree_set: + if node.is_leaf() and node.ref_counter == 0: + a = node.token_mem_index_value.cuda() + assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a) + + def clear_tree_nodes(self): + """ + 该函数只在测试时调用 + """ + while True: + node: TreeNode = self.evict_tree_set.pop(0) + if node != self.root_node: + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + else: + break + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + return + + def dec_node_ref_counter(self, node: TreeNode): + if node is None: + return + # 如果减引用的是叶节点,需要先从 evict_tree_set 中移除 + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + + # 加回。 + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: TreeNode, indent): + print( + " " * indent, + f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ + time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ + node_value_len: {node.node_value_len}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + def free_radix_cache_to_get_enough_token( + self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False + ): + assert self.mem_manager is not None + need_pages = 0 + can_use_pages = 0 + if hasattr(self.mem_manager, "can_use_page_size") and self.page_size > 1 and b_seq_len is not None: + + def get_need_page_size(page_size, b_seq_len, b_ready_cache_len=None, is_prefill=False): + need_new_pages = 0 + if is_prefill: + need_tokens_array = b_seq_len - b_ready_cache_len + need_pages_array = (need_tokens_array + page_size - 1) // page_size + need_new_pages = need_pages_array.sum() + else: + mask = (b_seq_len - 1) % page_size == 0 + need_new_pages = mask.sum() + return need_new_pages + + need_pages = get_need_page_size(self.page_size, b_seq_len, b_ready_cache_len, is_prefill) + can_use_pages = self.mem_manager.can_use_page_size + if need_token_num > self.mem_manager.can_use_mem_size or need_pages > can_use_pages: + need_evict_single_token_num = need_token_num - self.mem_manager.can_use_mem_size + need_evict_page_token_num = (need_pages - can_use_pages) * self.page_size + need_evict_token_num = max(need_evict_single_token_num, need_evict_page_token_num) + remaining_tokens = self.get_tree_total_tokens_num() - self.get_refed_tokens_num() + need_evict_token_num = min(need_evict_token_num, remaining_tokens) + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + self.evict(need_evict_token_num, release_mem) + if release_mems: + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 65ec4354b..a60d0a942 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -333,7 +333,9 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token(self, need_token_num): + def free_radix_cache_to_get_enough_token( + self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False + ): assert self.mem_manager is not None if need_token_num > self.mem_manager.can_use_mem_size: need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 67d69aa38..1add8d98b 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -340,7 +340,9 @@ def _match_radix_cache(self): self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len # 从 cpu 到 gpu 是流内阻塞操作 - g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor + g_infer_context.req_manager.mem_manager.set_prefix_cache_to_req( + self.req_idx, 0, ready_cache_len, value_tensor + ) self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 @@ -461,7 +463,7 @@ def diverse_copy(self, req_manager, is_prefill): req = g_infer_context.requests_mapping[req_id] req.finish_status.set_status(FinishStatus.NO_FINISH) input_len = req.get_chuncked_input_token_len() - req_manager.req_to_token_indexs[req.req_idx][prefix_len:input_len] = cache_token_id + req_manager.mem_manager.set_prefix_cache_to_req(req.req_idx, prefix_len, input_len, cache_token_id) assert input_len == pre_input_len diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index fd75afdbf..5284d7fa9 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -10,6 +10,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.server.router.dynamic_prompt.paged_radix_cache import PagedRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -139,8 +140,9 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + radix_cache_class = PagedRadixCache if "page_size_variable" in self.mode else RadixCache self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 10090a576..448c0d987 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -77,8 +77,12 @@ def padded_prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( + input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len, True + ) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True + ) g_infer_state_lock.release() if padded_req_num > 0: @@ -162,8 +166,10 @@ def padded_prepare_decode_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_req_num) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num, b_seq_len) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len + ) g_infer_state_lock.release() if padded_req_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index d5bba1ae5..e5e871d83 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -55,8 +55,12 @@ def prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( + input_ids.shape[0], b_seq_len, b_ready_cache_len, True + ) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ) g_infer_state_lock.release() model_input = ModelInput( @@ -111,8 +115,8 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0], b_seq_len) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0], b_req_idx, b_seq_len) g_infer_state_lock.release() model_input = ModelInput( diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index f1dae4cac..ff5d2aca5 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -3,6 +3,11 @@ from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.utils.envs_utils import get_page_size + + +def cdiv(a, b): + return (a + b - 1) // b class ChunkedPrefillQueue(BaseQueue): @@ -21,8 +26,9 @@ def _init_cache_list(self, current_batch: Batch, is_busy): return # @calculate_time(show=True, min_cost_ms=0.1) - def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens): - self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis + def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens, new_batch_prefill_need_pages): + token_infos = req.get_tuple_tokens(is_busy, self.router_max_new_token_len) + self.cache_len_list.append(token_infos) # hard to analysis self.cache_len_list.sort(key=lambda x: -x[1]) left_out_len_array = np.array([e[1] for e in self.cache_len_list]) @@ -32,9 +38,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() with g_router_lock.obj: + page_size = get_page_size() + page_remaining = (len(self.cache_len_list) - 1) * page_size if page_size > 1 else 0 ok_token_num = ( need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens + < self.max_total_tokens - page_remaining ) ok_req_num = len(self.cache_len_list) <= self.running_max_req_size @@ -49,9 +57,9 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens / self.max_total_tokens, self.dp_index, ) - return True, new_batch_first_router_need_tokens + return True, new_batch_first_router_need_tokens, new_batch_prefill_need_pages else: - return False, new_batch_first_router_need_tokens + return False, new_batch_first_router_need_tokens, new_batch_prefill_need_pages # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch): @@ -77,6 +85,7 @@ def generate_new_batch(self, current_batch: Batch): waiting_queue = self.waiting_req_list + new_batch_prefill_need_pages = cdiv(new_batch_first_router_need_tokens, get_page_size()) for req in waiting_queue: if req.is_aborted: # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. @@ -84,8 +93,8 @@ def generate_new_batch(self, current_batch: Batch): aborted_count += 1 abort_req_list.append(req) continue - ok_insert, new_batch_first_router_need_tokens = self._can_add_new_req( - req, is_busy, new_batch_first_router_need_tokens + ok_insert, new_batch_first_router_need_tokens, new_batch_prefill_need_pages = self._can_add_new_req( + req, is_busy, new_batch_first_router_need_tokens, new_batch_prefill_need_pages ) if ok_insert: can_run_list.append(req) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index b78784d82..1cb030cb1 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -149,6 +149,15 @@ def get_kv_quant_calibration_inference_count(): return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_INFERENCE_COUNT", 4000)) +@lru_cache(maxsize=None) +def get_page_size(): + try: + args = get_env_start_args() + return int(os.getenv("PAGE_SIZE", 64)) if "page_size_variable" in args.mode else 1 + except: + return 1 + + g_model_init_done = False diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 73a99ff28..f59e01a6f 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -242,7 +242,9 @@ def run_forward_once( b_seq_len[i] = input_len total_token_num = batch_size * input_len - mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = model_part.req_manager.mem_manager.alloc( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() rank_id = model_kvargs["rank_id"] @@ -303,7 +305,7 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda() + mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0], b_req_idx, b_seq_len).cuda() max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index eb36bc873..3df56cdcc 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -126,7 +126,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() # Main model Prefill model_input = ModelInput( batch_size=batch_size, @@ -193,7 +195,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc( + batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len + ).cuda() model_input = ModelInput( batch_size=batch_size * (len(draft_models) + 1), diff --git a/unit_tests/models/llama/test_context_flashattention_nopad.py b/unit_tests/models/llama/test_context_flashattention_nopad.py index f24ab619b..94e61cfda 100644 --- a/unit_tests/models/llama/test_context_flashattention_nopad.py +++ b/unit_tests/models/llama/test_context_flashattention_nopad.py @@ -10,7 +10,6 @@ context_attention_fwd_no_prompt_cache, ) from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -56,8 +55,6 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state.batch_size = Z infer_state.max_len_in_batch = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len infer_state.b_ready_cache_len = b_ready_cache_len @@ -73,7 +70,7 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state.b_seq_len, infer_state.b_ready_cache_len, infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, + req_to_token_indexs, ) batch_size = Z diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py new file mode 100644 index 000000000..e7702f084 --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py @@ -0,0 +1,163 @@ +import torch +import time +import pytest +import triton as tl +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + if N_CTX > 1: + b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32, device="cuda") + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + + k_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :] + v_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :] + o1 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache.reshape(-1, 1, kv_heads, head_dim), + v_cache=v_cache.reshape(-1, 1, kv_heads, head_dim), + page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=N_CTX, + causal=True, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(f"cos_sim1: {cos_sim1}") + assert cos_sim1.item() == 1 + + k_cache_paged = k_cache.reshape(-1, page_size, kv_heads, head_dim) + v_cache_paged = v_cache.reshape(-1, page_size, kv_heads, head_dim) + + page_table_paged = torch.empty((batch_size, N_CTX // page_size), dtype=torch.int32, device="cuda") + page_table_paged.copy_(req_to_page_indexs[b_req_idx, : N_CTX // page_size]) + + o2 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache_paged, + v_cache=v_cache_paged, + page_table=page_table_paged, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=N_CTX, + causal=True, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o1, o2, atol=1e-2, rtol=1e-2) + cos_sim2 = F.cosine_similarity(o1, o2).mean() + print(f"cos_sim2: {cos_sim2}") + assert cos_sim2.item() == 1 + + +if __name__ == "__main__": + test_context_attention_fwd_fa3_fp8(32, 16384, 32, 4, 128) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py new file mode 100644 index 000000000..763a80015 --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py @@ -0,0 +1,214 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, + context_attention_fwd_no_prompt_cache, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.randint_like(b_seq_len, high=N_CTX - 1, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + + num_pages_per_seq = torch.ceil(b_seq_len.float() / page_size).int() + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(num_pages_per_seq, dim=0) + + q_indptr = q_starts.int() + kv_indptr = kv_starts.int() + + total_pages = num_pages_per_seq.sum().item() + kv_indices = torch.zeros(total_pages, dtype=torch.int32, device="cuda") + + # 设置kv_indices + b_start_loc = num_pages_per_seq.cumsum(0) - num_pages_per_seq + for req, sl, start in zip(b_req_idx, num_pages_per_seq, b_start_loc): + kv_indices[start : start + sl] = req_to_page_indexs[req][:sl] + + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, + qo_indptr_buf=q_indptr, + paged_kv_indptr_buf=kv_indptr, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len_buffer, + ) + + # 设置kv_last_page_len + kv_last_page_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + for i in range(Z): + seq_len = b_seq_len[i].item() + remainder = seq_len % page_size + kv_last_page_len[i] = remainder if remainder > 0 else page_size + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + q_heads, + kv_heads, + head_dim, + page_size, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=q.dtype, + kv_data_type=kv.dtype, + ) + k_cache = kv[:, :, :KV_HEADS, :] + v_cache = kv[:, :, KV_HEADS:, :] + wrapper.run(q, (k_cache, v_cache), out=o1, return_lse=False) + cos_sim1 = F.cosine_similarity(o, o1).mean() + assert cos_sim1 == 1.0 + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_no_prompt_cache(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + q = torch.randn((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + k = torch.randn((Z * N_CTX, KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + v = torch.randn((Z * N_CTX, KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + + o = torch.zeros((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + context_attention_fwd_no_prompt_cache( + q, + k, + v, + o, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_indptr = q_starts.int() + kv_indptr = kv_starts.int() + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, + ) + wrapper.plan( + qo_indptr=q_indptr, + kv_indptr=kv_indptr, + num_qo_heads=q_heads, + num_kv_heads=kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + q_data_type=dtype, + causal=True, + ) + wrapper.run(q, k, v, out=o1, return_lse=False) + + # assert torch.allclose(o, o1, atol=1e-2, rtol=0) + cos_sim1 = F.cosine_similarity(o, o1).mean() + assert cos_sim1 == 1.0 + + +if __name__ == "__main__": + test_context_attention_fwd(32, 16384, 32, 4, 128) # 16384 is divisible by 4 diff --git a/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py b/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py new file mode 100644 index 000000000..1de2fbc34 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py @@ -0,0 +1,186 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd +from lightllm.utils.sgl_utils import flash_attn_with_kvcache + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.float().abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_starts = torch.arange(0, Z + 1).int().cuda() + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32).to(0) + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + k_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :].contiguous() + v_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :].contiguous() + o1 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache.view(-1, 1, kv_heads, head_dim), + v_cache=v_cache.view(-1, 1, kv_heads, head_dim), + page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + assert cos_sim1 == 1 + + k_cache_paged = k_cache.reshape(-1, page_size, kv_heads, head_dim) + v_cache_paged = v_cache.reshape(-1, page_size, kv_heads, head_dim) + + page_table_paged = torch.empty((batch_size, N_CTX // page_size), dtype=torch.int32, device="cuda") + page_table_paged.copy_(req_to_page_indexs[b_req_idx, : N_CTX // page_size]) + + o2 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache_paged, + v_cache=v_cache_paged, + page_table=page_table_paged, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o1, o2, atol=1e-2, rtol=1e-2) + cos_sim2 = F.cosine_similarity(o1, o2).mean() + print(cos_sim2) + assert cos_sim2.item() == 1 + + +if __name__ == "__main__": + test_token_attention_nopad_fa3_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py new file mode 100644 index 000000000..9bb97be99 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py @@ -0,0 +1,169 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_start_loc = torch.arange(Z).cuda().int() * N_CTX + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + # gqa_decode_attention_fwd( + # q, + # kv[:,:KV_HEADS,:], + # kv[:,KV_HEADS:,:], + # o, + # req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_seq_len, + # ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + + num_pages_per_seq = torch.ceil(b_seq_len.float() / page_size).int() + kv_indptr = torch.zeros(Z + 1, dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(num_pages_per_seq, dim=0) + + # Fill the paged KV data indices + total_pages = kv_indptr[-1].item() + kv_indices = torch.zeros(total_pages, dtype=torch.int32, device="cuda") + b_start_loc = num_pages_per_seq.cumsum(0) - num_pages_per_seq + for req, sl, start in zip(b_req_idx, num_pages_per_seq, b_start_loc): + kv_indices[start : start + sl] = req_to_page_indexs[req][:sl] + + # Calculate last page lengths + kv_last_page_len = torch.zeros(Z, dtype=torch.int32, device="cuda") + for i in range(Z): + seq_len = b_seq_len[i].item() + remainder = seq_len % page_size + kv_last_page_len[i] = remainder if remainder > 0 else page_size + + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=kv_indptr, + paged_kv_indices_buffer=kv_indices, + paged_kv_last_page_len_buffer=kv_last_page_len_buffer, + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + q_heads, + kv_heads, + head_dim, + page_size, + q_data_type=dtype, + non_blocking=True, + ) + wrapper.run(q, (kv[:, :, :KV_HEADS, :], kv[:, :, KV_HEADS:, :]), out=o1, return_lse=False) + cos_sim = F.cosine_similarity(o, o1).mean() + assert cos_sim == 1.0 + + +if __name__ == "__main__": + test_token_attention_nopad(32, 16384, 32, 4, 128)