diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index dfbf0c84b..a3d8b18bc 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -42,6 +42,13 @@ def make_argument_parser() -> argparse.ArgumentParser: default=42000, help="p d mode, decode node used for kv move manager rpyc server port", ) + parser.add_argument( + "--select_p_d_node_func", + type=str, + default="round_robin", + choices=["random", "round_robin", "memory"], + help="select p d node func, can be round_robin, random or memory", + ) parser.add_argument( "--config_server_host", type=str, diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 10a4a8ec5..bd9e9ca3d 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -176,9 +176,29 @@ async def _pd_process_generate( logger.error(str(e)) +# 获取节点负载信息 +def _get_load_info(have_finished_req: bool) -> dict: + if not have_finished_req: + return None + + from lightllm.server.api_http import g_objs + assert g_objs.shared_token_load is not None, "shared_token_load is not initialized" + current_load = [ + float(g_objs.shared_token_load.get_dynamic_max_load(dp_index)) for dp_index in range(g_objs.args.dp) + ] + load_info = { + "mem_len": min(current_load), + "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" + } + return load_info + + # 转发token的task async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): while True: handle_list = await forwarding_queue.wait_to_get_all_data() + if handle_list: - await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list))) + have_finished_req = any(finish_status.is_finished() for _, _, _, finish_status in handle_list) + load_info: dict = _get_load_info(have_finished_req) + await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 05b2d987c..e0311efaf 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -25,36 +25,29 @@ from lightllm.utils.statics_utils import MovingAverage from lightllm.server.httpserver.manager import AsyncQueue from lightllm.utils.error_utils import ServerBusyError +from .node_info_recorder import PredictNodeInfoRecorder +from .pd_selector import ( + create_selector, + PDSelector, +) logger = init_logger(__name__) -class HttpServerManagerForPDMaster: - def __init__( - self, - args, - metric_port, - ): +class PDManager: + def __init__(self, args): self.args = args - self.metric_client = MetricClient(metric_port) - self.id_gen = ReqIDGenerator() self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] - self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {} - - self.req_id_to_out_inf: Dict[int, ReqStatus] = {} - self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对 - - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) - - self.first_time_costs = MovingAverage() - self.per_token_costs = MovingAverage() + self.node_info_recorder: PredictNodeInfoRecorder = PredictNodeInfoRecorder() + self.selector: PDSelector = create_selector(args.select_p_d_node_func, self.prefill_nodes, self.decode_nodes, self) return async def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket - self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client + self.node_info_recorder.register_node(pd_client) + if pd_client.mode == "prefill": self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.prefill_nodes.append(pd_client) @@ -62,22 +55,69 @@ async def register_pd(self, pd_info_json, websocket): self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes.append(pd_client) else: - assert False + assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}" + + await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes) logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed") return async def remove_pd(self, pd_info_json): pd_client = PD_Client_Obj(**pd_info_json) - try: - del self.url_to_pd_nodes[pd_client.client_ip_port] - except: - pass + self.node_info_recorder.remove_node(pd_client) + self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] + + await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes) + logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") return + def update_node_load_info(self, load_info: dict): + """更新节点负载信息""" + if load_info is None: + return + self.node_info_recorder.update_node_load_info(load_info) + + def get_predict_node_infos(self): + """获取所有节点的预测负载信息""" + return self.node_info_recorder.get_predict_node_infos() + + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + p_node, d_node = await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params) + self.node_info_recorder.update_predict_node_info(p_node, d_node, prompt, sampling_params, multimodal_params) + return p_node, d_node + +class HttpServerManagerForPDMaster: + def __init__( + self, + args, + metric_port, + ): + self.args = args + self.metric_client = MetricClient(metric_port) + self.id_gen = ReqIDGenerator() + + self.pd_manager = PDManager(args) + + self.req_id_to_out_inf: Dict[int, ReqStatus] = {} + self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对 + + self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + + self.first_time_costs = MovingAverage() + self.per_token_costs = MovingAverage() + return + + async def register_pd(self, pd_info_json, websocket): + await self.pd_manager.register_pd(pd_info_json, websocket) + return + + async def remove_pd(self, pd_info_json): + await self.pd_manager.remove_pd(pd_info_json) + return + async def update_req_status(self, upkv_status: UpKVStatus): try: group_request_id = convert_sub_id_to_group_id(upkv_status.group_request_id) @@ -108,11 +148,7 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar async def select_p_d_node( self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - import random - - p_node = random.choice(self.prefill_nodes) - d_node = random.choice(self.decode_nodes) - return p_node, d_node + return await self.pd_manager.select_p_d_node(prompt, sampling_params, multimodal_params) async def generate( self, @@ -264,7 +300,7 @@ async def _wait_to_token_package( request: Request, ): out_token_counter = 0 - first_token_cost_ms = sys.float_info.max + first_token_cost_ms = float('inf') group_request_id = sampling_params.group_request_id unfinished_count = sampling_params.best_of is_first_token = True @@ -368,7 +404,16 @@ async def handle_loop(self): try: for obj in objs: if obj[0] == ObjType.TOKEN_PACKS: - for sub_req_id, text, metadata, finish_status in obj[1]: + # 检查是否包含节点信息 + if len(obj) >= 3: + handle_list, load_info = obj[1], obj[2] + # 更新节点负载信息 + self.pd_manager.update_node_load_info(load_info) + else: + # 兼容旧格式 + handle_list = obj[1] + + for sub_req_id, text, metadata, finish_status in handle_list: finish_status: FinishStatus = finish_status group_req_id = convert_sub_id_to_group_id(sub_req_id) try: diff --git a/lightllm/server/httpserver_for_pd_master/node_info_recorder.py b/lightllm/server/httpserver_for_pd_master/node_info_recorder.py new file mode 100644 index 000000000..e7c2e3eb0 --- /dev/null +++ b/lightllm/server/httpserver_for_pd_master/node_info_recorder.py @@ -0,0 +1,91 @@ +import copy + +from ..pd_io_struct import PD_Client_Obj +from lightllm.server.httpserver.manager import SamplingParams, MultimodalParams +from typing import Union, List, Dict +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class NodeInfoRecorder: + def __init__(self): + self.prefill_node_info: dict = {} + self.decode_node_info: dict = {} + + def register_node(self, pd_client: PD_Client_Obj): + node_info = { + "node_id": pd_client.node_id, + "client_ip_port": pd_client.client_ip_port, + "mode": pd_client.mode, + "node": pd_client, + "mem_len": 0, + # "batch_size": 0, + } + if pd_client.mode == "prefill": + self.prefill_node_info[pd_client.client_ip_port] = node_info + elif pd_client.mode == "decode": + self.decode_node_info[pd_client.client_ip_port] = node_info + else: + assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}" + + def remove_node(self, pd_client: PD_Client_Obj): + if pd_client.mode == "prefill": + del self.prefill_node_info[pd_client.client_ip_port] + elif pd_client.mode == "decode": + del self.decode_node_info[pd_client.client_ip_port] + else: + assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}" + + def update_node_load_info(self, load_info: dict): + if "client_ip_port" in load_info: + ip_port = load_info["client_ip_port"] + if ip_port in self.prefill_node_info: + self.prefill_node_info[ip_port]["mem_len"] = load_info["mem_len"] + elif ip_port in self.decode_node_info: + self.decode_node_info[ip_port]["mem_len"] = load_info["mem_len"] + else: + logger.warning(f"Received load info for unknown node: {ip_port}") + else: + logger.warning("Received load info without client_ip_port") + + +class PredictNodeInfoRecorder(NodeInfoRecorder): + def __init__(self): + super().__init__() + self.prefill_predict_node_info: dict = {} + self.decode_predict_node_info: dict = {} + + def register_node(self, pd_client: PD_Client_Obj): + super().register_node(pd_client) + if pd_client.mode == "prefill": + self.prefill_predict_node_info[pd_client.client_ip_port] = copy.copy(self.prefill_node_info[pd_client.client_ip_port]) + elif pd_client.mode == "decode": + self.decode_predict_node_info[pd_client.client_ip_port] = copy.copy(self.decode_node_info[pd_client.client_ip_port]) + + def remove_node(self, pd_client: PD_Client_Obj): + super().remove_node(pd_client) + if pd_client.mode == "prefill": + del self.prefill_predict_node_info[pd_client.client_ip_port] + elif pd_client.mode == "decode": + del self.decode_predict_node_info[pd_client.client_ip_port] + + def update_node_load_info(self, load_info: dict): + super().update_node_load_info(load_info) + ip_port = load_info["client_ip_port"] + if ip_port in self.prefill_node_info: + self.prefill_predict_node_info[ip_port] = copy.copy(self.prefill_node_info[ip_port]) + elif ip_port in self.decode_node_info: + self.decode_predict_node_info[ip_port] = copy.copy(self.decode_node_info[ip_port]) + else: + logger.warning(f"Received load info for unknown node: {ip_port}") + + def update_predict_node_info(self, p_node: PD_Client_Obj, d_node: PD_Client_Obj, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams): + self.prefill_predict_node_info[p_node.client_ip_port]["mem_len"] += len(prompt) + self.decode_predict_node_info[d_node.client_ip_port]["mem_len"] += sampling_params.max_new_tokens + + def get_predict_node_infos(self) -> Dict[str, dict]: + return { + "prefill": self.prefill_predict_node_info, + "decode": self.decode_predict_node_info, + } diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py new file mode 100644 index 000000000..dae927341 --- /dev/null +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py @@ -0,0 +1,25 @@ +from typing import List +from lightllm.server.httpserver_for_pd_master.manager import PD_Client_Obj +from .pd_selector import ( + PDSelector, + RandomSelector, + RoundRobinSelector, + MemorySelector +) + +__all__ = [ + "PDSelector", + "RandomSelector", + "RoundRobinSelector", + "MemorySelector" +] + +def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager) -> PDSelector: + if selector_type == "random": + return RandomSelector(prefill_nodes, decode_nodes, pd_manager) + elif selector_type == "round_robin": + return RoundRobinSelector(prefill_nodes, decode_nodes, pd_manager) + elif selector_type == "memory": + return MemorySelector(prefill_nodes, decode_nodes, pd_manager) + else: + raise ValueError(f"Invalid selector type: {selector_type}") diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py new file mode 100644 index 000000000..bae2555db --- /dev/null +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -0,0 +1,75 @@ +import random + +from typing import Union, List, Tuple, Dict +from lightllm.server.pd_io_struct import PD_Client_Obj +from lightllm.server.core.objs import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams + + +class PDSelector: + def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager): + self.prefill_nodes: List[PD_Client_Obj] = prefill_nodes + self.decode_nodes: List[PD_Client_Obj] = decode_nodes + self.pd_manager = pd_manager + + async def update_nodes(self, prefill_nodes, decode_nodes): + self.prefill_nodes = prefill_nodes + self.decode_nodes = decode_nodes + + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + raise NotImplementedError("Subclass must implement this method") + + +class RandomSelector(PDSelector): + """随机选择器""" + + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + p_node = random.choice(self.prefill_nodes) + d_node = random.choice(self.decode_nodes) + return p_node, d_node + + +class RoundRobinSelector(PDSelector): + """轮询选择器""" + + def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager): + super().__init__(prefill_nodes, decode_nodes, pd_manager) + self.prefill_node_index: int = 0 + self.decode_node_index: int = 0 + + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + p_node = self.prefill_nodes[self.prefill_node_index] + d_node = self.decode_nodes[self.decode_node_index] + self.prefill_node_index = (self.prefill_node_index + 1) % len(self.prefill_nodes) + self.decode_node_index = (self.decode_node_index + 1) % len(self.decode_nodes) + return p_node, d_node + + +class MemorySelector(PDSelector): + """基于内存使用情况的选择器""" + + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + def _get_min_node(nodes: List[PD_Client_Obj], node_infos: Dict[str, dict], key: str) -> PD_Client_Obj: + min_node, min_node_value = None, float("inf") + for node in nodes: + if node.client_ip_port in node_infos: + if node_infos[node.client_ip_port][key] < min_node_value: + min_node_value = node_infos[node.client_ip_port][key] + min_node = node + return min_node if min_node is not None else random.choice(nodes) + + if self.pd_manager is None: + # 如果没有 PDManager 引用,回退到随机选择 + p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None + d_node = random.choice(self.decode_nodes) if self.decode_nodes else None + return p_node, d_node + + node_infos = self.pd_manager.get_predict_node_infos() + + # 获取负载最小的节点 + p_node_infos = node_infos["prefill"] + d_node_infos = node_infos["decode"] + p_node = _get_min_node(self.prefill_nodes, p_node_infos, "mem_len") + d_node = _get_min_node(self.decode_nodes, d_node_infos, "mem_len") + + return p_node, d_node