Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 149 additions & 1 deletion fastdeploy/rl/dynamic_weight_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,105 @@
from fastdeploy.config import FDConfig
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus

""" -------------------------------------------------------- """
from fastdeploy.model_executor.utils import process_final_after_loading
import json
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug ModelRegistrymulti_switch_config_context_update_ipc_async 中被使用但未导入,运行时会抛出 NameError

需要添加以下导入:

from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.model_executor.utils import process_final_after_loading, multi_switch_config_context

from paddle.base import core as paddle_core
import zmq

# --- dtype string → paddle dtype mapping ---
_DTYPE_MAP = {
"paddle.float32": paddle.float32,
"paddle.float16": paddle.float16,
"paddle.bfloat16": paddle.bfloat16,
"paddle.int32": paddle.int32,
"paddle.int64": paddle.int64,
"paddle.uint8": paddle.uint8,
}


def receive_all_buckets(gpu_id, ipc_root="/shared_ipc_meta", target_device=None, recv_timeout_ms=300_000):
"""Connect to a ReshardWorkerThread and receive all buckets.

Args:
gpu_id: The training-side GPU id (determines the IPC socket path).
ipc_root: Root directory for IPC socket files.
target_device: The CUDA device id to rebuild tensors on. Defaults to gpu_id.
recv_timeout_ms: ZMQ receive timeout in milliseconds.

Yields:
(name, paddle.Tensor) tuples with original dtype and shape restored.
"""
if target_device is None:
target_device = gpu_id

ctx = zmq.Context()
sock = ctx.socket(zmq.PAIR)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.RCVTIMEO, recv_timeout_ms)
ipc_addr = f"ipc:///{ipc_root}/ipc_metas_{gpu_id}"
sock.connect(ipc_addr)
print(f"[Receiver] Connected to {ipc_addr}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 receive_all_buckets 函数内使用 print 而非 logger,在生产环境中建议统一使用 logger 进行日志输出,便于日志级别控制和格式统一。


bucket_count = 0

try:
while True:
raw = sock.recv()

# --- Termination sentinel ---
if raw == b"END":
print(f"[Receiver] Received END after {bucket_count} bucket(s)")
break

# --- Decode bucket message ---
message = json.loads(raw.decode())
buffer_meta = message["buffer"]
layout = message["layout"]

# --- Rebuild flat uint8 buffer from CUDA IPC handle ---
buffer_meta[0] = buffer_meta[0].encode("latin-1")
buffer_meta[6] = int(os.getenv("FLAGS_selected_gpus", "0"))

This comment was marked as outdated.

lod_tensor = paddle_core.LoDTensor._new_shared_cuda(tuple(buffer_meta))
flat_buf = paddle.to_tensor(lod_tensor)

paddle.device.synchronize()

# --- Slice and reconstruct individual tensors ---
n_params = len(layout)
for name, dtype_key, byte_offset, n_bytes, shape in layout:
# param_bytes = flat_buf[byte_offset : byte_offset + n_bytes]
# .clone() ensures zero storage offset so that .view(dtype) reads from the correct position.
# Without it, Paddle's view(dtype) may ignore the slice offset and start from byte 0,
# which makes the first param (offset=0) correct but all subsequent ones wrong.
param_bytes = flat_buf[byte_offset : byte_offset + n_bytes].clone()
target_dtype = _DTYPE_MAP[dtype_key]
param = param_bytes.view(target_dtype).reshape(shape)
# Clone to own memory so we can release the IPC flat buffer
print(">>>>>>>>>> ori param_key", name, param, param._md5sum())

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 此处 print 会将完整 tensor 内容输出到标准输出,在生产环境中会造成大量日志输出并影响性能。建议移除或替换为 logger.debug,仅打印 name 和 md5sum。

logger.debug(f"ori param_key: {name}, md5={param._md5sum()}")

yield name, param.clone()

bucket_count += 1

print(
f"[Receiver] Bucket {bucket_count}: {n_params} params, "
f"flat_buf size={flat_buf.numel()} bytes"
)

# --- Release IPC buffer reference, then ack ---
del flat_buf, lod_tensor
sock.send(b"OK")

except zmq.Again:
print(f"[Receiver] Timeout waiting for data (timeout={recv_timeout_ms}ms)")
raise
finally:
sock.close()
ctx.term()
print("[Receiver] Socket closed, context terminated")

""" -------------------------------------------------------- """


class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes."""
Expand Down Expand Up @@ -175,7 +274,8 @@
# step3 : update model weight
strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot,
"ipc": self._update_ipc,
# "ipc": self._update_ipc,
"ipc": self._update_ipc_async,

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 reload_model_weights(第354行)仍将 "ipc" 映射到旧方法 self._update_ipc,而此处已改为 self._update_ipc_async,两个策略映射不一致,可能导致不同调用路径行为不同。建议同步更新 reload_model_weights 中的映射。

}

if handler := strategy_handlers.get(self.load_config.load_strategy):
Expand All @@ -190,6 +290,42 @@
# step5: recapture cuda_graph
# step6: update weight status signal

def _update_ipc_async(self):
"""Update using IPC snapshot async for elastic recovery."""
logger.info(f"[ALLOC] check memory before _update_ipc_async: {paddle.device.cuda.memory_allocated() / (1024**3)}")
logger.info(f"[RESERVED] check memory before _update_ipc_async: {paddle.device.cuda.memory_reserved() / (1024**3)}")
context = paddle.LazyGuard()
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
if self.fd_config.quant_config is not None:
quantization_context = multi_switch_config_context(
(self.fd_config.quant_config, "is_checkpoint_bf16", True),
(self.fd_config.load_config, "dynamic_load_weight", False),
)
else:

Check failure on line 304 in fastdeploy/rl/dynamic_weight_manager.py

View workflow job for this annotation

GitHub Actions / Pre Commit

Ruff (F821)

fastdeploy/rl/dynamic_weight_manager.py:304:36: F821 Undefined name `multi_switch_config_context`
# bf16
quantization_context = multi_switch_config_context(
(self.fd_config.load_config, "dynamic_load_weight", False)
)
gpu_id = self._get_gpu_id()
weights_iterator = receive_all_buckets(gpu_id, ipc_root="/shared_ipc_meta", target_device=None, recv_timeout_ms=300000)

Check failure on line 310 in fastdeploy/rl/dynamic_weight_manager.py

View workflow job for this annotation

GitHub Actions / Pre Commit

Ruff (F821)

fastdeploy/rl/dynamic_weight_manager.py:310:36: F821 Undefined name `multi_switch_config_context`
with quantization_context:
with context:
model_cls = ModelRegistry.get_class(architectures)
tmp_model = model_cls(self.fd_config)
tmp_model.eval()
tmp_model.load_weights(weights_iterator)
if self.fd_config.speculative_config.model_type != "mtp":
process_final_after_loading(tmp_model, self.fd_config)
self._capture_model_state() # thd test

Check failure on line 319 in fastdeploy/rl/dynamic_weight_manager.py

View workflow job for this annotation

GitHub Actions / Pre Commit

Ruff (F821)

fastdeploy/rl/dynamic_weight_manager.py:319:29: F821 Undefined name `ModelRegistry`
self._update_model_from_state(tmp_model.state_dict(), "raw")
for param in tmp_model.state_dict().values():
param._clear_data()
del param
tmp_model.state_dict().clear()
tmp_model = None
logger.info(f"[ALLOC] check memory after _update_ipc_async: {paddle.device.cuda.memory_allocated() / (1024**3)}")
logger.info(f"[RESERVED] check memory after _update_ipc_async: {paddle.device.cuda.memory_reserved() / (1024**3)}")

def restart_communication_group(self):
if not self.first_load:
start_time = time.perf_counter()
Expand Down Expand Up @@ -331,10 +467,22 @@
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)

paddle.device.cuda.empty_cache()
logger.info(f"[ALLOC] check memory before clear_param_data: {paddle.device.cuda.memory_allocated() / (1024**3)}")
logger.info(f"[RESERVED] check memory before clear_param_data: {paddle.device.cuda.memory_reserved() / (1024**3)}")
# step2: release model weight
for model in self.model_list:
for param in model.state_dict().values():
param._clear_data()
del param
self.state_dict = {} # thd test
gc.collect()
import ctypes
try:
ctypes.CDLL("libc.so.6").malloc_trim(0)
except Exception as e:
logger.warning(f"malloc_trim failed: {e}")
logger.info(f"[ALLOC] check memory after clear_param_data: {paddle.device.cuda.memory_allocated() / (1024**3)}")
logger.info(f"[RESERVED] check memory after clear_param_data: {paddle.device.cuda.memory_reserved() / (1024**3)}")

self._verify_parameters("clearance")

Expand Down
Loading