66from typing import List , Optional
77
88import torch
9-
10- from cache import PyLocalCacheService , PyState
119from lightllm .utils .envs_utils import get_unique_server_name
1210from lightllm .utils .log_utils import init_logger
1311
1412logger = init_logger (__name__ )
1513
14+ try :
15+ from cache import PyLocalCacheService , PyState
16+ except ImportError as e :
17+ logger .error (
18+ "Failed to import LightMem library. Please install it first.\n "
19+ "You can install it by building from source: https://github.com/ModelTC/LightMem"
20+ )
21+ raise ImportError ("LightMem library is required for disk cache functionality" ) from e
22+
1623
1724@dataclass
1825class _PagePayload :
@@ -23,21 +30,28 @@ class _PagePayload:
2330class DiskCacheWorker :
2431 """Background worker that offloads CPU KV pages to disk using kvcache."""
2532
26- def __init__ (self , disk_cache_storage_size : float , cpu_cache_client ):
33+ def __init__ (
34+ self ,
35+ disk_cache_storage_size : float ,
36+ cpu_cache_client ,
37+ disk_cache_dir : Optional [str ] = None ,
38+ ):
2739 self .cpu_cache_client = cpu_cache_client
2840 self ._pages_all_idle = False
2941
3042 assert disk_cache_storage_size > 0
3143 storage_size = int (disk_cache_storage_size * (1024 ** 3 ))
32- num_shard = 32
33- num_worker = 32
44+ num_shard = 64
45+ num_worker = 48
46+ max_concurrent_write_tasks = 16
3447
35- cache_dir = os . getenv ( "LIGHTLLM_DISK_CACHE_DIR" )
48+ cache_dir = disk_cache_dir
3649 if not cache_dir :
3750 cache_dir = os .path .join (tempfile .gettempdir (), f"lightllm_disk_cache_{ get_unique_server_name ()} " )
3851 os .makedirs (cache_dir , exist_ok = True )
3952 cache_file = os .path .join (cache_dir , "cache_file" )
4053
54+ self .max_concurrent_write_tasks = max_concurrent_write_tasks
4155 self ._page_major_tensor = self ._prepare_tensor (cpu_cache_client .cpu_kv_cache_tensor )
4256
4357 self .service = PyLocalCacheService (
@@ -49,7 +63,7 @@ def __init__(self, disk_cache_storage_size: float, cpu_cache_client):
4963 )
5064
5165 logger .info (
52- "blueswhen disk cache worker initialized: dir=%s size_bytes=%d shards=%d workers=%d pages_per_block=%d" ,
66+ "disk cache worker initialized: dir=%s size_bytes=%d shards=%d workers=%d pages_per_block=%d" ,
5367 cache_dir ,
5468 storage_size ,
5569 num_shard ,
@@ -63,35 +77,15 @@ def _prepare_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
6377
6478 def run (self ) -> None :
6579 while True :
66- time .sleep (0.01 )
80+ time .sleep (0.1 )
6781 payload_groups = self ._gather_offload_payloads ()
68- # self._log_idle_once()
6982 if not payload_groups :
7083 continue
7184 for payloads in payload_groups :
7285 if not payloads :
7386 continue
7487 self ._persist_pages_to_disk (payloads )
7588
76- def _log_idle_once (self ) -> int :
77- locked_pages = 0
78- self .cpu_cache_client .lock .acquire_sleep1ms ()
79- try :
80- for page_idx in range (self .cpu_cache_client .page_num ):
81- page_item = self .cpu_cache_client .page_items .get_item_by_index (page_idx )
82- if not page_item .is_ready_recycle () or page_item .ref_count != 0 :
83- locked_pages += 1
84- finally :
85- self .cpu_cache_client .lock .release ()
86-
87- if locked_pages == 0 :
88- if not self ._pages_all_idle :
89- logger .info ("blueswhen all cpu cache pages are idle and ready to reuse" )
90- self ._pages_all_idle = True
91- else :
92- self ._pages_all_idle = False
93- return locked_pages
94-
9589 def _gather_offload_payloads (self ) -> List [List [_PagePayload ]]:
9690 self .cpu_cache_client .lock .acquire_sleep1ms ()
9791 try :
@@ -109,6 +103,7 @@ def _gather_offload_payloads(self) -> List[List[_PagePayload]]:
109103 finally :
110104 self .cpu_cache_client .lock .release ()
111105
106+ # 数据写入磁盘
112107 def _persist_pages_to_disk (self , payloads : List [_PagePayload ]) -> None :
113108 if not payloads :
114109 return
@@ -120,16 +115,21 @@ def _persist_pages_to_disk(self, payloads: List[_PagePayload]) -> None:
120115 kv_indexer = torch .tensor (page_indexes , dtype = torch .int32 , device = "cpu" )
121116 query_result = self .service .query (tokens )
122117 if not all (query_result ):
118+ # 限制写入并发量,给读取操作留资源
119+ while (
120+ self .service .active_threads ("r" ) and self .service .active_threads ("w" ) >= self .max_concurrent_write_tasks
121+ ):
122+ time .sleep (0.001 )
123+
123124 task = self .service .create (tokens = tokens , kv_page_indexer = kv_indexer , mode = "w" )
124- while not task .ready ():
125+ # 数据安全即可结束等待,无需写入完成
126+ while not task .data_safe ():
125127 time .sleep (0.001 )
126128
127129 self .cpu_cache_client .lock .acquire_sleep1ms ()
128130 self .cpu_cache_client .update_pages_status_to_ready_recycle (page_list = page_indexes , deref = True )
129131 self .cpu_cache_client .lock .release ()
130132
131- # self._log_idle_once()
132-
133133 def blocks_exist (self , tokens : List [int ], start_pos : int = 0 ) -> bool :
134134 if not tokens or start_pos < 0 or start_pos >= len (tokens ):
135135 return False
@@ -141,12 +141,18 @@ def blocks_exist(self, tokens: List[int], start_pos: int = 0) -> bool:
141141 return False
142142 return all (query_result [block_start :block_end ])
143143
144+ # 从磁盘读取数据到内存
144145 def load_pages (self , tokens : List [int ], page_indexes : List [int ], start_pos : int = 0 ) -> bool :
145146 if not tokens or not page_indexes or len (tokens ) != len (page_indexes ):
146147 return False
147148 if start_pos < 0 or start_pos >= len (tokens ):
148149 return False
149150
151+ # 检测当前是否有写操作在进行,若有则跳过本次load请求,暂时不用
152+ # if self.service.active_threads("w") > 0:
153+ # logger.warning("disk cache worker is busy writing, skip load_pages")
154+ # return False
155+
150156 kv_indexer = torch .tensor (page_indexes , dtype = torch .int32 , device = "cpu" )
151157 task = self .service .create (tokens = tokens , kv_page_indexer = kv_indexer , mode = "r" , start_pos = start_pos )
152158 while not task .ready ():
0 commit comments