diff --git a/docker/Dockerfile.nixl b/docker/Dockerfile.nixl new file mode 100644 index 000000000..ef37d21d6 --- /dev/null +++ b/docker/Dockerfile.nixl @@ -0,0 +1,83 @@ +ARG CUDA_VERSION=12.6.1 +FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 +ARG PYTHON_VERSION=3.10 +ARG MAMBA_VERSION=24.7.1-0 +ARG TARGETPLATFORM +ENV PATH=/opt/conda/bin:$PATH \ + CONDA_PREFIX=/opt/conda + +RUN chmod 777 -R /tmp && apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ca-certificates \ + libssl-dev \ + curl \ + g++ \ + make \ + git && \ + rm -rf /var/lib/apt/lists/* + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -o ~/mambaforge.sh -v "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + + +WORKDIR /root + +COPY ./requirements.txt /lightllm/requirements.txt +RUN --mount=type=cache,target=/root/.cache/pip pip install -r /lightllm/requirements.txt --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu124 + +RUN --mount=type=cache,target=/root/.cache/pip pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly +RUN --mount=type=cache,target=/root/.cache/pip git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . + +RUN apt-get update && apt-get install -y libnuma-dev # for sgl_kernel + +RUN apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ + DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ + rm -rf /usr/lib/ucx && \ + rm -rf /opt/hpcx/ucx && \ + cd /usr/local/src && \ + git clone https://github.com/openucx/ucx.git && \ + cd ucx && \ + git checkout v1.19.x && \ + ./autogen.sh && ./configure \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=/usr/local/cuda \ + --with-verbs=yes \ + --with-dm \ + --with-gdrcopy=/usr/local \ + --with-efa \ + --enable-mt && \ + make -j && \ + make -j install-strip && \ + ldconfig; + +RUN apt-get update && apt-get install -y pkg-config tmux net-tools; \ + cd /usr/local/src; \ + pip install --upgrade meson pybind11 patchelf; \ + git clone https://github.com/ai-dynamo/nixl.git -b main && \ + cd nixl && \ + rm -rf build && \ + mkdir build && \ + meson setup build/ --prefix=/usr/local/nixl --buildtype=release && \ + cd build && \ + ninja && \ + ninja install && \ + cd .. && pip install . --no-deps; + +COPY . /lightllm +RUN pip install -e /lightllm --no-cache-dir diff --git a/docker/Dockerfile.nixl.deepep b/docker/Dockerfile.nixl.deepep new file mode 100644 index 000000000..8ca06e109 --- /dev/null +++ b/docker/Dockerfile.nixl.deepep @@ -0,0 +1,121 @@ +ARG CUDA_VERSION=12.6.1 +FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 + +ARG PYTHON_VERSION=3.10 +ARG MAMBA_VERSION=24.7.1-0 +ARG TARGETPLATFORM + +ENV PATH=/opt/conda/bin:$PATH \ + CONDA_PREFIX=/opt/conda + +RUN chmod 777 -R /tmp && apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ca-certificates \ + libssl-dev \ + curl \ + g++ \ + make \ + git && \ + rm -rf /var/lib/apt/lists/* + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -o ~/mambaforge.sh -v "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") exit 1 ;; \ + *) /opt/conda/bin/conda update -y conda && \ + /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ + esac && \ + /opt/conda/bin/conda clean -ya + + +WORKDIR /root + +COPY ./requirements.txt /lightllm/requirements.txt +RUN --mount=type=cache,target=/root/.cache/pip pip install -r /lightllm/requirements.txt --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu124 + +RUN --mount=type=cache,target=/root/.cache/pip pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly +RUN --mount=type=cache,target=/root/.cache/pip git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . + +RUN apt-get update && apt-get install -y libnuma-dev wget devscripts debhelper dh-make build-essential dkms +RUN apt-get install -y ibverbs-providers infiniband-diags perftest rdma-core libibverbs-dev librdmacm-dev + +ENV CUDA_HOME=/usr/local/cuda \ + GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ + +RUN mkdir -p /tmp/gdrcopy && cd /tmp \ + && git clone https://github.com/NVIDIA/gdrcopy.git -b v2.4.4 \ + && cd gdrcopy/packages \ + && CUDA=/usr/local/cuda ./build-deb-packages.sh \ + && dpkg -i gdrdrv-dkms_*.deb libgdrapi_*.deb gdrcopy-tests_*.deb gdrcopy_*.deb \ + && cd / && rm -rf /tmp/gdrcopy + + # Fix DeepEP IBGDA symlink +RUN ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so + +RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/source/nvshmem_src_cuda12-all-all-3.3.9.tar.gz \ + && tar -xf nvshmem_src_cuda12-all-all-3.3.9.tar.gz && mv nvshmem_src nvshmem \ + && cd nvshmem \ + && rm -f /root/nvshmem_src_cuda12-all-all-3.3.9.tar.gz \ + && NVSHMEM_SHMEM_SUPPORT=0 \ + NVSHMEM_UCX_SUPPORT=0 \ + NVSHMEM_USE_NCCL=0 \ + NVSHMEM_MPI_SUPPORT=0 \ + NVSHMEM_IBGDA_SUPPORT=1 \ + NVSHMEM_PMIX_SUPPORT=0 \ + NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + NVSHMEM_USE_GDRCOPY=1 \ + cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90 \ + && cmake --build build --target install -j64 + +ARG DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58 +RUN git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd .. + +WORKDIR /root/DeepEP +ENV NVSHMEM_DIR=/root/nvshmem/install +RUN NVSHMEM_DIR=/root/nvshmem/install python setup.py install + +RUN apt-get update && apt-get install -y cmake automake autotools-dev libtool libz-dev && \ + DEBIAN_FRONTEND=noninteractive apt-get -y install --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev; \ + rm -rf /usr/lib/ucx && \ + rm -rf /opt/hpcx/ucx && \ + cd /usr/local/src && \ + git clone https://github.com/openucx/ucx.git && \ + cd ucx && \ + git checkout v1.19.x && \ + ./autogen.sh && ./configure \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=/usr/local/cuda \ + --with-verbs=yes \ + --with-dm \ + --with-gdrcopy=/usr/local \ + --with-efa \ + --enable-mt && \ + make -j && \ + make -j install-strip && \ + ldconfig; + +RUN apt-get update && apt-get install -y pkg-config tmux net-tools ; \ + cd /usr/local/src; \ + pip install --upgrade meson pybind11 patchelf; \ + git clone https://github.com/ai-dynamo/nixl.git -b main && \ + cd nixl && \ + rm -rf build && \ + mkdir build && \ + meson setup build/ --prefix=/usr/local/nixl --buildtype=release && \ + cd build && \ + ninja && \ + ninja install && \ + cd .. && pip install . --no-deps; + +COPY . /lightllm +RUN pip install -e /lightllm --no-cache-dir diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4516e18c3..9f976df9d 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -177,6 +177,10 @@ def _init_kv_move_buffer(self): # p d 分离的推理模式下才需要做这一步初始化 if self.run_mode in ["prefill", "decode"]: self.mem_manager.alloc_kv_move_buffer(self.mem_manager.size) + elif self.run_mode in ["nixl_prefill", "nixl_decode"]: + page_num = int(os.getenv("PD_NIXL_MOVE_PAGE_NUM", 32)) + page_size = int(os.getenv("PD_NIXL_MOVE_PAGE_SIZE", 1024)) + self.mem_manager.alloc_paged_kv_move_buffer(page_num, page_size) def _check_mem_size(self): self.max_total_token_num = self.mem_manager.size diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 6ddec24e2..c0a0b72b9 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -36,6 +36,12 @@ def alloc_kv_move_buffer(self, max_req_total_len): self.token_dim_size = self.kv_move_buffer.shape[-1] * self.kv_move_buffer.shape[-2] return + def alloc_paged_kv_move_buffer(self, page_num, page_size): + self.kv_move_buffer = torch.empty( + (page_num, page_size, self.layer_num, self.head_num, self.head_dim), dtype=self.dtype, device="cuda" + ) + return + def send_to_decode_node( self, move_tasks: List[KVMoveTask], diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 4142ce4aa..4f8292bf2 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -96,6 +96,14 @@ def alloc_kv_move_buffer(self, max_req_total_len): self.token_dim_size = self.kv_move_buffer.shape[-2] * self.kv_move_buffer.shape[-1] return + def alloc_paged_kv_move_buffer(self, page_num, page_size): + if isinstance(self, MemoryManager) and type(self) != MemoryManager: + raise NotImplementedError("subclass need reimpl this method") + self.kv_move_buffer = torch.empty( + (page_num, page_size, self.layer_num, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda" + ) + return + def send_to_decode_node( self, move_tasks: List[KVMoveTask], diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index dfbf0c84b..855a36b7d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "pd_master", "config_server"], + choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -54,6 +54,20 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="The port number for the config server in config_server mode.", ) + parser.add_argument( + "--pd_nixl_remote_prefill_http_port", + type=int, + default=42001, + help="nixl pd mode, prefill node used for triggering prefill http port.", + ) + + parser.add_argument( + "--pd_nixl_remote_prefill_port", + type=int, + default=42002, + help="nixl pd mode, prefill and decode used for meta exchange.", + ) + parser.add_argument( "--model_name", type=str, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index c2a87b4c3..2fdf29de4 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -67,7 +67,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: return assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index 5594df6a0..ab390a602 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -1,5 +1,5 @@ from .sampling_params import SamplingParams -from .req import Req, FinishStatus +from .req import Req, FinishStatus, PDNIXLChunkedPrefillReq from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray from .start_args_type import StartArgs diff --git a/lightllm/server/core/objs/io_objs/__init__.py b/lightllm/server/core/objs/io_objs/__init__.py index 80f4f0772..0a443a57c 100644 --- a/lightllm/server/core/objs/io_objs/__init__.py +++ b/lightllm/server/core/objs/io_objs/__init__.py @@ -1 +1 @@ -from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd +from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, NIXLRemotePrefillDoneCmd, ReqCmd diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index d16dc4d06..d322552df 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -29,5 +29,13 @@ def to_group_req_index(self): @dataclass -class AbortedReqCmd: +class ReqCmd: req_id: int + + +class AbortedReqCmd(ReqCmd): + pass + + +class NIXLRemotePrefillDoneCmd(ReqCmd): + pass diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 06a728925..2a489da2a 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -105,6 +105,7 @@ def get_str(self): f"shm_cur_kv_len:{self.shm_cur_kv_len}," f"shm_cur_output_len:{self.shm_cur_output_len}," f"finish_status:{self.finish_status.is_finished()}" + f"group_id: {self.group_req_id}" ) def init( @@ -326,3 +327,63 @@ def post_init( # 错误问题。 self.sample_params.max_new_tokens = self.sample_params.max_new_tokens + self.prefix_token_ids.size + 6 return + + +class PdNixlReqState(ctypes.Structure): + _pack_ = 4 + _MAX_TP_SIZE = 32 + _fields_ = [("dp_world_size", ctypes.c_int), ("state", ctypes.c_int * _MAX_TP_SIZE)] + + def __init__(self): + self.dp_world_size = 0 + self.state = (ctypes.c_int * self._MAX_TP_SIZE)(*([0] * self._MAX_TP_SIZE)) + + def set_dp_world_size(self, size: int): + assert size < self._MAX_TP_SIZE, f"size {size} > max size {self._MAX_TP_SIZE}" + self.dp_world_size = size + ctypes.memset(ctypes.addressof(self.state), 0, (self.dp_world_size + 1) * ctypes.sizeof(ctypes.c_int)) + + def set_tp_state(self, tp_id: int, state: int): + assert ( + self.dp_world_size > 0 and tp_id >= 0 and tp_id < self.dp_world_size + ), f"tp_id {tp_id} out of range [0, {self.dp_world_size})" + self.state[tp_id] = state + + def set_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + unique_state = np.unique(self.state[: self.dp_world_size]) + self.state[self.dp_world_size] = unique_state[0] + return unique_state[0] + + def get_state(self): + assert self.dp_world_size > 0, "dp_world_size should be set before calling this" + return self.state[self.dp_world_size] + + +class PDNIXLChunkedPrefillReq(ChunkedPrefillReq): + _pack_ = 4 + _fields_ = ChunkedPrefillReq._fields_ + [ + # 用于pd nixl状态同步 + ("pd_nixl_req_state", PdNixlReqState), + ("router_nixl_rpd", ctypes.c_bool), + ] + + def post_init(self): + self.router_nixl_rpd = False + + def set_dp_world_size(self, dp_world_size): + self.pd_nixl_req_state.set_dp_world_size(dp_world_size) + self.router_nixl_rpd = False + + # called by each tp rank, no contention + def set_pd_req_rank_state(self, tp_id: int, state: int): + self.pd_nixl_req_state.set_tp_state(tp_id, state) + + # state: -1 for failed, 0 for in progress, 1 for success + # set by router + def set_pd_req_state(self): + return self.pd_nixl_req_state.set_state() + + # read by all rank + def get_pd_req_state(self): + return self.pd_nixl_req_state.get_state() diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index b88d4576d..ed9d40f00 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -3,7 +3,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from multiprocessing import shared_memory from lightllm.utils.log_utils import init_logger -from .req import Req, ChunkedPrefillReq, TokenHealingReq +from .req import Req, ChunkedPrefillReq, TokenHealingReq, PDNIXLChunkedPrefillReq from .shm_array import ShmArray from .atomic_array_lock import AtomicShmArrayLock, AtomicLockItem from .atomic_lock import AtomicShmLock @@ -33,6 +33,9 @@ def get_req_class_type(self): if args.token_healing_mode: return TokenHealingReq + if args.run_mode in ["nixl_prefill", "nixl_decode"]: + return PDNIXLChunkedPrefillReq + return ChunkedPrefillReq def get_max_req_num(self): diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d4a205a15..4c2006fe4 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -6,7 +6,10 @@ @dataclass class StartArgs: - run_mode: str = field(default="normal", metadata={"choices": ["normal", "prefill", "decode", "pd_master"]}) + run_mode: str = field( + default="normal", + metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + ) host: str = field(default="127.0.0.1") port: int = field(default=8000) zmq_mode: str = field( diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index de552d80c..68416ea91 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -101,7 +101,7 @@ def __init__( self.metric_client = MetricClient(metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) - assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL] + assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL, NodeRole.NP, NodeRole.ND] self.id_gen = ReqIDGenerator() self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() @@ -236,7 +236,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): # health 请求 request_id 为负数,直接返回 if is_health_req: return sampling_params.group_request_id - if self.pd_mode == NodeRole.NORMAL: + if self.pd_mode.is_normal(): if not self.is_multinode_tp: group_request_id = self.id_gen.generate_id() else: @@ -246,7 +246,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert sampling_params.group_request_id != -1 group_request_id = sampling_params.group_request_id sampling_params.group_request_id = group_request_id - elif self.pd_mode == NodeRole.P or self.pd_mode == NodeRole.D: + elif self.pd_mode.is_P_or_D(): assert sampling_params.group_request_id is not None, "p d mode, group_request_id must be setting" group_request_id = sampling_params.group_request_id else: @@ -424,7 +424,7 @@ async def transfer_to_next_module_or_node( # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. if self.is_multinode_tp_master: for sender in self.multinode_req_manager: - sender.send_pyobj( + await sender.send_pyobj( (prompt, sampling_params, original_multimodal_params), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -437,35 +437,37 @@ async def transfer_to_next_module( group_req_objs: Optional[GroupReqObjs] = None, ): - if self.pd_mode == NodeRole.P: + if self.pd_mode.is_P(): if self.enable_multimodal: - self.send_to_visual.send_pyobj( + await self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( + + # P 模式下,直接将请求发送到路由器 + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return - if self.pd_mode == NodeRole.D: + if self.pd_mode.is_D(): # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 - self.send_to_router.send_pyobj( + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return - if self.pd_mode == NodeRole.NORMAL: + if self.pd_mode.is_normal(): if self.enable_multimodal: - self.send_to_visual.send_pyobj( + await self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( + await self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -510,7 +512,7 @@ async def _wait_to_token_package( # pd master 节点需要这个做统计信息, 所以放在元数据中返回给 pd master 节点 metadata["prompt_tokens"] = prompt_tokens # p 节点返回 prompt_ids 信息,防止 d 节点重新 encode - if self.pd_mode == NodeRole.P and is_first_token: + if self.pd_mode.is_P() and is_first_token: metadata["prompt_ids"] = prompt_ids prompt_cache_len = metadata.pop("prompt_cache_len", 0) @@ -619,7 +621,7 @@ async def recycle_resource_loop(self): continue logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" + f"left req id {req_status.group_req_objs.group_req_id} " f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" ) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 10a4a8ec5..94394182d 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -5,6 +5,7 @@ import socket import httpx import base64 +import zmq from typing import Dict, Optional from lightllm.server.pd_io_struct import NodeRole, ObjType from lightllm.server.httpserver.async_queue import AsyncQueue @@ -33,6 +34,8 @@ async def pd_handle_loop(manager: HttpServerManager): manager.host_ip = manager.args.host asyncio.create_task(timer_log(manager)) + if manager.pd_mode.is_NP_or_ND(): + asyncio.create_task(pd_handle_loop_from_d(manager)) id_to_handle_task: Dict[int, asyncio.Task] = {} @@ -92,7 +95,8 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O logger.info(f"Sent registration JSON: {regist_json}") # 转发任务 - forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) + if manager.pd_mode != NodeRole.NP: # nixl prefill don't need up token to master + forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 while True: @@ -182,3 +186,33 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): handle_list = await forwarding_queue.wait_to_get_all_data() if handle_list: await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list))) + + +async def pd_handle_loop_from_d(manager: HttpServerManager): + if manager.pd_mode != NodeRole.NP: + return + + context = zmq.asyncio.Context(2) + manager.recv_from_d = context.socket(zmq.PULL) + manager.recv_from_d.bind(f"tcp://*:{manager.args.pd_nixl_remote_prefill_http_port}") + + while True: + try: + ( + prompt, + sampling_params, + multimodal_params, + ) = await manager.recv_from_d.recv_pyobj() + + # 触发推理的task + async def pd_process_generate(manager: "HttpServerManager", prompt, sampling_params, multimodal_params): + try: + async for _, _, _, _ in manager.generate(prompt, sampling_params, multimodal_params, None): + pass + except BaseException as e: + logger.error(str(e)) + + asyncio.create_task(pd_process_generate(manager, prompt, sampling_params, multimodal_params)) + + except Exception as e: + logger.exception(f"pd loop generate error: {str(e)}") diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 05b2d987c..52c9c729d 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -1,20 +1,15 @@ import sys -import zmq -import zmq.asyncio import asyncio import uvloop -import rpyc import time -import hashlib import datetime -import aiohttp import ujson as json import pickle asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict from lightllm.server.core.objs import FinishStatus -from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType +from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType, NodeRole from lightllm.server.core.objs import SamplingParams from ..multimodal_params import MultimodalParams from ..tokenizer import get_tokenizer @@ -55,10 +50,11 @@ async def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client - if pd_client.mode == "prefill": + client_pd_mode: NodeRole = NodeRole(pd_client.mode) + if client_pd_mode.is_P(): self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.prefill_nodes.append(pd_client) - elif pd_client.mode == "decode": + elif client_pd_mode.is_D(): self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes.append(pd_client) else: @@ -110,8 +106,8 @@ async def select_p_d_node( ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: import random - p_node = random.choice(self.prefill_nodes) - d_node = random.choice(self.decode_nodes) + p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None + d_node = random.choice(self.decode_nodes) if self.decode_nodes else None return p_node, d_node async def generate( @@ -133,6 +129,10 @@ async def generate( p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params) + if not p_node or not d_node: + logger.error(f"{group_request_id}: No p_node or d_node found") + return + results_generator = self._wait_to_token_package( p_node, d_node, @@ -253,6 +253,43 @@ async def fetch_stream( return + async def fetch_stream_nixl( + self, + p_node: PD_Client_Obj, + d_node: PD_Client_Obj, + prompt: Union[str, List[int]], + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + ): + group_request_id = sampling_params.group_request_id + + req_status = ReqStatus(group_request_id, p_node, d_node) + self.req_id_to_out_inf[group_request_id] = req_status + + p_start_args = p_node.start_args + prefill_node_dict = { + "node_id": p_start_args["pd_node_id"], + "ip": p_start_args["host"], + "rpyc_port": p_start_args["pd_nixl_remote_prefill_port"], + "max_new_tokens": sampling_params.max_new_tokens, + "pd_master_node_id": self.args.pd_node_id, + } + + sampling_params.move_kv_to_decode_node.initialize(prefill_node_dict) + sampling_params.suggested_dp_index = -1 + + await d_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) + + while True: + await req_status.wait_to_ready() + if await request.is_disconnected(): + raise Exception(f"req_id {group_request_id} disconnected") + if await req_status.can_read(self.req_id_to_out_inf): + token_list = await req_status.pop_all_tokens() + for sub_req_id, request_output, metadata, finish_status in token_list: + yield sub_req_id, request_output, metadata, finish_status + async def _wait_to_token_package( self, p_node: PD_Client_Obj, @@ -269,7 +306,11 @@ async def _wait_to_token_package( unfinished_count = sampling_params.best_of is_first_token = True - async for sub_req_id, out_str, metadata, finish_status in self.fetch_stream( + client_mode: NodeRole = NodeRole(d_node.mode) + + fetch_stream = self.fetch_stream_nixl if client_mode.is_NP_or_ND() else self.fetch_stream + + async for sub_req_id, out_str, metadata, finish_status in fetch_stream( p_node, d_node, prompt, sampling_params, multimodal_params, request ): if await request.is_disconnected(): diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index e3c1d19d2..2c5355d28 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -155,6 +155,12 @@ def to_dict(self): ret["audios"] = [a.to_dict() for a in self.audios] return ret + @classmethod + def from_dict(cls, data: dict): + if "images" not in data: + return cls() + return cls(images=data["images"]) + def to_origin_dict(self): """ 将内容转换为原始请求的形式,主要用于请求转发 diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 414e3c74a..4267afaee 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -13,23 +13,30 @@ class NodeRole(enum.Enum): P = "prefill" D = "decode" + + NP = "nixl_prefill" + ND = "nixl_decode" + NORMAL = "normal" PD_MASTER = "pd_master" def is_D(self): - return self == NodeRole.D + return self == NodeRole.D or self == NodeRole.ND def is_P(self): - return self == NodeRole.P + return self == NodeRole.P or self == NodeRole.NP def is_normal(self): return self == NodeRole.NORMAL def is_P_or_NORMAL(self): - return (self == NodeRole.P) or (self == NodeRole.NORMAL) + return self.is_P() or self.is_normal() def is_P_or_D(self): - return (self == NodeRole.P) or (self == NodeRole.D) + return self.is_P() or self.is_D() + + def is_NP_or_ND(self): + return self == NodeRole.NP or self == NodeRole.ND class ObjType(enum.Enum): @@ -47,8 +54,8 @@ class PD_Client_Obj: websocket: WebSocket = None # 用于通信的 websocket 连接对象 def __post_init__(self): - if self.mode not in ["prefill", "decode"]: - error_info = f"""mode must in ["prefill", "decode"], but get {self.mode}""" + if self.mode not in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + error_info = f"""mode must in ["prefill", "decode", "nixl_prefill", "nixl_decode"], but get {self.mode}""" logger.error(error_info) raise ValueError(error_info) return @@ -114,6 +121,23 @@ class PDTransJoinInfo: connect_id: str +@dataclass +class RemotePrefillServerInfo: + perfill_server_id: int + prefill_server_ip: str + prefill_server_port: int + + +@dataclass +class DistInfo: + world_size: int + nnodes: int + dp_size: int + dp_world_size: int + dp_size_in_node: int + node_world_size: int + + @dataclass class PDTransLeaveInfo: decode_id: int diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 1336cd1dc..e933fbb83 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -1,7 +1,7 @@ import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union -from lightllm.server.core.objs import ShmReqManager, Req +from lightllm.server.core.objs import ShmReqManager, Req, PDNIXLChunkedPrefillReq from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 24b8a9ddb..78fd65102 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -11,12 +11,12 @@ import torch.multiprocessing as mp import torch.distributed as dist import multiprocessing -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue -from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd -from lightllm.server.core.objs import ShmReqManager, StartArgs +from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd, NIXLRemotePrefillDoneCmd, ReqCmd +from lightllm.server.core.objs import ShmReqManager, StartArgs, PDNIXLChunkedPrefillReq from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from .shm_reqs_io_buffer import ShmReqsIOBuffer from lightllm.utils.log_utils import init_logger, log_time_ready @@ -27,6 +27,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.pd_io_struct import DistInfo logger = init_logger(__name__) @@ -44,6 +45,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.schedule_time_interval = args.schedule_time_interval # 默认30ms 的调度周期 # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.dp_size_in_node = max(1, args.dp // self.nnodes) + self.dp_world_size = self.world_size // self.dp_size self.is_multinode_tp = args.nnodes > 1 and args.dp == 1 self.is_multinode_tp_master = self.is_multinode_tp and args.node_rank == 0 self.is_multinode_tp_slave = self.is_multinode_tp and args.node_rank != 0 @@ -85,8 +87,8 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por ) self.metric_client = MetricClient(metric_port) - self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] - self.is_pd_decode_mode = self.args.run_mode == "decode" + self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"] + self.is_pd_decode_mode = self.args.run_mode in ["decode", "nixl_decode"] # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 # 主要是为了防止调度失误,造成 OOM 等错误 self.router_lock = mp.Lock() @@ -106,12 +108,14 @@ async def wait_to_model_ready(self): self.mem_queues: List[torch.multiprocessing.Queue] = [ torch.multiprocessing.Queue() for _ in range(self.node_world_size) ] + self.result_queues: List[mp.Queue] = [mp.Queue() for _ in range(self.node_world_size)] self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() assert (self.world_size % self.nnodes) == 0 node_world_size = self.world_size // self.nnodes for rank_id in range(self.node_rank * node_world_size, (self.node_rank + 1) * node_world_size): + rpc_model = await start_model_process( args=self.args, rank=rank_id, @@ -120,7 +124,8 @@ async def wait_to_model_ready(self): rpc_event=self.rpc_event, rpc_finished_event=self.rpc_finished_event, info_queue=self.info_queue, - mem_queue=self.mem_queues[(rank_id % node_world_size)], + result_queue=self.result_queues[rank_id % node_world_size], + mem_queue=self.mem_queues[rank_id % node_world_size], router_lock=self.router_lock, ) self.model_rpc_servers.append(rpc_model) @@ -173,7 +178,7 @@ async def wait_to_model_ready(self): get_unique_server_name(), self.max_total_token_num, node_world_size=self.node_world_size, - dp_world_size=self.world_size // self.dp_size, + dp_world_size=self.dp_world_size, ) self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node) logger.info(f"use req queue {self.req_queue.__class__.__name__}") @@ -186,6 +191,30 @@ async def wait_to_model_ready(self): start_prefill_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + if self.args.run_mode == "nixl_prefill": + from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( + start_pd_remote_prefill_server_process, + ) + + dist_info = DistInfo( + self.world_size, + self.nnodes, + self.dp_size, + self.dp_world_size, + self.dp_size_in_node, + self.node_world_size, + ) + + start_pd_remote_prefill_server_process( + self.args.pd_node_id, + dist_info=dist_info, + http_server_port=self.args.pd_nixl_remote_prefill_http_port, + server_port=self.args.pd_nixl_remote_prefill_port, + from_backend_queue=self.info_queue, + to_backend_queues=self.result_queues, + agent_meta_queues=self.mem_queues, + ) + if self.args.run_mode == "decode": # 启动 decode kv move 管理进程 from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.decode_node_impl import ( @@ -194,6 +223,28 @@ async def wait_to_model_ready(self): start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + if self.args.run_mode == "nixl_decode": + from lightllm.server.router.model_infer.mode_backend.pd_nixl.pd_remote_prefill import ( + start_pd_remote_prefill_client_process, + ) + + dist_info = DistInfo( + self.world_size, + self.nnodes, + self.dp_size, + self.dp_world_size, + self.dp_size_in_node, + self.node_world_size, + ) + + start_pd_remote_prefill_client_process( + self.args.pd_node_id, + dist_info, + from_backend_queue=self.info_queue, + to_backend_queues=self.result_queues, + agent_meta_queues=self.mem_queues, + ) + return def _get_schedule_time_interval(self): @@ -276,9 +327,20 @@ async def _step(self): await self._add_batch(new_batch) self._filter_reqs_from_running_batch() + + filter_cmds = [] + aborted_reqs = self._get_aborted_reqs_from_running_batch() if aborted_reqs: - await self._aborted_reqs(aborted_reqs=aborted_reqs) + filter_cmds.extend([AbortedReqCmd(req_id=r.request_id) for r in aborted_reqs]) + + if self.args.run_mode == "nixl_decode": + remote_prefill_done_reqs = self._get_nixl_rpd_reqs_from_running_batch() + if remote_prefill_done_reqs: + filter_cmds.extend([NIXLRemotePrefillDoneCmd(req_id=r.request_id) for r in remote_prefill_done_reqs]) + + if filter_cmds: + await self._filter_running_reqs(filter_cmds) return async def _add_batch(self, batch: Batch): @@ -292,8 +354,7 @@ async def _add_batch(self, batch: Batch): logger.debug(f"Prefill Batch: {batch.simple_log()} \n") return - async def _aborted_reqs(self, aborted_reqs: List[Req]): - cmds = [AbortedReqCmd(req_id=r.request_id) for r in aborted_reqs] + async def _filter_running_reqs(self, cmds: List[ReqCmd]): while not self.shm_reqs_io_buffer.is_empty(): await asyncio.sleep(0.02) @@ -325,6 +386,16 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]: ans.append(req) return ans + def _get_nixl_rpd_reqs_from_running_batch(self) -> List[Req]: + ans = [] + if self.running_batch is None: + return ans + for req in self.running_batch.reqs: + req: PDNIXLChunkedPrefillReq + if req.set_pd_req_state() != 0: + ans.append(req) + return ans + def _get_paused_req_num(self) -> int: if self.running_batch is None: return 0 @@ -361,6 +432,8 @@ def _add_req(self, group_req_indexes: GroupReqIndexes): req = self.shm_req_manager.get_req_obj_by_index(req_index) req.multimodal_params = group_req_indexes.multimodal_params req.start_time = group_req_indexes.time_mark + if isinstance(req, PDNIXLChunkedPrefillReq): + req.set_dp_world_size(self.dp_world_size) req_group.append(req) logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 67d69aa38..0f44a1b07 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -10,7 +10,7 @@ from typing import List, Dict, Tuple, Optional, Callable, Any from lightllm.common.req_manager import ReqManager from lightllm.utils.infer_utils import mark_start, mark_end -from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager +from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager, PDNIXLChunkedPrefillReq from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id @@ -124,7 +124,7 @@ def _save_promptcache_kvbuffer(self): https://arxiv.org/abs/2403.01241 """ prompt_cache_token_id = list(self.radix_cache.root_node.children.values())[0].token_id_key - print(f"prompt_cache_token_id : {prompt_cache_token_id}") + # print(f"prompt_cache_token_id : {prompt_cache_token_id}") index = range(len(prompt_cache_token_id)) prompt_cache_kv_buffer = self.radix_cache.mem_manager.get_index_kv_buffer(index) torch.save(prompt_cache_kv_buffer, f"prompt_cache_rank_{dist.get_rank()}.pt") @@ -299,6 +299,9 @@ def __init__( self.need_out_token_id_statistics = True self.out_token_id_count: Dict[int, int] = None + self.infer_nixl_rpd = False + self.in_prefill_or_transfer = False + # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 self.mtp_step: int = get_env_start_args().mtp_step @@ -483,8 +486,11 @@ def handle( eos_ids: List[int], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]], is_master_in_dp: bool, + call_post_handle_for_chunk: bool, ): if self.output_len <= 0: + if call_post_handle_for_chunk and extra_post_req_handle_func: + extra_post_req_handle_func(self.req_obj, next_token_id, next_token_logprob) return req_obj = self.req_obj diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 3f30e8b4e..1558f000e 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -14,3 +14,7 @@ from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ChunckedPrefillForPrefillNode from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp import DPChunkedForPrefillNode +from .pd_nixl.impl_for_pd_prefill import PDNIXLBackendForPrefillNode +from .pd_nixl.impl_for_pd_decode import PDNIXLBackendForDecodeNode +from .pd_nixl.impl_for_pd_decode_dp import PDNIXLDPBackendForDecodeNode +from .pd_nixl.impl_for_pd_prefill_dp import PDNIXLDPBackendForPrefillNode diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index fd75afdbf..ebfcab2b3 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -19,7 +19,7 @@ from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs -from lightllm.server.core.objs.io_objs import AbortedReqCmd +from lightllm.server.core.objs.io_objs import AbortedReqCmd, NIXLRemotePrefillDoneCmd, ReqCmd from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size @@ -47,6 +47,7 @@ def __init__(self) -> None: self.decode_mask_func: Optional[Callable[[List[InferReq], torch.Tensor], None]] = None # extra_post_req_handle_func 用于添加请求InferReq的状态变化中添加额外的后处理信息,主要是状态机相关的调整等。 self.extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None + self.call_post_handle_for_chunk: bool = False self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap @@ -315,12 +316,14 @@ def _read_reqs_buffer_and_init_reqs(self): cmds: List = self.shm_reqs_io_buffer.read_obj() self.shm_reqs_io_buffer.sub_state() if cmds: - if isinstance(cmds[0], AbortedReqCmd): + if isinstance(cmds[0], ReqCmd): for obj in cmds: - obj: AbortedReqCmd = obj if obj.req_id in g_infer_context.requests_mapping: req: InferReq = g_infer_context.requests_mapping[obj.req_id] - req.infer_aborted = True + if isinstance(obj, AbortedReqCmd): + req.infer_aborted = True + elif isinstance(obj, NIXLRemotePrefillDoneCmd): + req.infer_nixl_rpd = True else: self._init_reqs(reqs=cmds) return @@ -501,6 +504,7 @@ def _post_handle( next_token_logprobs: List[float], run_reqs_update_packs: List[InferReqUpdatePack], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + call_post_handle_for_chunk: bool = False, ): """ extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 @@ -517,6 +521,7 @@ def _post_handle( eos_ids=self.eos_id, extra_post_req_handle_func=extra_post_req_handle_func, is_master_in_dp=self.is_master_in_dp, + call_post_handle_for_chunk=call_post_handle_for_chunk, ) g_infer_context.req_manager.req_sampling_params_manager.update_reqs_token_counter( diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 39d345ff5..5e45ac051 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -51,7 +51,6 @@ def infer_loop(self): # 关闭overlap 模式 if not self.support_overlap: event_pack._close_overlap() - event_pack.wait_to_forward() self._try_read_new_reqs() @@ -139,6 +138,7 @@ def prefill_normal( next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, + call_post_handle_for_chunk=self.call_post_handle_for_chunk, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -253,6 +253,7 @@ def prefill_mtp( next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, + call_post_handle_for_chunk=self.call_post_handle_for_chunk, ) # 第四阶段 diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index a90a946fd..557d2dd87 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -160,6 +160,7 @@ def prefill_normal( next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, + call_post_handle_for_chunk=self.call_post_handle_for_chunk, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -290,6 +291,7 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, + call_post_handle_for_chunk=self.call_post_handle_for_chunk, ) # 第四阶段 event_pack.notify_pre_post_handle() @@ -437,6 +439,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, + call_post_handle_for_chunk=self.call_post_handle_for_chunk, ) # 第四阶段 @@ -661,6 +664,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I next_token_logprobs=next_token_logprobs_cpu, run_reqs_update_packs=update_packs, extra_post_req_handle_func=self.extra_post_req_handle_func, + call_post_handle_for_chunk=self.call_post_handle_for_chunk, ) event_pack.notify_pre_post_handle() else: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index d5bba1ae5..5cbe57c58 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -89,7 +89,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len = req.get_cur_total_len() - assert req.cur_kv_len == seq_len - 1 + assert req.cur_kv_len == seq_len - 1, f"{req.cur_kv_len} {seq_len}" b_seq_len.append(seq_len) total_token_num += seq_len max_len_in_batch = max(max_len_in_batch, seq_len) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py new file mode 100644 index 000000000..dde3f793e --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py @@ -0,0 +1,535 @@ +import time +import torch.multiprocessing as mp +import torch +from typing import List +import queue +import numpy as np +import asyncio +import threading + + +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq +from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend +from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferReqUpdatePack + +from .nixl_kv_transporter import NixlMetadata, NixlKVTransporter +from .pd_remote_prefill_obj import ( + PrefillRequest, + RemoteRequest, + RemoteRequstType, + ConnectRequest, + KVMoveRequest, + RemotePrefillStatus, + ThreadSafeDict, + TransferState, + SafePageIndexScheduler, + RemoteTransferType, + RemoteTransferStatusType, + PageTransferAck, + NotificationType, + Notification, +) + +logger = init_logger(__name__) + + +class PDNIXLBackendBase(object): + _THREAD_WAIT_INTERVAL = 0.001 + + def __init__(self, to_remote_queue: mp.Queue, from_remote_queue: mp.Queue, nixl_meta_queue: mp.Queue): + super().__init__() + self.to_remote_queue = to_remote_queue + self.from_remote_queue = from_remote_queue + self.nixl_meta_queue = nixl_meta_queue + self.prefill_post_handle_queue = queue.Queue() + + # for decode + self.remote_prefilled_reqs: ThreadSafeDict = ThreadSafeDict() + self.request_to_page_ids: ThreadSafeDict = ThreadSafeDict() + self.request_to_first_token: ThreadSafeDict = ThreadSafeDict() + + # for prefill + self.remote_prefill_requests: ThreadSafeDict = ThreadSafeDict() + self.inflght_transfer_requests: ThreadSafeDict = ThreadSafeDict() + + self.page_copy_stream = torch.cuda.Stream() + + def init_custom(self): + self.nixl_agent = NixlKVTransporter(self.args.pd_node_id, self.rank_in_node) + self.nixl_agent.register_kv_buffer(self.model.mem_manager.kv_buffer) + self.nixl_agent.register_kv_move_buffer(self.model.mem_manager.kv_move_buffer) + self.page_scheduer = SafePageIndexScheduler(self.nixl_agent.num_pages) + + self.nixl_meta_queue.put( + ( + self.nixl_agent.agent_metadata, + self.nixl_agent.num_tokens, + self.nixl_agent.num_pages, + self.nixl_agent.local_mem_desc, + self.nixl_agent.local_page_mem_desc, + ) + ) + + def _start_async_loop(self, async_loop_func): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(async_loop_func()) + + async def _handle_remote_prefill(self, req_status: RemotePrefillStatus): + group_req_id = req_status.group_req_id + status = req_status.status + if status != RemoteTransferStatusType.SUCCESS: + logger.warning(f"remote prefill reqeust: {group_req_id} done with state: {status}") + + ret = None + if run_req := self.remote_prefilled_reqs.get(group_req_id, None): + if ( + req_status.transfer_type == RemoteTransferType.PAGE_TRANSFER + and status == RemoteTransferStatusType.SUCCESS + ): + kv_start, kv_len = req_status.kv_start, req_status.kv_len + token_ids = g_infer_context.req_manager.req_to_token_indexs[run_req.req_idx][ + kv_start : kv_start + kv_len + ] # gpu tensor + self.model.mem_manager.kv_buffer[:, token_ids, :, :] = self.model.mem_manager.kv_move_buffer[ + req_status.page_id + ][:kv_len].transpose(0, 1) + ret = PageTransferAck(group_req_id=group_req_id, page_id=req_status.page_id) + + if req_status.is_last or status != RemoteTransferStatusType.SUCCESS: + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req + self.remote_prefilled_reqs.pop(group_req_id) + self.request_to_first_token[group_req_id] = (req_status.next_token_id, req_status.next_token_logprob) + shm_req.set_pd_req_rank_state(self.rank_in_dp, status.value) + + if self.is_master_in_dp: + logger.info( + f"remote prefill request: {group_req_id} done with status: {status} " + f"took: {time.time() - run_req.remote_prefill_start} seconds" + ) + + ret = None + + else: + if self.is_master_in_dp: + logger.warning(f"remote prefill reqeust: {group_req_id} not found") + + return ret + + async def _prefill_wait_loop_async(self): + while True: + # from local + try: + req_status = self.from_remote_queue.get_nowait() + await self._handle_remote_prefill(req_status) + except queue.Empty: + pass + + # from remote + notifies = self.nixl_agent.get_new_notifs() + for agent_name, req_statuses in notifies.items(): + with torch.cuda.stream(self.page_copy_stream): + acks = [] + for req_statuses_bytes in req_statuses: + noti: Notification = Notification.from_bytes(req_statuses_bytes) + if noti.type == NotificationType.REMOTE_MD: + self.nixl_agent.connect_to_remote(agent_name, noti.data) + elif noti.type == NotificationType.TRANSFER_NOTIFY: + for req_status in noti.data: + prefill_status = RemotePrefillStatus.deserialize(req_status) + ack = await self._handle_remote_prefill(prefill_status) + if ack: + acks.append(ack) + if len(acks) > 0: + + # wait for copy done + self.page_copy_stream.synchronize() + # logger.info(f"send {len(acks)} acks to {agent_name}") + self.nixl_agent.send_transfer_notify(agent_name, acks) + + await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + def _handle_chunked_transfer(self, req: InferReq, next_token_id: int = None, next_token_logprob: float = None): + if next_token_id: + next_token_id = int(next_token_id) + next_token_logprob = float(next_token_logprob) + + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + group_req_id = shm_req.group_req_id + if group_req_id not in self.remote_prefill_requests: + logger.info(f"remote prefill request {group_req_id} not found") + return + + remote_request: PrefillRequest = self.remote_prefill_requests[group_req_id] + if remote_request.transfer_state is None: + remote_request.transfer_state = TransferState( + start_time=time.time(), + current_chunk_id=0, + transfered_kv_len=remote_request.data.local_cached_len, + current_kv_len=req.cur_kv_len, + is_finished=req.finish_status.is_finished(), + token_index=self.model.req_manager.req_to_token_indexs[req.req_idx].tolist(), + free_page_ids=remote_request.data.page_ids.copy(), + next_token_id=next_token_id, + next_token_logprob=next_token_logprob, + lock=threading.Lock(), + ) + shm_req.set_pd_req_rank_state(self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value) + req.in_prefill_or_transfer = True + self.inflght_transfer_requests[group_req_id] = req + else: + transfer_state: TransferState = remote_request.transfer_state + with transfer_state.lock: + transfer_state.current_chunk_id += 1 + transfer_state.current_kv_len = req.cur_kv_len + transfer_state.is_finished = req.finish_status.is_finished() + if next_token_id: + transfer_state.next_token_id = next_token_id + transfer_state.next_token_logprob = next_token_logprob + + async def _transfer_kv_to_remote_paged_batch(self, transfer_reqs: List[KVMoveRequest]): + requests_by_agents = dict() + transfer_pages = self.page_scheduer.borrow(len(transfer_reqs)) + # first copy the kv to transfer pages & build notification + with torch.cuda.stream(self.page_copy_stream): + for trans_req, page_index in zip(transfer_reqs, transfer_pages): + trans_req: KVMoveRequest + group_req_id = trans_req.group_req_id + remote_request: PrefillRequest = self.remote_prefill_requests.get(group_req_id) + transfer_state: TransferState = remote_request.transfer_state + decode_id: int = remote_request.decode_id + if decode_id not in requests_by_agents: + requests_by_agents[decode_id] = ([], [], []) + + with transfer_state.lock: + + start_kv_len = transfer_state.transfered_kv_len + trans_kv_len = min(trans_req.cur_kv_len - trans_req.prev_kv_len, self.nixl_agent.page_size) + trans_kv_index = transfer_state.token_index[start_kv_len : start_kv_len + trans_kv_len] + self.model.mem_manager.kv_move_buffer[page_index][:trans_kv_len] = self.model.mem_manager.kv_buffer[ + :, trans_kv_index, :, : + ].transpose(0, 1) + + receive_page = transfer_state.free_page_ids.pop(0) + requests_by_agents[decode_id][0].append(page_index) + requests_by_agents[decode_id][1].append(receive_page) + is_last = ( + transfer_state.is_finished and start_kv_len + trans_kv_len == transfer_state.current_kv_len + ) + + requests_by_agents[decode_id][2].append( + RemotePrefillStatus( + transfer_type=RemoteTransferType.PAGE_TRANSFER, + group_req_id=group_req_id, + status=RemoteTransferStatusType.SUCCESS, + chunk_id=transfer_state.current_chunk_id, + is_last=is_last, + page_id=receive_page, + kv_start=start_kv_len, + kv_len=trans_kv_len, + next_token_id=transfer_state.next_token_id, + next_token_logprob=transfer_state.next_token_logprob, + ) + ) + transfer_state.transfered_kv_len += trans_kv_len + + # wait copy done + self.page_copy_stream.synchronize() + for decode_id, (transfer_pages, receive_pages, notifications) in requests_by_agents.items(): + assert len(transfer_reqs) == len(receive_pages), "transfer_reqs and receive_pages should have same length" + # transfer + self.nixl_agent.write_blocks_paged(decode_id, transfer_pages, receive_pages, notifications) + + async def _handle_transfer_loop(self): + while True: + free_transfer_pages = self.page_scheduer.current_size() + if free_transfer_pages < 1: + await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + continue + + transfer_reqs = [] + for group_req_id, req in self.inflght_transfer_requests.items(): + remote_request: PrefillRequest = self.remote_prefill_requests.get(group_req_id) + transfer_state: TransferState = remote_request.transfer_state + with transfer_state.lock: + if transfer_state.completed() or len(transfer_state.free_page_ids) == 0: + continue + + if transfer_state.transfered_kv_len >= transfer_state.current_kv_len: + continue + + transfer_reqs.append( + KVMoveRequest( + group_req_id=group_req_id, + prev_kv_len=transfer_state.transfered_kv_len, + cur_kv_len=transfer_state.current_kv_len, + ) + ) + if len(transfer_reqs) >= free_transfer_pages: + break + + if len(transfer_reqs) > 0: + await self._transfer_kv_to_remote_paged_batch(transfer_reqs) + + await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + async def _wait_page_transfer_loop(self): + while True: + # local pages can be reused as soon as transfer is done + done_pages, done_requests = await self.nixl_agent.get_done_page_transfers() + if len(done_pages): + self.page_scheduer.return_(done_pages) + + # release requests when prefill done + for req_id, status in done_requests: + if req_id not in self.inflght_transfer_requests: + if self.is_master_in_dp: + logger.warning(f"{req_id} not found in inflght_transfer_requests") + continue + + req: InferReq = self.inflght_transfer_requests[req_id] + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, status.value) + transfer_state = self.remote_prefill_requests[req_id].transfer_state + if self.is_master_in_dp: + logger.info( + f"req: {req_id} kv transfer with state: {status} " + f"took: {time.time() - transfer_state.start_time} seconds" + ) + # only delete success transfers, failed / aborted will delete after send abort notification + if status == RemoteTransferStatusType.SUCCESS: + del self.inflght_transfer_requests[req_id] + del self.remote_prefill_requests[req_id] + + # remote pages should be released after nofication received + notifies = self.nixl_agent.get_new_notifs() + for _, trans_acks in notifies.items(): + for trans_ack_bytes in trans_acks: + trans_acks_noti: Notification = Notification.from_bytes(trans_ack_bytes) + assert trans_acks_noti.type == NotificationType.TRANSFER_NOTIFY_ACK + for trans_ack in trans_acks_noti.data: + ack = PageTransferAck.deserialize(trans_ack) + remote_request: PrefillRequest = self.remote_prefill_requests.get(ack.group_req_id) + if remote_request is None: + continue + + transfer_state: TransferState = remote_request.transfer_state + with transfer_state.lock: + transfer_state.free_page_ids.append(ack.page_id) + + await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + async def _wait_transfer_loop(self): + while True: + done_req_ids = self.nixl_agent.get_done_tranfers() + for req_id, state in done_req_ids: + if state != 1: + logger.info(f"wait transfer done: {req_id} state: {state}") + + if req_id not in self.inflght_transfer_requests: + if self.is_master_in_dp: + logger.warning(f"{req_id} not found in inflght_transfer_requests") + continue + + req: InferReq = self.inflght_transfer_requests[req_id] + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + shm_req.set_pd_req_rank_state(self.rank_in_dp, state) + transfer_state = self.remote_prefill_requests[req_id].transfer_state + if self.is_master_in_dp: + logger.info( + f"req: {req_id} kv transfer with state: {state} " + f"took: {time.time() - transfer_state.start_time} seconds" + ) + del self.remote_prefill_requests[req_id] + del self.inflght_transfer_requests[req_id] + + await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL) + + async def _handle_prefill_loop(self): + while True: + request: RemoteRequest = self.from_remote_queue.get() + if request.type == RemoteRequstType.REMOTE_CONNECT: + request: ConnectRequest + logger.info(f"connect request received from: {request.decode_id}") + self.nixl_agent.add_remote_agent( + NixlMetadata( + id=request.decode_id, + num_tokens=request.num_tokens, + num_pages=request.num_pages, + agent_metadatas=request.agent_metadatas, + agent_mem_descs=request.agent_mem_descs, + agent_page_mem_descs=request.agent_page_mem_descs, + ) + ) + self.to_remote_queue.put("OK") + + if request.type == RemoteRequstType.REMOTE_PREFILL: + request: PrefillRequest + group_request_id = request.data.sampling_params.group_request_id + logger.info( + f"prefill request received from decode: {request.decode_id} " + f"and group request id: {group_request_id}" + ) + self.remote_prefill_requests[group_request_id] = request + + def _transfer_kv_to_remote(self, req: InferReq, group_req_id: int, cur_kv_len: int, is_finished: bool): + start = time.time() + remote_request: PrefillRequest = self.remote_prefill_requests[group_req_id] + + transfer_state = remote_request.transfer_state + token_index = self.model.req_manager.req_to_token_indexs[req.req_idx] + + kv_transfer_req = KVMoveRequest( + group_req_id=group_req_id, + token_ids=token_index[:cur_kv_len].tolist(), + prev_kv_len=transfer_state.current_kv_len, + cur_kv_len=cur_kv_len, + ) + if transfer_state.current_chunk_id == 0: + self.inflght_transfer_requests[group_req_id] = req + logger.debug( + f"put {group_req_id} into inflght_transfer_requests and size: {len(self.inflght_transfer_requests)}" + ) + + # kick off kv transfer + self.nixl_agent.write_blocks(kv_transfer_req, remote_request, is_finished) + + transfer_state.current_kv_len = cur_kv_len + transfer_state.current_chunk_id += 1 + logger.info( + f"transfer kv to remote: {group_req_id} " + f"current chunk id: {transfer_state.current_chunk_id} {cur_kv_len} " + f"took: {time.time() - start} seconds" + ) + + def _post_remote_prefill(self, req: InferReq, success: bool = True): + + req.in_prefill_or_transfer = False + group_req_id = req.shm_req.group_req_id + req.cur_kv_len = req.get_cur_total_len() - 1 + + if self.is_master_in_dp: + req.shm_req.shm_cur_kv_len = req.cur_kv_len + if group_req_id in self.request_to_page_ids: + self.page_scheduer.return_(self.request_to_page_ids[group_req_id]) + del self.request_to_page_ids[group_req_id] + + if not success: + self.request_to_first_token.pop(group_req_id, None) + return + + assert group_req_id in self.request_to_first_token, f"{group_req_id} not in request_to_first_token dict" + token_id, token_logprob = self.request_to_first_token.pop(group_req_id) + + # (TODO) figure out how to update req_to_next_token_ids + # req.cur_output_len += 1 + + # pack = InferReqUpdatePack(req, req.cur_output_len) + # pack.handle( + # token_id, + # token_logprob, + # eos_ids=self.eos_id, + # extra_post_req_handle_func=None, + # is_master_in_dp=self.is_master_in_dp, + # call_post_handle_for_chunk=False + # ) + return token_id + + def _decode_filter_reqs(self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]): + new_prefill_reqs: List[InferReq] = [] + remote_prefill_reqs: List[InferReq] = [] + failed_prefill_reqs: List[InferReq] = [] + next_token_ids: List[int] = [] + rpd_reqs: List[InferReq] = [] + + for req in prefill_reqs: + if req.in_prefill_or_transfer: + if req.infer_nixl_rpd: + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + # state is updated by router + state = shm_req.get_pd_req_state() + if state == RemoteTransferStatusType.SUCCESS.value: # success + next_token_ids.append(self._post_remote_prefill(req)) + rpd_reqs.append(req) + elif state == RemoteTransferStatusType.FAILED.value: + self._post_remote_prefill(req, False) + failed_prefill_reqs.append(req) + else: + logger.warning(f"remote prefill request {shm_req.group_req_id} unexpected state {state}") + else: + remote_prefill_reqs.append(req) + else: + new_prefill_reqs.append(req) + + if rpd_reqs: + # g_infer_context.req_manager.req_sampling_params_manager.update_reqs_token_counter( + # rpd_reqs, + # next_token_ids, + # ) + decode_reqs.extend(rpd_reqs) + + return new_prefill_reqs, decode_reqs, failed_prefill_reqs, remote_prefill_reqs + + def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]): + run_reqs = [] + start_loc = 0 + input_ids = [] + nopad_b_req_idx = [] + nopad_b_start_loc = [] + nopad_b_seq_len = [] + + for req in req_objs: + run_reqs.append(req) + nopad_b_req_idx.append(req.req_idx) + nopad_b_start_loc.append(start_loc) + + input_token_ids = req.get_input_token_ids() + seq_len = len(input_token_ids) + input_token_len = seq_len - req.cur_kv_len + input_id = input_token_ids[req.cur_kv_len :] + nopad_b_seq_len.append(seq_len) + input_ids.append(input_id) + start_loc += input_token_len + + nopad_b_start_loc.append(start_loc) # last request + + input_ids = np.concatenate(input_ids, dtype=np.int64) + + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + + req_to_token_indexs = g_infer_context.req_manager.req_to_token_indexs + for idx, req_idx in enumerate(nopad_b_req_idx): + cur_kv_len = req_objs[idx].cur_kv_len + seq_len = nopad_b_seq_len[idx] + mem_start = nopad_b_start_loc[idx] + mem_end = nopad_b_start_loc[idx + 1] + req_to_token_indexs[req_idx, cur_kv_len : nopad_b_seq_len[idx]] = mem_indexes[mem_start:mem_end] + + kwargs = { + "batch_size": len(run_reqs), + "mem_indexes": mem_indexes.tolist(), + "b_start_loc": nopad_b_start_loc, + } + + return kwargs, run_reqs + + def _prefill_abort_remote(self, req_objs: List[InferReq]): + for req_obj in req_objs: + group_req_id = req_obj.shm_req.group_req_id + if group_req_id in self.remote_prefill_requests: + self.nixl_agent.send_abort_notify(self.remote_prefill_requests[group_req_id].decode_id, group_req_id) + del self.remote_prefill_requests[group_req_id] + if group_req_id in self.inflght_transfer_requests: + del self.inflght_transfer_requests[group_req_id] + + +class PDNIXLBackendBaseChunked(PDNIXLBackendBase, ChunkedPrefillBackend): + pass + + +class PDNIXLBackendBaseDPChunked(PDNIXLBackendBase, DPChunkedPrefillBackend): + pass diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py new file mode 100644 index 000000000..943e51cb2 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py @@ -0,0 +1,113 @@ +import os +import time +import torch.multiprocessing as mp +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import List, Tuple, Dict +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq +from lightllm.utils.log_utils import init_logger +from lightllm.server.multimodal_params import MultimodalParams + +from .pd_remote_prefill_obj import ( + RemotePrefillTask, + RemotePrefillServerInfo, + RemotePrefillRequest, + RemoteTransferStatusType, +) + +from .impl_for_pd_base import PDNIXLBackendBaseChunked + +logger = init_logger(__name__) + + +class PDNIXLBackendForDecodeNode(PDNIXLBackendBaseChunked): + def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, nix_meta_queue: mp.Queue) -> None: + super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue) + self.classed_req_strict_prefill = False + self.support_overlap = True + + def init_custom(self): + super(type(self), self).init_custom() + self.wait_prefill_thread = threading.Thread( + target=self._start_async_loop, args=(self._prefill_wait_loop_async,), daemon=True + ) + max_workers = int(os.getenv("PD_NIXL_MOVE_PAGE_POOL_SIZE", 4)) + self.wait_move_page_pool = ThreadPoolExecutor(max_workers) + self.wait_prefill_thread.start() + return + + def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq): + prefill_node = req.shm_req.sample_params.move_kv_to_decode_node.to_dict() + prefill_node_info = RemotePrefillServerInfo( + perfill_server_id=prefill_node["node_id"], + prefill_server_ip=prefill_node["ip"], + prefill_server_port=prefill_node["rpyc_port"], + ) + + mem_indexes = kwargs.get("mem_indexes") + b_start_loc = kwargs.get("b_start_loc") + prefill_request = RemotePrefillRequest( + prompt=req.shm_req.get_prompt_ids(), + sampling_params=req.shm_req.sample_params, + multimodal_params=MultimodalParams.from_dict(req.multimodal_params), + local_cached_len=req.cur_kv_len, + token_ids=mem_indexes[b_start_loc[index] : b_start_loc[index + 1]], + page_ids=self.page_scheduer.borrow(), # get page ids for this request, blocking when not enough pages + ) + return RemotePrefillTask(server_info=prefill_node_info, prefill_request=prefill_request) + + def _trigger_remote_prefill(self, req_id: int, index: int, kwargs: Dict, req: InferReq): + remote_prefill_task = self._build_remote_prefill_task(index, kwargs, req) + self.request_to_page_ids[req_id] = remote_prefill_task.prefill_request.page_ids + self.to_remote_queue.put(remote_prefill_task) + + def _pre_handle_finished_reqs(self, finished_reqs: List[InferReq]): + new_finished_reqs = [] + for req in finished_reqs: + if req.infer_aborted and req.in_prefill_or_transfer: + # those are in progress, we will handle them later + pass + else: + new_finished_reqs.append(req) + + finished_reqs = new_finished_reqs + + def _get_classed_reqs( + self, + req_ids: List[int] = None, + no_decode: bool = False, + strict_prefill: bool = False, + recover_paused: bool = False, + ): + prefill_reqs, decode_reqs = super(type(self), self)._get_classed_reqs( + req_ids, no_decode, strict_prefill, recover_paused + ) + prefill_reqs, decode_reqs, failed_reqs, _ = self._decode_filter_reqs(prefill_reqs, decode_reqs) + + if failed_reqs: + g_infer_context.filter_reqs(failed_reqs) + + if prefill_reqs: + kwargs, run_reqs = self._prepare_remote_prefill_inputs(prefill_reqs) + for idx, run_req in enumerate(run_reqs): + run_req: InferReq = run_req + shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req + # forward each req to remote prefill + # since the token index are the same across TPs, we only need to trigger prefill on master + if self.is_master_in_dp: + run_req.remote_prefill_start = time.time() + # since this function may blocking the calling thread, so we do it in a thread pool + self.wait_move_page_pool.submit( + self._trigger_remote_prefill, shm_req.group_req_id, idx, kwargs, run_req + ) + + shm_req.set_pd_req_rank_state( + self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value + ) # set in progress state + run_req.in_prefill_or_transfer = True + self.remote_prefilled_reqs[shm_req.group_req_id] = run_req + + prefill_reqs.clear() + + return prefill_reqs, decode_reqs diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py new file mode 100644 index 000000000..7bb8abc69 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode_dp.py @@ -0,0 +1,21 @@ +from types import MethodType +import torch.multiprocessing as mp +from lightllm.utils.log_utils import init_logger + +from .impl_for_pd_base import PDNIXLBackendBaseDPChunked +from .impl_for_pd_decode import PDNIXLBackendForDecodeNode + +logger = init_logger(__name__) + + +class PDNIXLDPBackendForDecodeNode(PDNIXLBackendBaseDPChunked): + def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, nix_meta_queue: mp.Queue) -> None: + self.init_custom = MethodType(PDNIXLBackendForDecodeNode.init_custom, self) + super().__init__(prefill_task_queue, prefill_done_queue, nix_meta_queue) + self.classed_req_strict_prefill = False + self.support_overlap = True + + self._build_remote_prefill_task = MethodType(PDNIXLBackendForDecodeNode._build_remote_prefill_task, self) + self._trigger_remote_prefill = MethodType(PDNIXLBackendForDecodeNode._trigger_remote_prefill, self) + self._pre_handle_finished_reqs = MethodType(PDNIXLBackendForDecodeNode._pre_handle_finished_reqs, self) + self._get_classed_reqs = MethodType(PDNIXLBackendForDecodeNode._get_classed_reqs, self) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py new file mode 100644 index 000000000..90e2e690d --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill.py @@ -0,0 +1,67 @@ +import threading +import torch.multiprocessing as mp +from typing import List, Tuple +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.core.objs.req import PDNIXLChunkedPrefillReq +from lightllm.utils.log_utils import init_logger + +from .impl_for_pd_base import PDNIXLBackendBaseChunked +from .pd_remote_prefill_obj import RemoteTransferStatusType + +logger = init_logger(__name__) + + +class PDNIXLBackendForPrefillNode(PDNIXLBackendBaseChunked): + def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, nixl_meta_queue: mp.Queue) -> None: + super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue) + self.support_overlap = False + self.classed_req_no_decode = True + self.extra_post_req_handle_func = self._handle_chunked_transfer + self.call_post_handle_for_chunk = True + + def init_custom(self): + super(type(self), self).init_custom() + self.handle_prefill_loop_thread = threading.Thread( + target=self._start_async_loop, args=(self._handle_prefill_loop,), daemon=True + ) + self.wait_transfer_loop_thread = threading.Thread( + target=self._start_async_loop, args=(self._wait_page_transfer_loop,), daemon=True + ) + self.handle_transfer_loop_thread = threading.Thread( + target=self._start_async_loop, args=(self._handle_transfer_loop,), daemon=True + ) + + self.handle_prefill_loop_thread.start() + self.handle_transfer_loop_thread.start() + self.wait_transfer_loop_thread.start() + return + + def _pre_handle_finished_reqs(self, finished_reqs: List[InferReq]): + new_finished_reqs = [] + need_remote_aborted_reqs = [] + for req in finished_reqs: + if req.in_prefill_or_transfer: + if req.infer_aborted: + need_remote_aborted_reqs.append(req) + new_finished_reqs.append(req) + else: + if req.infer_nixl_rpd: + shm_req: PDNIXLChunkedPrefillReq = req.shm_req + state = shm_req.get_pd_req_state() + if state == RemoteTransferStatusType.SUCCESS.value: # success + req.in_prefill_or_transfer = False + new_finished_reqs.append(req) + elif state == RemoteTransferStatusType.FAILED.value: # failure + need_remote_aborted_reqs.append(req) + req.in_prefill_or_transfer = False + new_finished_reqs.append(req) + else: + logger.warning(f"remote prefill request {shm_req.group_req_id} unexpected state {state}") + else: + pass + else: + new_finished_reqs.append(req) + + finished_reqs = new_finished_reqs + + self._prefill_abort_remote(need_remote_aborted_reqs) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py new file mode 100644 index 000000000..cbdb722ea --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_prefill_dp.py @@ -0,0 +1,22 @@ +from types import MethodType +import torch.multiprocessing as mp +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.log_utils import init_logger + +from .impl_for_pd_base import PDNIXLBackendBaseDPChunked +from .impl_for_pd_prefill import PDNIXLBackendForPrefillNode + +logger = init_logger(__name__) + + +class PDNIXLDPBackendForPrefillNode(PDNIXLBackendBaseDPChunked): + def __init__(self, transfer_task_queue: mp.Queue, transfer_done_queue: mp.Queue, nixl_meta_queue: mp.Queue) -> None: + self.init_custom = MethodType(PDNIXLBackendForPrefillNode.init_custom, self) + super().__init__(transfer_task_queue, transfer_done_queue, nixl_meta_queue) + + self.support_overlap = False + self.classed_req_no_decode = True + self.call_post_handle_for_chunk = True + self.extra_post_req_handle_func = self._handle_chunked_transfer + + self._pre_handle_finished_reqs = MethodType(PDNIXLBackendForPrefillNode._pre_handle_finished_reqs, self) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py new file mode 100644 index 000000000..fffc22858 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py @@ -0,0 +1,374 @@ +from collections import defaultdict +from typing import Dict, List, Any +from torch import Tensor +from dataclasses import dataclass +import queue +import pickle +import time + +from lightllm.utils.log_utils import init_logger + +from .pd_remote_prefill_obj import ( + RemoteAgent, + KVMoveRequest, + PrefillRequest, + RemotePrefillStatus, + ThreadSafeDict, + KVMoveRequestState, + PageTransferAck, + RemoteTransferStatusType, + RemoteTransferType, + NotificationType, + Notification, +) + + +logger = init_logger(__name__) + +try: + from nixl._api import nixl_agent as NixlWrapper + from nixl._api import nixlBind + + logger.info("Nixl is available") +except ImportError: + logger.warning("nixl is not installed, which is required for pd disagreggation!!!") + NixlWrapper = None + + +@dataclass +class NixlMetadata: + id: str + num_tokens: list[int] + num_pages: list[int] + agent_metadatas: list[bytes] + agent_mem_descs: list[bytes] + agent_page_mem_descs: list[bytes] + + +class NixlKVTransporter: + def __init__(self, node_id: int, tp_idx: int): + self.node_id = node_id + self.tp_idx = tp_idx + self.nixl_agent = NixlWrapper(self.agent_name, None) + + self.num_layers = -1 + self.num_tokens = -1 + self.num_heads = -1 + self.head_dims = -1 + self.token_len = -1 + self.num_pages = -1 + self.page_size = -1 + self.page_len = -1 + + self.reg_desc = None + self.local_xfer_handles = None + self.page_reg_desc = None + self.page_local_xfer_handles = None + + self.remote_agents = defaultdict(list) + + self.inflight_transfers: ThreadSafeDict = ThreadSafeDict() + self.inflight_page_transfers: ThreadSafeDict = ThreadSafeDict() + + @property + def agent_name(self) -> str: + return f"{self.node_id}_{self.tp_idx}" + + @property + def agent_metadata(self): + return self.nixl_agent.get_agent_metadata() + + @property + def local_mem_desc(self): + return self.nixl_agent.get_serialized_descs(self.reg_desc) + + @property + def local_page_mem_desc(self): + return self.nixl_agent.get_serialized_descs(self.page_reg_desc) + + def get_new_notifs(self): + return self.nixl_agent.get_new_notifs() + + def _create_xfer_handles(self, reg_desc: nixlBind.nixlRegDList, num_tokens: int, agent_name: str = ""): + base_addr, _, device_id, _ = reg_desc[0] + layer_len = num_tokens * self.token_len + tokens_data = [0] * (self.num_layers * num_tokens) + idx = 0 + for layer_id in range(self.num_layers): + for token_id in range(num_tokens): + tokens_data[idx] = ( + base_addr + layer_id * layer_len + token_id * self.token_len, + self.token_len, + device_id, + ) + idx += 1 + descs = self.nixl_agent.get_xfer_descs(tokens_data, "VRAM", True) + return self.nixl_agent.prep_xfer_dlist(agent_name, descs, is_sorted=True) + + def register_kv_buffer(self, kv_buffer: Tensor): + self.num_layers, self.num_tokens, self.num_heads, self.head_dim = kv_buffer.shape + self.token_len = self.num_heads * self.head_dim * kv_buffer.element_size() + + self.reg_desc = self.nixl_agent.register_memory(kv_buffer) + self.local_xfer_handles = self._create_xfer_handles(self.reg_desc, self.num_tokens) + + def _create_paged_xfer_handles(self, reg_desc: nixlBind.nixlRegDList, page_num: int, agent_name: str = ""): + base_addr, _, device_id, _ = reg_desc[0] + pages_data = [] + for page_id in range(page_num): + pages_data.append((base_addr + page_id * self.page_len, self.page_len, device_id)) + descs = self.nixl_agent.get_xfer_descs(pages_data, "VRAM", True) + return self.nixl_agent.prep_xfer_dlist(agent_name, descs, is_sorted=True) + + def register_kv_move_buffer(self, kv_move_buffer: Tensor): + self.num_pages, self.page_size, _, _, _ = kv_move_buffer.shape + self.page_len = self.page_size * self.num_layers * self.token_len + self.page_reg_desc = self.nixl_agent.register_memory(kv_move_buffer) + self.page_local_xfer_handles = self._create_paged_xfer_handles(self.page_reg_desc, self.num_pages) + + def add_remote_agent(self, remote_agent: NixlMetadata): + for idx, (agent_metadata, num_tokens, num_pages, agent_mem_desc, agent_page_mem_desc) in enumerate( + zip( + remote_agent.agent_metadatas, + remote_agent.num_tokens, + remote_agent.num_pages, + remote_agent.agent_mem_descs, + remote_agent.agent_page_mem_descs, + ) + ): + if self.tp_idx != idx: + self.remote_agents[remote_agent.id].append(None) + continue + + peer_name = self.nixl_agent.add_remote_agent(agent_metadata) + if isinstance(peer_name, bytes): + peer_name = peer_name.decode() + + self.nixl_agent.send_notif( + peer_name, Notification(type=NotificationType.REMOTE_MD, data=self.agent_metadata).to_bytes() + ) + + mem_desc = self.nixl_agent.deserialize_descs(agent_mem_desc) + kv_xfer_handles = self._create_xfer_handles(mem_desc, num_tokens, agent_name=peer_name) + + page_mem_desc = self.nixl_agent.deserialize_descs(agent_page_mem_desc) + kv_page_xfer_handles = self._create_paged_xfer_handles(page_mem_desc, num_pages, agent_name=peer_name) + + logger.info("Added remote agent %s with mem desc %s", peer_name, page_mem_desc) + self.remote_agents[remote_agent.id].append( + RemoteAgent( + name=peer_name, + kv_mem_desc=mem_desc, + num_tokens=num_tokens, + kv_xfer_handles=kv_xfer_handles, + kv_page_mem_desc=page_mem_desc, + num_pages=num_pages, + kv_page_xfer_handles=kv_page_xfer_handles, + ) + ) + + def connect_to_remote(self, name: str, remote_md: bytes): + target = self.nixl_agent.add_remote_agent(remote_md) + if isinstance(target, bytes): + target = target.decode() + assert name == target, "Target name {} does not match remote name {}".format(target, name) + + def _get_token_desc_ids(self, token_ids: List[int], num_tokens: int): + token_ids_len, idx = len(token_ids), 0 + descs_ids = [0] * (self.num_layers * token_ids_len) + for layer_id in range(self.num_layers): + for token_id in token_ids: + descs_ids[idx] = layer_id * num_tokens + token_id + idx += 1 + return descs_ids + + def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, is_finished: bool): + group_reqeust_id = request.group_req_id + skip_kv_move_len = prefill_request.data.local_cached_len + + # current kv len is less than remote cached kv len, just skip + if request.cur_kv_len <= skip_kv_move_len: + return + + kv_move_start = max(skip_kv_move_len, request.prev_kv_len) + kv_move_end = request.cur_kv_len + + src_token_ids = request.token_ids[kv_move_start:] + dst_token_ids = prefill_request.data.token_ids[ + kv_move_start - skip_kv_move_len : kv_move_end - skip_kv_move_len + ] + + remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][ + self.tp_idx + ] # TODO one-one mapping now + + if len(src_token_ids) > 0: + assert len(src_token_ids) == len(dst_token_ids), ( + f"{len(src_token_ids)} {len(dst_token_ids)} {kv_move_start} " + f"{kv_move_end} {skip_kv_move_len}, {len(prefill_request.data.token_ids)}" + ) + src_token_descs = self._get_token_desc_ids(src_token_ids, self.num_tokens) + dst_token_descs = self._get_token_desc_ids(dst_token_ids, remote_agent.num_tokens) + + src_handle = self.local_xfer_handles + dst_handle = remote_agent.kv_xfer_handles + + notify_status = ( + RemotePrefillStatus( + group_req_id=group_reqeust_id, + status=1, + chunk_id=prefill_request.transfer_state.current_chunk_id, + is_last=is_finished, + ).serialize() + if is_finished + else b"" + ) + + handle = self.nixl_agent.make_prepped_xfer( + "WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status + ) + + status = self.nixl_agent.transfer(handle) + assert status != "ERR" + + if group_reqeust_id not in self.inflight_transfers: + self.inflight_transfers[group_reqeust_id] = KVMoveRequestState( + handles=[], done_handles=[], remote_agent=remote_agent, abort=False, is_last_arrived=False + ) + + self.inflight_transfers[group_reqeust_id].handles.append(handle) + + if is_finished: + self.inflight_transfers[group_reqeust_id].is_last_arrived = True + + return handle + + return None + + def write_blocks_paged( + self, + remote_id: int, + transfer_pages: List[int], + receive_pages: List[int], + notifications: List[RemotePrefillStatus], + ): + remote_agent: RemoteAgent = self.remote_agents[remote_id][self.tp_idx] + src_handle = self.page_local_xfer_handles + dst_handle = remote_agent.kv_page_xfer_handles + notify_status = Notification(type=NotificationType.TRANSFER_NOTIFY, data=[n.serialize() for n in notifications]) + handle = self.nixl_agent.make_prepped_xfer( + "WRITE", src_handle, transfer_pages, dst_handle, receive_pages, notify_status.to_bytes() + ) + status = self.nixl_agent.transfer(handle) + assert status != "ERR", f"Transfer failed with status {status} for handle {handle}" + self.inflight_page_transfers[handle] = (transfer_pages, receive_pages, notifications, remote_agent) + + def send_transfer_notify(self, agent_name: str, acks: List[PageTransferAck]): + assert len(acks) > 0, "Acks should not be empty" + acks_noti = Notification(type=NotificationType.TRANSFER_NOTIFY_ACK, data=[ack.serialize() for ack in acks]) + self.nixl_agent.send_notif(agent_name, acks_noti.to_bytes()) + + def send_abort_notify(self, remote_id: int, group_req_id: int): + remote_agent: RemoteAgent = self.remote_agents[remote_id][self.tp_idx] + notify_status = RemotePrefillStatus( + group_req_id=group_req_id, + transfer_type=RemoteTransferType.PAGE_TRANSFER, + status=RemoteTransferStatusType.FAILED, + is_last=True, + ) + self.nixl_agent.send_notif( + remote_agent.name, + Notification(type=NotificationType.TRANSFER_NOTIFY, data=[notify_status.serialize()]).to_bytes(), + ) + + if group_req_id in self.inflight_transfers: + self.inflight_transfers[group_req_id].abort = True + + async def get_done_page_transfers(self): + done_pages = [] + done_requests = [] + for handle, (transfer_pages, _, notifications, _) in self.inflight_page_transfers.items(): + xfer_state = self.nixl_agent.check_xfer_state(handle) + if xfer_state == "DONE": + done_pages.extend(transfer_pages) + done_requests.extend( + [(x.group_req_id, RemoteTransferStatusType.SUCCESS) for x in notifications if x.is_last] + ) + self.nixl_agent.release_xfer_handle(handle) + del self.inflight_page_transfers[handle] + + elif xfer_state == "PROC": + continue + else: + logger.warning(f"Transfer failed with state {xfer_state} for handle {handle}") + done_pages.extend(transfer_pages) + done_requests.extend([(x.group_req_id, RemoteTransferStatusType.FAILED) for x in notifications]) + self.nixl_agent.release_xfer_handle(handle) + del self.inflight_page_transfers[handle] + + return done_pages, done_requests + + def get_done_tranfers(self): + done_req_ids = [] + for req_id, kv_move_state in self.inflight_transfers.items(): + kv_move_state: KVMoveRequestState + if kv_move_state.abort: + logger.warning(f"{req_id} Transfer aborted") + done_req_ids.append((req_id, -1)) + continue + + if not kv_move_state.is_last_arrived: + continue + + remote_agent: RemoteAgent = kv_move_state.remote_agent + + left_handles = [] + failed = False + for handle in kv_move_state.handles: + if failed: + left_handles.append(handle) + continue + + xfer_state = self.nixl_agent.check_xfer_state(handle) + + if xfer_state == "DONE": + kv_move_state.done_handles.append(handle) + elif xfer_state == "PROC": + left_handles.append(handle) + else: + logger.warning(f"{req_id} Transfer failed with state {xfer_state}") + failed = True + kv_move_state.done_handles.append(handle) + notify_failed_status = RemotePrefillStatus( + group_req_id=req_id, status=-1, chunk_id=-1, is_last=True + ) + self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize()) + + kv_move_state.handles = left_handles + + if failed: + done_req_ids.append((req_id, -1)) + elif len(left_handles) == 0: + done_req_ids.append((req_id, 1)) + + for req_id, _ in done_req_ids: + kv_move_state: KVMoveRequestState = self.inflight_transfers[req_id] + for handle in kv_move_state.handles + kv_move_state.done_handles: + # release will abort inflight transfer + self.nixl_agent.release_xfer_handle(handle) + + del self.inflight_transfers[req_id] + return done_req_ids + + def shutdown(self): + self.nixl_agent.deregister_memory(self.reg_desc) + self.nixl_agent.release_dlist_handle(self.local_xfer_handles) + self.nixl_agent.release_dlist_handle(self.page_local_xfer_handles) + for id, agents in self.remote_agents.items(): + for agent in agents: + self.nixl_agent.remove_remote_agent(agent.name) + self.nixl_agent.release_dlist_handle(agent.kv_xfer_handles) + self.nixl_agent.release_dlist_handle(agent.kv_page_xfer_handles) + + for handle in self.inflight_page_transfers: + self.nixl_agent.release_xfer_handle(handle) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py new file mode 100644 index 000000000..d1fa4003a --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py @@ -0,0 +1,314 @@ +from typing import List, Any +import zmq +import inspect +import random +import time + +import torch.multiprocessing as mp + +from lightllm.utils.log_utils import init_logger +from lightllm.utils.net_utils import get_hostname_ip +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.server.pd_io_struct import DistInfo + +from .pd_remote_prefill_obj import ( + ConnectRequest, + RemoteRequest, + RemoteRequstType, + PrefillRequest, + RemotePrefillRequest, + RemotePrefillServerInfo, + RemotePrefillTask, + RemotePrefillStatus, + RemoteTransferStatusType, + RemoteTransferType, + SockWithPoller, +) +from .nixl_kv_transporter import NixlMetadata + +logger = init_logger(__name__) + + +class PDRemotePrefillBase: + def __init__( + self, + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], # need send kv cache to this process and register with nixl + ): + self.id = id + self.dist_info = dist_info + assert len(agent_meta_queues) == dist_info.node_world_size + self.agent_meta_queues = agent_meta_queues + self.from_backend_queue = from_backend_queue + self.to_backend_queues = to_backend_queues + self.local_agent_meta = None + + def local_init(self): + agent_metas = NixlMetadata( + id=self.id, + agent_metadatas=[], + num_tokens=[], + num_pages=[], + agent_mem_descs=[], + agent_page_mem_descs=[], + ) + for tp in range(self.dist_info.node_world_size): + agent_metadata, num_tokens, num_pages, mem_desc, page_mem_desc = self.agent_meta_queues[tp].get(timeout=60) + logger.info(f"Received agent_metadata from {tp} with mem reg: {mem_desc}") + agent_metas.num_tokens.append(num_tokens) + agent_metas.num_pages.append(num_pages) + agent_metas.agent_metadatas.append(agent_metadata) + agent_metas.agent_mem_descs.append(mem_desc) + agent_metas.agent_page_mem_descs.append(page_mem_desc) + + self.local_agent_meta = agent_metas + logger.info("All local kv cache registered.") + + +class PDRemotePrefillServer(PDRemotePrefillBase): + def __init__( + self, + id: int, + dist_info: DistInfo, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], + ): + super().__init__(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) + # map from client id to decode server info + self.remote_decode_clients = {} + + # build control path + _ctx = zmq.Context() + self.recv_from_decode = SockWithPoller(_ctx.socket(zmq.ROUTER)) + self.host_ip = get_hostname_ip() + self.recv_from_decode.bind(f"tcp://{self.host_ip}:{server_port}") + + # build trigger remote prefill path + self.send_to_httpserver = SockWithPoller(_ctx.socket(zmq.PUSH)) + self.send_to_httpserver.connect(f"tcp://{self.host_ip}:{http_server_port}") + + def main_loop(self): + self.local_init() + while True: + try: + client_obj, request = self.recv_from_decode.recv_pyobj_multipart() + request: RemoteRequest + logger.info(f"recevied request from decode, type: {request.type}") + + if request.type == RemoteRequstType.REMOTE_CONNECT: + # forward request to all prefill server + for queue in self.to_backend_queues: + queue.put(request) + + success = True + for idx in range(self.dist_info.node_world_size): + ack = self.from_backend_queue.get() + logger.info(f"received ack from backend {idx}: {ack}") + if ack != "OK": + success = False + break + + self.recv_from_decode.send_pyobj_multipart(client_obj, success) + logger.info(f"Sent ack to decode: {success}") + if not success: + logger.warning(f"Remote connect failed: {request}") + + if request.type == RemoteRequstType.REMOTE_PREFILL: + request: PrefillRequest = request + if self.dist_info.dp_size_in_node > 1: + group_req_id = request.data.sampling_params.group_request_id + suggested_dp_index = request.data.sampling_params.suggested_dp_index + if suggested_dp_index < 0: # not likely to happen + suggested_dp_index = random.randint(0, self.dist_info.dp_size_in_node) + request.data.sampling_params.suggested_dp_index = suggested_dp_index + logger.warning( + f"Suggested dp index is negative for {group_req_id}, set to {suggested_dp_index}" + ) + + for local_rank in range( + suggested_dp_index * self.dist_info.dp_world_size, + (suggested_dp_index + 1) * self.dist_info.dp_world_size, + ): + self.to_backend_queues[local_rank].put(request) + else: + for queue in self.to_backend_queues: + queue.put(request) + + self.send_to_httpserver.send_pyobj( + (request.data.prompt, request.data.sampling_params, request.data.multimodal_params) + ) + + except Exception as e: + logger.error(f"Error in remote prefill server loop: {e}", exc_info=e) + + +class PDRemotePrefillClient(PDRemotePrefillBase): + def __init__( + self, + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, # only tp0 will trigger prefill + to_backend_queues: List[mp.Queue], # one to many done queue + agent_meta_queues: List[mp.Queue], + ): + super().__init__(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues) + # map from server id to prefill server info + + self.remote_prefill_servers = {} + self.client_socket_cnt = 0 + self._ctx = zmq.Context() + + def _connect_server(self, server_ip: str, port: int): + _socket = self._ctx.socket(zmq.DEALER) + _socket.setsockopt_string(zmq.IDENTITY, f"{self.id}_{self.client_socket_cnt}") + self.client_socket_cnt += 1 + connect_str = f"tcp://{server_ip}:{port}" + _socket.connect(connect_str) + return SockWithPoller(_socket) + + def _send_nixl_agent(self, socket: SockWithPoller): + socket.send_pyobj( + ConnectRequest( + type=RemoteRequstType.REMOTE_CONNECT, + decode_id=self.id, + num_tokens=self.local_agent_meta.num_tokens, + num_pages=self.local_agent_meta.num_pages, + agent_metadatas=self.local_agent_meta.agent_metadatas, + agent_mem_descs=self.local_agent_meta.agent_mem_descs, + agent_page_mem_descs=self.local_agent_meta.agent_page_mem_descs, + ) + ) + + success = socket.recv_pyobj(timeout=60) + logger.info(f"recv remote nixl connect response {success}") + if success is None: + logger.warning("timeout to recv remote nixl connect response") + return False + + return success + + def connect_to_prefill_server(self, server_info: RemotePrefillServerInfo): + + if server_info.perfill_server_id in self.remote_prefill_servers: + return True + + # build control path if not exist + _socket = self._connect_server(server_info.prefill_server_ip, server_info.prefill_server_port) + success = self._send_nixl_agent(_socket) + if success: + self.remote_prefill_servers[server_info.perfill_server_id] = (_socket, server_info) + return True + else: + logger.warning("Remote Prefill Server Connect Failed") + return False + + def main_loop(self): + self.local_init() + while True: + try: + prefill_tasks: RemotePrefillTask = self.from_backend_queue.get() + # connect first + if self.connect_to_prefill_server(prefill_tasks.server_info): + # do prefill + self.remote_prefill(prefill_tasks.server_info.perfill_server_id, prefill_tasks.prefill_request) + else: + # failed to connect a remote + for idx in self.to_backend_queues: + self.to_backend_queues.put( + RemotePrefillStatus( + transfer_type=RemoteTransferType.PAGE_TRANSFER, + group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id, + status=RemoteTransferStatusType.FAILED, + is_last=True, + ) + ) + except Exception as e: + logger.error(f"Remote prefill client loop error: {e}", exc_info=e) + + # place request to server do remote prefill + def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest): + socket, _ = self.remote_prefill_servers[server_id] + prefill_request.sampling_params.max_new_tokens = 1 + socket.send_pyobj( + PrefillRequest( + type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request, transfer_state=None + ) + ) + + +def remote_prefill_server_loop( + id: int, + dist_info: DistInfo, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + graceful_registry(inspect.currentframe().f_code.co_name) + server = PDRemotePrefillServer( + id, dist_info, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues + ) + server.main_loop() + + +def start_pd_remote_prefill_server_process( + id: int, + dist_info: DistInfo, + http_server_port: int, + server_port: int, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + proc = mp.Process( + target=remote_prefill_server_loop, + args=(id, dist_info, http_server_port, server_port, from_backend_queue, to_backend_queues, agent_meta_queues), + ) + proc.start() + assert proc.is_alive() + logger.info(f"remote prefill server with id: {id} started!") + return proc + + +def remote_prefill_client_loop( + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + graceful_registry(inspect.currentframe().f_code.co_name) + + client = PDRemotePrefillClient( + id, + dist_info, + from_backend_queue, + to_backend_queues, + agent_meta_queues, + ) + client.main_loop() + + +def start_pd_remote_prefill_client_process( + id: int, + dist_info: DistInfo, + from_backend_queue: mp.Queue, + to_backend_queues: List[mp.Queue], + agent_meta_queues: List[mp.Queue], +): + + proc = mp.Process( + target=remote_prefill_client_loop, + args=(id, dist_info, from_backend_queue, to_backend_queues, agent_meta_queues), + ) + proc.start() + assert proc.is_alive() + logger.info(f"remote prefill client with id: {id} started!") + return proc diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py new file mode 100644 index 000000000..99a61cc43 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py @@ -0,0 +1,301 @@ +from dataclasses import dataclass, asdict +from enum import Enum +import json +from typing import List, Union, Optional, Any +from threading import Lock, Condition +import pickle +import zmq +import threading + +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.pd_io_struct import RemotePrefillServerInfo + +logger = init_logger(__name__) + +try: + from nixl._api import nixlBind, nixl_prepped_dlist_handle, nixl_xfer_handle + +except ImportError: + logger.error("nixl is not installed, which is required for pd disagreggation!!!") + raise + + +class RemoteRequstType(Enum): + REMOTE_CONNECT = 1 + REMOTE_PREFILL = 2 + + +@dataclass +class RemotePrefillRequest: + prompt: Union[str, List[int]] + sampling_params: SamplingParams + multimodal_params: MultimodalParams + local_cached_len: int # will skip transfer + token_ids: List[int] # mem cache indexes + page_ids: List[int] # transfer page indexes + + +@dataclass +class RemotePrefillTask: + server_info: RemotePrefillServerInfo + prefill_request: RemotePrefillRequest + + +@dataclass +class RemoteRequest: + type: RemoteRequstType + + +@dataclass +class ConnectRequest(RemoteRequest): + decode_id: int + num_tokens: List[int] + num_pages: List[int] + agent_metadatas: List[bytes] + agent_mem_descs: List[bytes] + agent_page_mem_descs: List[bytes] + + +@dataclass +class TransferState: + start_time: float + lock: threading.Lock + free_page_ids: List[int] + + current_kv_len: int = 0 + current_chunk_id: int = 0 + + transfered_kv_len: int = 0 + transfered_chunk_id: int = 0 + + token_index: List[int] = None + is_finished: bool = False + + next_token_id: int = None + next_token_logprob: float = None + + def completed(self): + return self.is_finished and self.transfered_kv_len == self.current_kv_len + + +@dataclass +class PrefillRequest(RemoteRequest): + decode_id: int + data: RemotePrefillRequest + # transfer status + transfer_state: Optional[TransferState] + + +@dataclass +class KVMoveRequest: + group_req_id: int + prev_kv_len: int + cur_kv_len: int + + +@dataclass +class RemoteAgent: + name: str + num_tokens: int + num_pages: int + kv_mem_desc: nixlBind.nixlRegDList + kv_xfer_handles: nixl_prepped_dlist_handle + kv_page_mem_desc: nixlBind.nixlRegDList + kv_page_xfer_handles: nixl_prepped_dlist_handle + + +@dataclass +class KVMoveRequestState: + handles: List[nixl_xfer_handle] + done_handles: List[nixl_xfer_handle] + remote_agent: RemoteAgent + abort: bool + is_last_arrived: bool + + +class SerializableBase: + def to_dict(self): + return asdict(self) + + def serialize(self): + return json.dumps(self.to_dict()).encode() + + @classmethod + def from_dict(cls, dict_obj): + return cls(**dict_obj) + + @classmethod + def deserialize(cls, data: bytes): + return cls.from_dict(json.loads(data.decode())) + + +class RemoteTransferType(Enum): + TOKEN_TRANSFER = 1 + PAGE_TRANSFER = 2 + + +class RemoteTransferStatusType(Enum): + FAILED = -1 + IN_PROGRESS = 0 + SUCCESS = 1 + + +@dataclass +class RemotePrefillStatus(SerializableBase): + transfer_type: RemoteTransferType + group_req_id: int + status: RemoteTransferStatusType + chunk_id: int = -1 + is_last: bool = False + page_id: int = -1 + kv_start: int = 0 + kv_len: int = 0 + next_token_id: int = None + next_token_logprob: float = None + + def to_dict(self): + dict_obj = asdict(self) + dict_obj["transfer_type"] = self.transfer_type.name + dict_obj["status"] = self.status.name + return dict_obj + + @classmethod + def from_dict(cls, dict_obj): + dict_obj["transfer_type"] = RemoteTransferType[dict_obj["transfer_type"]] + dict_obj["status"] = RemoteTransferStatusType[dict_obj["status"]] + return cls(**dict_obj) + + +@dataclass +class PageTransferAck(SerializableBase): + group_req_id: int + page_id: int + + +class NotificationType(Enum): + REMOTE_MD = 1 + TRANSFER_NOTIFY = 2 + TRANSFER_NOTIFY_ACK = 3 + + +@dataclass +class Notification: + type: NotificationType + data: Union[bytes, List[bytes]] + + def to_bytes(self): + return pickle.dumps(self) + + @classmethod + def from_bytes(cls, data): + return pickle.loads(data) + + +class ThreadSafeDict: + def __init__(self): + self._dict = {} + self._lock = Lock() + + def __getitem__(self, key): + with self._lock: + return self._dict[key] + + def __setitem__(self, key, value): + with self._lock: + self._dict[key] = value + + def __delitem__(self, key): + with self._lock: + del self._dict[key] + + def __contains__(self, key): + with self._lock: + return key in self._dict + + def __len__(self) -> int: + with self._lock: + return len(self._dict) + + def get(self, key, default=None): + with self._lock: + return self._dict.get(key, default) + + def items(self): + with self._lock: + return list(self._dict.items()) + + def keys(self): + with self._lock: + return list(self._dict.keys()) + + def pop(self, key: Any, default: Optional[Any] = None) -> Any: + with self._lock: + return self._dict.pop(key, default) + + def values(self): + with self._lock: + return list(self._dict.values()) + + def clear(self) -> None: + with self._lock: + self._dict.clear() + + +class SockWithPoller: + def __init__(self, sock: zmq.Socket): + self.sock = sock + self.poller = zmq.Poller() + self.poller.register(self.sock, zmq.POLLIN) + + def recv_pyobj(self, timeout: int = 5): + socks = dict(self.poller.poll(timeout * 1000)) + if socks: + if socks.get(self.sock) == zmq.POLLIN: + return self.sock.recv_pyobj() + else: + None + + def send_pyobj(self, obj: Any): + return self.sock.send_pyobj(obj) + + def recv_pyobj_multipart(self): + client_id, data = self.sock.recv_multipart() + return client_id, pickle.loads(data) + + def send_pyobj_multipart(self, client_id: bytes, data: Any): + return self.sock.send_multipart([client_id, pickle.dumps(data)]) + + def bind(self, addr: str): + return self.sock.bind(addr) + + def connect(self, addr: str): + return self.sock.connect(addr) + + +class SafePageIndexScheduler: + def __init__(self, num_pages: int): + self.num_pages = num_pages + self.items = list(range(num_pages)) + self.lock = Lock() + self.cond = Condition(self.lock) + + def borrow(self, num_pages: int = 2) -> List[int]: + if num_pages > self.num_pages: + raise ValueError(f"Cannot borrow {num_pages} pages, only {self.num_pages} available.") + + with self.cond: + while len(self.items) < num_pages: + self.cond.wait() + ret, self.items = self.items[:num_pages], self.items[num_pages:] + return ret + + def return_(self, items: List[int]): + with self.cond: + self.items.extend(items) + self.cond.notify_all() + + def current_size(self) -> int: + with self.lock: + return len(self.items) diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index f007a3b86..bb10bf83c 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -20,6 +20,10 @@ DPForDecodeNode, ChunckedPrefillForPrefillNode, DPChunkedForPrefillNode, + PDNIXLBackendForPrefillNode, + PDNIXLBackendForDecodeNode, + PDNIXLDPBackendForPrefillNode, + PDNIXLDPBackendForDecodeNode, ) from lightllm.server.router.model_infer.mode_backend.redundancy_expert_manager import RedundancyExpertManager from lightllm.server.core.objs import RpcShmParams, RpcShmResults, ShmSyncStatusArray @@ -41,12 +45,14 @@ def __init__( rpc_event: multiprocessing.Event, rpc_finished_event: multiprocessing.Event, info_queue: mp.Queue, + result_queue: mp.Queue, mem_queue: mp.Queue, ): super().__init__() self.args: StartArgs = args self.node_world_size = node_world_size self.info_queue = info_queue + self.result_queue = result_queue self.mem_queue = mem_queue self.rpc_event = rpc_event self.rpc_finished_event = rpc_finished_event @@ -116,17 +122,32 @@ def init_model(self, kvargs): assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" is_prefill_node = self.args.run_mode == "prefill" is_decode_node = self.args.run_mode == "decode" + is_nixl_prefill_node = self.args.run_mode == "nixl_prefill" + is_nixl_decode_node = self.args.run_mode == "nixl_decode" if is_prefill_node: if self.args.dp > 1: self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue) else: self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) + elif is_nixl_prefill_node: + if self.args.dp > 1: + self.backend = PDNIXLDPBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) + else: + self.backend = PDNIXLBackendForPrefillNode(self.info_queue, self.result_queue, self.mem_queue) + elif is_decode_node: if self.args.dp > 1: self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) else: self.backend = DecodeNode(self.info_queue, self.mem_queue) + + elif is_nixl_decode_node: + if self.args.dp > 1: + self.backend = PDNIXLDPBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) + else: + self.backend = PDNIXLBackendForDecodeNode(self.info_queue, self.result_queue, self.mem_queue) + elif self.args.dp > 1: self.backend = DPChunkedPrefillBackend() elif use_reward_model: @@ -197,6 +218,7 @@ def _init_env( rank_in_node, node_world_size, info_queue, + result_queue, mem_queue, router_lock, rpc_event: mp.Event, @@ -215,7 +237,7 @@ def _init_env( g_router_lock.obj = router_lock model_rpc_server = ModelRpcServer( - args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue + args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, result_queue, mem_queue ) success_event.set() @@ -231,6 +253,7 @@ async def start_model_process( rpc_event, rpc_finished_event, info_queue: mp.Queue, + result_queue: mp.Queue, mem_queue: mp.Queue, router_lock: mp.Queue, ): @@ -245,6 +268,7 @@ async def start_model_process( rank_in_node, node_world_size, info_queue, + result_queue, mem_queue, router_lock, rpc_event, diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 10867b6e5..3e58b056d 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -13,9 +13,9 @@ def _get_req_queue_class(args, router, dp_size_in_node: int): return ChunkedPrefillQueue if args.first_token_constraint_mode: return ChunkedPrefillQueue - if args.run_mode == "decode": + if args.run_mode in ["decode"]: return QueueForPDDecode - if args.run_mode == "prefill": + if args.run_mode in ["prefill", "nixl_prefill", "nixl_decode"]: return ChunkedPrefillQueue if args.disable_chunked_prefill: diff --git a/lightllm/utils/health_check.py b/lightllm/utils/health_check.py index f6305e209..ee0778b65 100644 --- a/lightllm/utils/health_check.py +++ b/lightllm/utils/health_check.py @@ -70,7 +70,7 @@ async def health_check(args, httpserver_manager: HttpServerManager, request: Req health_obj.begin_check() try: request_dict = {"inputs": "你好!", "parameters": {"do_sample": True, "temperature": 0.8, "max_new_tokens": 2}} - if args.run_mode == "prefill": + if args.run_mode in ["prefill", "nixl_prefill"]: request_dict["parameters"]["max_new_tokens"] = 1 prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"]