Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightllm/models/gemma3/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def context_forward(self, input_ids, infer_state, layer_weight):
continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
img_weight.append(bytes2tensor(data, torch_dtype=dtype).cuda().reshape(img["token_num"], -1))
img_start_token_ids.append(img["token_id"])
img_token_lens.append(img["token_num"])
img_start_locs.append(img_start_loc)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
img_weight.append(bytes2tensor(data, torch_dtype=dtype).cuda().reshape(img["token_num"], -1))
img_start_token_ids.append(img["token_id"])
img_token_lens.append(img["token_num"])
img_start_locs.append(img_start_loc)
Expand Down
24 changes: 13 additions & 11 deletions lightllm/server/embed_cache/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@


def tensor2bytes(t: torch.Tensor):
# t = t.cpu().numpy().tobytes()
# return t
buf = BytesIO()
torch.save(t.detach().cpu(), buf)
buf.seek(0)
return buf.read()


def bytes2tensor(b):
# return torch.from_numpy(np.frombuffer(b, dtype=np.float16)).cuda()
return torch.load(BytesIO(b))
if t.dtype == torch.float32:
t = t.cpu().numpy().tobytes()
else:
t = t.cpu().to(torch.uint16).numpy().tobytes()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

For non-float32 types, the tensor2bytes function uses t.cpu().to(torch.uint16), which performs a value cast instead of a bit-wise reinterpretation, leading to data loss. Use t.cpu().contiguous().view(torch.uint16) to preserve the bit representation. Explicitly check for torch.float16 and torch.bfloat16.

Suggested change
else:
t = t.cpu().to(torch.uint16).numpy().tobytes()
elif t.dtype == torch.float16 or t.dtype == torch.bfloat16:
t_view = t.cpu().contiguous().view(torch.uint16)
t = t_view.numpy().tobytes()
else:
raise TypeError(f"Unsupported dtype for tensor2bytes: {t.dtype}. Only float32, float16, bfloat16 are explicitly supported.")

return t


def bytes2tensor(b, torch_dtype=torch.bfloat16):
if torch_dtype == torch.float32:
arr_loaded = np.frombuffer(b, dtype=np.float32)
else:
arr_loaded = np.frombuffer(b, dtype=np.uint16)
return torch.from_numpy(arr_loaded).to(torch_dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The bytes2tensor function incorrectly casts uint16 values back to float, instead of reinterpreting the bits. Use .view(torch_dtype) on the tensor created from arr_loaded to reinterpret the bits correctly. Use .copy() on the NumPy array returned by np.frombuffer for safety. Explicitly handle float32, float16/bfloat16.

    if torch_dtype == torch.float32:
        arr_loaded = np.frombuffer(b, dtype=np.float32)
        return torch.from_numpy(arr_loaded)
    elif torch_dtype == torch.float16 or torch_dtype == torch.bfloat16:
        arr_loaded_uint16 = np.frombuffer(b, dtype=np.uint16)
        return torch.from_numpy(arr_loaded_uint16.copy()).view(torch_dtype)
    else:
        raise TypeError(f"Unsupported torch_dtype for bytes2tensor: {torch_dtype}. This function is optimized for float32, float16, bfloat16.")



def create_shm(name, data):
Expand Down