Skip to content

Commit 6aabc47

Browse files
author
niushengxiao
committed
feat: refine
1 parent a89723c commit 6aabc47

File tree

9 files changed

+99
-42
lines changed

9 files changed

+99
-42
lines changed

lightllm/server/api_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,4 +537,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
537537
parser.add_argument(
538538
"--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used."""
539539
)
540+
parser.add_argument(
541+
"--disk_cache_dir",
542+
type=str,
543+
default=None,
544+
help="""Directory used to persist disk cache data. Defaults to a temp directory when not set.""",
545+
)
540546
return parser

lightllm/server/core/objs/atomic_lock.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
2929

3030
# acquire_sleep1ms 和 release 是某些特定场景下主动使用进行锁获取的操作函数
3131
def acquire_sleep1ms(self):
32+
last_log_time = time.monotonic()
3233
with atomics.atomicview(buffer=self.shm.buf, atype=atomics.INT) as a:
3334
while not a.cmpxchg_weak(0, 1):
34-
logger.warning("acquire_sleep1ms wait for 1ms")
35+
now = time.monotonic()
36+
if now - last_log_time >= 0.1:
37+
logger.warning("acquire_sleep1ms wait for 100ms")
38+
last_log_time = now
3539
time.sleep(0.001)
3640
pass
3741

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class StartArgs:
111111
cpu_cache_token_page_size: int = field(default=64)
112112
enable_disk_cache: bool = field(default=False)
113113
disk_cache_storage_size: float = field(default=10)
114+
disk_cache_dir: Optional[str] = field(default=None)
114115
# zmp ports
115116
router_port: int = field(default=None)
116117
detokenization_port: int = field(default=None)

lightllm/server/httpserver/manager.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,15 +614,21 @@ async def _wait_to_token_package(
614614
f"disk_prompt_cache_ratio:{disk_prompt_cache_ratio} "
615615
f"mtp_avg_token_per_step:{mtp_avg_token_per_step} "
616616
)
617+
if prompt_cache_len > 0:
618+
logger.info(
619+
f"[gpu cache hit] "
620+
f"prompt_cache_len:{prompt_cache_len} "
621+
f"prompt_cache_ratio:{prompt_cache_ratio} "
622+
)
617623
if cpu_prompt_cache_len > 0:
618624
logger.info(
619-
f"blueswhen "
625+
f"[cpu cache hit] "
620626
f"cpu_prompt_cache_len:{cpu_prompt_cache_len} "
621627
f"cpu_prompt_cache_ratio:{cpu_prompt_cache_ratio} "
622628
)
623629
if disk_prompt_cache_len > 0:
624630
logger.info(
625-
f"blueswhen "
631+
f"[disk cache hit] "
626632
f"disk_prompt_cache_len:{disk_prompt_cache_len} "
627633
f"disk_prompt_cache_ratio:{disk_prompt_cache_ratio} "
628634
)

lightllm/server/multi_level_kv_cache/cpu_cache_client.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ctypes
22
import torch
33
import numpy as np
4-
from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name
4+
from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name, get_disk_cache_prompt_limit_length
55
from typing import List, Optional, Tuple
66
from lightllm.utils.log_utils import init_logger
77
from .shm_objs import ShmDict, ShmLinkedList, _LinkedListItem, IntList
@@ -38,10 +38,12 @@ def __init__(self, only_create_meta_data: bool, init_shm_data: bool):
3838
return
3939

4040
@staticmethod
41+
# 负数编码,用于标记一个page index是一个offload group的第一个page
4142
def _encode_offload_head(page_index: int) -> int:
4243
return -(page_index + 1)
4344

4445
@staticmethod
46+
# 解码恢复page index,并返回该page index是否是一个offload group的第一个page
4547
def _decode_offload_value(value: int) -> Tuple[int, bool]:
4648
if value < 0:
4749
return -(value + 1), True
@@ -126,6 +128,19 @@ def update_pages_status_to_ready(
126128
assert cur_page.ref_count > 0
127129
cur_page.ref_count -= 1
128130

131+
# 控制prompt长度,较短的prompt不进行disk offload
132+
limit_length = get_disk_cache_prompt_limit_length()
133+
if (
134+
disk_offload_enable
135+
and offload_candidates
136+
and len(page_list) * self.args.cpu_cache_token_page_size < limit_length
137+
):
138+
logger.info(
139+
f"skip disk offload for small page, " f"length = {len(page_list) * self.args.cpu_cache_token_page_size}"
140+
)
141+
self.mark_pages_recyclable(page_list=offload_candidates)
142+
return
143+
129144
if disk_offload_enable and offload_candidates:
130145
for idx, page_index in enumerate(offload_candidates):
131146
if idx == 0:
@@ -225,11 +240,19 @@ def recycle_pages(self, page_list: List[int]):
225240
if page_index == -1:
226241
continue
227242
cur_page: _CpuPageStatus = self.page_items.get_item_by_index(page_index)
228-
cur_page.del_self_from_list()
229-
if not cur_page.is_empty() and cur_page.hash_key != 0:
243+
244+
if cur_page.ref_count != 0:
245+
if cur_page.status == cur_page.LOADING and cur_page.ref_count == 1:
246+
cur_page.ref_count = 0
247+
else:
248+
continue
249+
250+
if cur_page.hash_key != 0:
230251
existing_index = self.page_hash_dict.get(cur_page.hash_key)
231-
if existing_index is not None:
252+
if existing_index is not None and existing_index == cur_page.self_index:
232253
self.page_hash_dict.remove(cur_page.hash_key)
254+
255+
cur_page.del_self_from_list()
233256
cur_page.hash_key = 0
234257
cur_page.status = cur_page.EMPTY
235258
cur_page.ref_count = 0

lightllm/server/multi_level_kv_cache/disk_cache_worker.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,20 @@
66
from typing import List, Optional
77

88
import torch
9-
10-
from cache import PyLocalCacheService, PyState
119
from lightllm.utils.envs_utils import get_unique_server_name
1210
from lightllm.utils.log_utils import init_logger
1311

1412
logger = init_logger(__name__)
1513

14+
try:
15+
from cache import PyLocalCacheService, PyState
16+
except ImportError as e:
17+
logger.error(
18+
"Failed to import LightMem library. Please install it first.\n"
19+
"You can install it by building from source: https://github.com/ModelTC/LightMem"
20+
)
21+
raise ImportError("LightMem library is required for disk cache functionality") from e
22+
1623

1724
@dataclass
1825
class _PagePayload:
@@ -23,21 +30,28 @@ class _PagePayload:
2330
class DiskCacheWorker:
2431
"""Background worker that offloads CPU KV pages to disk using kvcache."""
2532

26-
def __init__(self, disk_cache_storage_size: float, cpu_cache_client):
33+
def __init__(
34+
self,
35+
disk_cache_storage_size: float,
36+
cpu_cache_client,
37+
disk_cache_dir: Optional[str] = None,
38+
):
2739
self.cpu_cache_client = cpu_cache_client
2840
self._pages_all_idle = False
2941

3042
assert disk_cache_storage_size > 0
3143
storage_size = int(disk_cache_storage_size * (1024 ** 3))
32-
num_shard = 32
33-
num_worker = 32
44+
num_shard = 64
45+
num_worker = 48
46+
max_concurrent_write_tasks = 16
3447

35-
cache_dir = os.getenv("LIGHTLLM_DISK_CACHE_DIR")
48+
cache_dir = disk_cache_dir
3649
if not cache_dir:
3750
cache_dir = os.path.join(tempfile.gettempdir(), f"lightllm_disk_cache_{get_unique_server_name()}")
3851
os.makedirs(cache_dir, exist_ok=True)
3952
cache_file = os.path.join(cache_dir, "cache_file")
4053

54+
self.max_concurrent_write_tasks = max_concurrent_write_tasks
4155
self._page_major_tensor = self._prepare_tensor(cpu_cache_client.cpu_kv_cache_tensor)
4256

4357
self.service = PyLocalCacheService(
@@ -49,7 +63,7 @@ def __init__(self, disk_cache_storage_size: float, cpu_cache_client):
4963
)
5064

5165
logger.info(
52-
"blueswhen disk cache worker initialized: dir=%s size_bytes=%d shards=%d workers=%d pages_per_block=%d",
66+
"disk cache worker initialized: dir=%s size_bytes=%d shards=%d workers=%d pages_per_block=%d",
5367
cache_dir,
5468
storage_size,
5569
num_shard,
@@ -63,35 +77,15 @@ def _prepare_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
6377

6478
def run(self) -> None:
6579
while True:
66-
time.sleep(0.01)
80+
time.sleep(0.1)
6781
payload_groups = self._gather_offload_payloads()
68-
# self._log_idle_once()
6982
if not payload_groups:
7083
continue
7184
for payloads in payload_groups:
7285
if not payloads:
7386
continue
7487
self._persist_pages_to_disk(payloads)
7588

76-
def _log_idle_once(self) -> int:
77-
locked_pages = 0
78-
self.cpu_cache_client.lock.acquire_sleep1ms()
79-
try:
80-
for page_idx in range(self.cpu_cache_client.page_num):
81-
page_item = self.cpu_cache_client.page_items.get_item_by_index(page_idx)
82-
if not page_item.is_ready_recycle() or page_item.ref_count != 0:
83-
locked_pages += 1
84-
finally:
85-
self.cpu_cache_client.lock.release()
86-
87-
if locked_pages == 0:
88-
if not self._pages_all_idle:
89-
logger.info("blueswhen all cpu cache pages are idle and ready to reuse")
90-
self._pages_all_idle = True
91-
else:
92-
self._pages_all_idle = False
93-
return locked_pages
94-
9589
def _gather_offload_payloads(self) -> List[List[_PagePayload]]:
9690
self.cpu_cache_client.lock.acquire_sleep1ms()
9791
try:
@@ -109,6 +103,7 @@ def _gather_offload_payloads(self) -> List[List[_PagePayload]]:
109103
finally:
110104
self.cpu_cache_client.lock.release()
111105

106+
# 数据写入磁盘
112107
def _persist_pages_to_disk(self, payloads: List[_PagePayload]) -> None:
113108
if not payloads:
114109
return
@@ -120,16 +115,21 @@ def _persist_pages_to_disk(self, payloads: List[_PagePayload]) -> None:
120115
kv_indexer = torch.tensor(page_indexes, dtype=torch.int32, device="cpu")
121116
query_result = self.service.query(tokens)
122117
if not all(query_result):
118+
# 限制写入并发量,给读取操作留资源
119+
while (
120+
self.service.active_threads("r") and self.service.active_threads("w") >= self.max_concurrent_write_tasks
121+
):
122+
time.sleep(0.001)
123+
123124
task = self.service.create(tokens=tokens, kv_page_indexer=kv_indexer, mode="w")
124-
while not task.ready():
125+
# 数据安全即可结束等待,无需写入完成
126+
while not task.data_safe():
125127
time.sleep(0.001)
126128

127129
self.cpu_cache_client.lock.acquire_sleep1ms()
128130
self.cpu_cache_client.update_pages_status_to_ready_recycle(page_list=page_indexes, deref=True)
129131
self.cpu_cache_client.lock.release()
130132

131-
# self._log_idle_once()
132-
133133
def blocks_exist(self, tokens: List[int], start_pos: int = 0) -> bool:
134134
if not tokens or start_pos < 0 or start_pos >= len(tokens):
135135
return False
@@ -141,12 +141,18 @@ def blocks_exist(self, tokens: List[int], start_pos: int = 0) -> bool:
141141
return False
142142
return all(query_result[block_start:block_end])
143143

144+
# 从磁盘读取数据到内存
144145
def load_pages(self, tokens: List[int], page_indexes: List[int], start_pos: int = 0) -> bool:
145146
if not tokens or not page_indexes or len(tokens) != len(page_indexes):
146147
return False
147148
if start_pos < 0 or start_pos >= len(tokens):
148149
return False
149150

151+
# 检测当前是否有写操作在进行,若有则跳过本次load请求,暂时不用
152+
# if self.service.active_threads("w") > 0:
153+
# logger.warning("disk cache worker is busy writing, skip load_pages")
154+
# return False
155+
150156
kv_indexer = torch.tensor(page_indexes, dtype=torch.int32, device="cpu")
151157
task = self.service.create(tokens=tokens, kv_page_indexer=kv_indexer, mode="r", start_pos=start_pos)
152158
while not task.ready():

lightllm/server/multi_level_kv_cache/manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from lightllm.server.core.objs.io_objs import GroupReqIndexes
1515
from lightllm.utils.graceful_utils import graceful_registry
1616
from .cpu_cache_client import CpuKvCacheClient
17-
from .disk_cache_worker import DiskCacheWorker
1817
from lightllm.utils.log_utils import init_logger
1918

2019
logger = init_logger(__name__)
@@ -45,9 +44,12 @@ def __init__(
4544
self.disk_cache_worker = None
4645
self.disk_cache_thread = None
4746
if self.args.enable_disk_cache:
47+
from .disk_cache_worker import DiskCacheWorker
48+
4849
self.disk_cache_worker = DiskCacheWorker(
4950
disk_cache_storage_size=self.args.disk_cache_storage_size,
5051
cpu_cache_client=self.cpu_cache_client,
52+
disk_cache_dir=self.args.disk_cache_dir,
5153
)
5254
self.disk_cache_thread = threading.Thread(target=self.disk_cache_worker.run, daemon=True)
5355
self.disk_cache_thread.start()
@@ -71,8 +73,8 @@ def _handle_group_req_cpu_cache_match(self, group_req_indexes: GroupReqIndexes,
7173
if current_time - start_time >= self.cpu_cache_time_out:
7274
self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
7375
logger.warning(
74-
f"blueswhen cpu cache match time out {current_time - start_time}s, "
75-
"group_req_id: {group_req_indexes.group_req_id}"
76+
f"cpu cache match time out {current_time - start_time}s, "
77+
f"group_req_id: {group_req_indexes.group_req_id}"
7678
)
7779
return
7880

lightllm/utils/envs_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,8 @@ def enable_radix_tree_timer_merge() -> bool:
194194
@lru_cache(maxsize=None)
195195
def get_radix_tree_merge_update_delta() -> int:
196196
return int(os.getenv("LIGHTLMM_RADIX_TREE_MERGE_DELTA", 6000))
197+
198+
199+
@lru_cache(maxsize=None)
200+
def get_disk_cache_prompt_limit_length():
201+
return int(os.getenv("LIGHTLLM_DISK_CACHE_PROMPT_LIMIT_LENGTH", 10000))

requirements.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,8 @@ librosa==0.11.0
8787
cuda_bindings==12.9.0
8888
orjson==3.11.2
8989
setproctitle==1.3.6
90-
xxhash==3.6.0
90+
xxhash==3.6.0
91+
torchvision==0.23.0
92+
interegular
93+
partial_json_parser
94+
websockets

0 commit comments

Comments
 (0)