-
Notifications
You must be signed in to change notification settings - Fork 626
Open
Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
Describe the bug
4卡运行Qwen2.5-72B-Instruct-AWQ,前端界面使用的是Cherry Studio,只能发送文字,发送图片就会报错,不管是大图还是小图都报错。加载后4张卡都有3G左右的空余显存,显存没有溢出。我又试了下Qwen2.5-3B-Instruct-AWQ版本。也是报错。
Reproduction
lmdeploy serve api_server
--model-name Qwen2.5-VL-72B-Instruct
--model-format awq
--session-len 10240
--max-batch-size 16
--max-concurrent-requests 16
--log-level INFO
--tp 4
--server-port 8080
--cache-max-entry-count 0.17
/home/yun/model/Qwen2.5-VL-72B-Instruct-AWQ
Environment
(lmdeploy) yun@alex:~$ lmdeploy check_env
/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
sys.platform: linux
Python: 3.10.19 (main, Oct 21 2025, 16:43:05) [GCC 11.2.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3: Tesla T10
CUDA_HOME: /usr/local/cuda-12.9:/usr/local/cuda-12.9:
GCC: gcc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
PyTorch: 2.8.0+cu128
PyTorch compiling details: PyTorch built with:
- GCC 13.3
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2024.2-Product Build 20240605 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v3.7.1 (Git Hash 8d263e693366ef8db40acc569cc7d8edf644556d)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX512
- CUDA Runtime 12.8
- NVCC architecture flags: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_120,code=sm_120
- CuDNN 91.0.2 (built against CUDA 12.9)
- Built with CuDNN 90.8
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, COMMIT_SHA=a1cb3cc05d46d198467bebbb6e8fba50a325d4e7, CUDA_VERSION=12.8, CUDNN_VERSION=9.8.0, CXX_COMPILER=/opt/rh/gcc-toolset-13/root/usr/bin/c++, CXX_FLAGS= -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.8.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, USE_XCCL=OFF, USE_XPU=OFF,
TorchVision: 0.23.0+cu128
LMDeploy: 0.10.2+
transformers: 4.57.1
fastapi: 0.120.0
pydantic: 2.12.3
triton: 3.4.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X PHB NODE NODE 0-19 0 N/A
GPU1 PHB X NODE NODE 0-19 0 N/A
GPU2 NODE NODE X NODE 0-19 0 N/A
GPU3 NODE NODE NODE X 0-19 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinksError traceback
(lmdeploy) yun@alex:~$ lmdeploy serve api_server \
> --model-name Qwen2.5-VL-72B-Instruct \
> --model-format awq \
> --session-len 10240 \
> --max-batch-size 16 \
> --max-concurrent-requests 16 \
> --log-level INFO \
> --tp 4 \
> --server-port 8080 \
> --cache-max-entry-count 0.50 \
> /home/yun/model/Qwen2.5-VL-72B-Instruct-AWQ
/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
2025-11-10 13:19:30,451 - lmdeploy - INFO - builder.py:66 - matching vision model: Qwen2VLModel
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
2025-11-10 13:19:36,339 - lmdeploy - INFO - async_engine.py:304 - input backend=turbomind, backend_config=TurbomindEngineConfig(dtype='auto', model_format='awq', tp=4, dp=1, device_num=None, attn_tp_size=None, attn_dp_size=None, mlp_tp_size=None, mlp_dp_size=None, outer_dp_size=None, session_len=10240, max_batch_size=16, cache_max_entry_count=0.5, cache_chunk_size=-1, cache_block_seq_len=64, enable_prefix_caching=False, quant_policy=0, rope_scaling_factor=0.0, use_logn_attn=False, download_dir=None, revision=None, max_prefill_token_num=8192, num_tokens_per_iter=0, max_prefill_iters=1, devices=None, empty_init=False, communicator='nccl', hf_overrides=None, enable_metrics=False)
2025-11-10 13:19:36,339 - lmdeploy - INFO - async_engine.py:305 - input chat_template_config=None
2025-11-10 13:19:37,073 - lmdeploy - INFO - async_engine.py:317 - updated chat_template_onfig=ChatTemplateConfig(model_name='hf', model_path='/home/yun/model/Qwen2.5-VL-72B-Instruct-AWQ', system=None, meta_instruction=None, eosys=None, user=None, eoh=None, assistant=None, eoa=None, tool=None, eotool=None, separator=None, capability=None, stop_words=None)
`torch_dtype` is deprecated! Use `dtype` instead!
2025-11-10 13:19:37,500 - lmdeploy - WARNING - converter.py:65 - data type fallback to float16 since torch.cuda.is_bf16_supported is False
2025-11-10 13:19:37,955 - lmdeploy - INFO - turbomind.py:261 - turbomind model config:
{
"model_config": {
"model_name": "",
"chat_template": "",
"model_arch": "Qwen2_5_VLForConditionalGeneration",
"head_num": 64,
"kv_head_num": 8,
"hidden_units": 8192,
"vocab_size": 152064,
"embedding_size": 152064,
"tokenizer_size": 151665,
"num_layer": 80,
"inter_size": [
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696,
29696
],
"norm_eps": 1e-06,
"attn_bias": 1,
"mlp_bias": false,
"window_size": [],
"attn_sink": false,
"qk_norm": false,
"size_per_head": 128,
"group_size": 128,
"data_type": "float16",
"weight_type": "int4",
"expert_weight_type": "int4",
"session_len": 10240,
"attn_tp_size": 4,
"mlp_tp_size": 4,
"model_format": "awq",
"expert_num": [],
"expert_router_bias": false,
"expert_inter_size": 0,
"experts_per_token": 0,
"activation_type": "",
"moe_shared_gate": false,
"norm_topk_prob": false,
"routed_scale": 1.0,
"topk_group": 1,
"topk_method": "greedy",
"moe_group_num": 1,
"q_lora_rank": 0,
"kv_lora_rank": 0,
"qk_rope_dim": 0,
"v_head_dim": 0,
"tune_layer_num": 1
},
"attention_config": {
"softmax_scale": 0.0,
"cache_block_seq_len": 64,
"use_logn_attn": 0,
"max_position_embeddings": 128000,
"rope_param": {
"type": "mrope",
"base": 1000000.0,
"dim": 128,
"factor": 1.0,
"max_position_embeddings": null,
"attention_factor": 1.0,
"beta_fast": 32,
"beta_slow": 1,
"low_freq_factor": null,
"high_freq_factor": null,
"original_max_position_embeddings": null,
"mrope_section": [
16,
24,
24
]
}
},
"lora_config": {
"lora_policy": "",
"lora_r": 0,
"lora_scale": 0.0,
"lora_max_wo_r": 0,
"lora_rank_pattern": "",
"lora_scale_pattern": ""
},
"engine_config": {
"dtype": "auto",
"model_format": "awq",
"tp": 4,
"dp": 1,
"device_num": 4,
"attn_tp_size": 4,
"attn_dp_size": 1,
"mlp_tp_size": 4,
"mlp_dp_size": 1,
"outer_dp_size": 1,
"session_len": 10240,
"max_batch_size": 16,
"cache_max_entry_count": 0.5,
"cache_chunk_size": -1,
"cache_block_seq_len": 64,
"enable_prefix_caching": false,
"quant_policy": 0,
"rope_scaling_factor": 0.0,
"use_logn_attn": false,
"download_dir": null,
"revision": null,
"max_prefill_token_num": 8192,
"num_tokens_per_iter": 8192,
"max_prefill_iters": 2,
"devices": [
0,
1,
2,
3
],
"empty_init": false,
"communicator": "nccl",
"hf_overrides": null,
"enable_metrics": false
}
}
[TM][WARNING] [LlamaTritonModel] `max_context_token_num` is not set, default to 10240.
[TM][INFO] Model:
head_num: 64
kv_head_num: 8
size_per_head: 128
num_layer: 80
vocab_size: 152064
attn_bias: 1
qk_norm: 0
max_batch_size: 16
max_context_token_num: 10240
num_tokens_per_iter: 8192
max_prefill_iters: 2
session_len: 10240
cache_max_entry_count: 0.5
cache_block_seq_len: 64
cache_chunk_size: -1
enable_prefix_caching: 0
model_name:
model_dir:
quant_policy: 0
group_size: 128
expert_per_token: 0
moe_method: 1
2025-11-10 13:19:38,492 - lmdeploy - WARNING - turbomind.py:237 - get 5609 model params
[TM][INFO] [BlockManager] block_size = 5.000 MB
[TM][INFO] [BlockManager] block_size = 5.000 MB
[TM][INFO] [BlockManager] max_block_count = 482
[TM][INFO] [BlockManager] block_size = 5.000 MB
[TM][INFO] [BlockManager] block_size = 5.000 MB
[TM][INFO] [BlockManager] max_block_count = 482
[TM][INFO] [BlockManager] chunk_size = 482
[TM][INFO] [BlockManager] max_block_count = 482
[TM][INFO] [BlockManager] chunk_size = 482
[TM][INFO] [BlockManager] max_block_count = 482
[TM][INFO] [BlockManager] chunk_size = 482
[TM][INFO] [BlockManager] chunk_size = 482
[TM][WARNING] [SegMgr] prefix caching is disabled
[TM][WARNING] [SegMgr] prefix caching is disabled
[TM][WARNING] [SegMgr] prefix caching is disabled
[TM][WARNING] [SegMgr] prefix caching is disabled
[TM][INFO] [Gemm2] Tuning sequence: 8, 16, 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 8208
[TM][INFO] [Gemm2] 8
[TM][INFO] [Gemm2] 16
[TM][INFO] [Gemm2] 32
[TM][INFO] [Gemm2] 48
[TM][INFO] [Gemm2] 64
[TM][INFO] [Gemm2] 96
[TM][INFO] [Gemm2] 128
[TM][INFO] [Gemm2] 192
[TM][INFO] [Gemm2] 256
[TM][INFO] [Gemm2] 384
[TM][INFO] [Gemm2] 512
[TM][INFO] [Gemm2] 768
[TM][INFO] [Gemm2] 1024
[TM][INFO] [Gemm2] 1536
[TM][INFO] [Gemm2] 2048
[TM][INFO] [Gemm2] 3072
[TM][INFO] [Gemm2] 4096
[TM][INFO] [Gemm2] 6144
[TM][INFO] [Gemm2] 8192
[TM][INFO] [Gemm2] 8208
[TM][INFO] [Gemm2] Tuning finished in 13.25 seconds.
[TM][INFO] LlamaBatch<T>::Start()
[TM][INFO] LlamaBatch<T>::Start()
[TM][INFO] LlamaBatch<T>::Start()
[TM][INFO] LlamaBatch<T>::Start()
2025-11-10 13:21:00,255 - lmdeploy - INFO - async_engine.py:335 - updated backend_config=TurbomindEngineConfig(dtype='auto', model_format='awq', tp=4, dp=1, device_num=4, attn_tp_size=4, attn_dp_size=1, mlp_tp_size=4, mlp_dp_size=1, outer_dp_size=1, session_len=10240, max_batch_size=16, cache_max_entry_count=0.5, cache_chunk_size=-1, cache_block_seq_len=64, enable_prefix_caching=False, quant_policy=0, rope_scaling_factor=0.0, use_logn_attn=False, download_dir=None, revision=None, max_prefill_token_num=8192, num_tokens_per_iter=8192, max_prefill_iters=2, devices=[0, 1, 2, 3], empty_init=False, communicator='nccl', hf_overrides=None, enable_metrics=False)
HINT: Please open http://0.0.0.0:8080 in a browser for detailed api usage!!!
HINT: Please open http://0.0.0.0:8080 in a browser for detailed api usage!!!
HINT: Please open http://0.0.0.0:8080 in a browser for detailed api usage!!!
INFO: Started server process [1092876]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8080 (Press CTRL+C to quit)
INFO: 192.168.9.112:62550 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Exception in callback _raise_exception_on_finish(<Future finis...2113437780]')>) at /home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/vl/engine.py:16
handle: <Handle _raise_exception_on_finish(<Future finis...2113437780]')>) at /home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/vl/engine.py:16>
Traceback (most recent call last):
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/vl/engine.py", line 23, in _raise_exception_on_finish
raise e
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/vl/engine.py", line 19, in _raise_exception_on_finish
task.result()
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/vl/model/qwen2.py", line 119, in forward
image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 453, in forward
hidden_states = blk(
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/accelerate/hooks.py", line 175, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 281, in forward
hidden_states = hidden_states + self.attn(
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/accelerate/hooks.py", line 175, in new_forward
output = module._old_forward(*args, **kwargs)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 240, in forward
splits = [
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 241, in <listcomp>
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/functional.py", line 222, in split
return tensor.split(split_size_or_sections, dim)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/_tensor.py", line 1052, in split
return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes have only non-negative entries, but got split_sizes=[-2113437780]
ERROR: Exception in ASGI application
+ Exception Group Traceback (most recent call last):
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_utils.py", line 79, in collapse_excgroups
| yield
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 271, in __call__
| async with anyio.create_task_group() as task_group:
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 781, in __aexit__
| raise BaseExceptionGroup(
| exceptiongroup.ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
+-+---------------- 1 ----------------
| Traceback (most recent call last):
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
| result = await app( # type: ignore[func-returns-value]
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
| return await self.app(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/fastapi/applications.py", line 1134, in __call__
| await super().__call__(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/applications.py", line 113, in __call__
| await self.middleware_stack(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/errors.py", line 186, in __call__
| raise exc
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/errors.py", line 164, in __call__
| await self.app(scope, receive, _send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/base.py", line 189, in __call__
| raise app_exc
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/base.py", line 144, in coro
| await self.app(scope, receive_or_disconnect, send_no_error)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/cors.py", line 85, in __call__
| await self.app(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 63, in __call__
| await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
| raise exc
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
| await app(scope, receive, sender)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
| await self.app(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 716, in __call__
| await self.middleware_stack(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 736, in app
| await route.handle(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 290, in handle
| await self.app(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/fastapi/routing.py", line 124, in app
| await wrap_app_handling_exceptions(app, request)(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
| raise exc
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
| await app(scope, receive, sender)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/fastapi/routing.py", line 111, in app
| await response(scope, receive, send)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 270, in __call__
| with collapse_excgroups():
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/contextlib.py", line 153, in __exit__
| self.gen.throw(typ, value, traceback)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_utils.py", line 85, in collapse_excgroups
| raise exc
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 274, in wrap
| await func()
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 254, in stream_response
| async for chunk in self.body_iterator:
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/serve/openai/api_server.py", line 489, in completion_stream_generator
| async for res in result_generator:
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/serve/async_engine.py", line 777, in generate
| prompt_input = await self._get_prompt_input(prompt,
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/serve/vl_async_engine.py", line 98, in _get_prompt_input
| results = await self.vl_encoder.async_infer(results)
| File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/vl/engine.py", line 61, in async_infer
| outputs = await future
| RuntimeError: split_with_sizes expects split_sizes have only non-negative entries, but got split_sizes=[-2113437780]
+------------------------------------
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
result = await app( # type: ignore[func-returns-value]
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
return await self.app(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/fastapi/applications.py", line 1134, in __call__
await super().__call__(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/applications.py", line 113, in __call__
await self.middleware_stack(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/errors.py", line 186, in __call__
raise exc
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/errors.py", line 164, in __call__
await self.app(scope, receive, _send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/base.py", line 189, in __call__
raise app_exc
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/base.py", line 144, in coro
await self.app(scope, receive_or_disconnect, send_no_error)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/cors.py", line 85, in __call__
await self.app(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 63, in __call__
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
raise exc
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
await app(scope, receive, sender)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/fastapi/middleware/asyncexitstack.py", line 18, in __call__
await self.app(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 716, in __call__
await self.middleware_stack(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 736, in app
await route.handle(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 290, in handle
await self.app(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/fastapi/routing.py", line 124, in app
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
raise exc
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
await app(scope, receive, sender)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/fastapi/routing.py", line 111, in app
await response(scope, receive, send)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 270, in __call__
with collapse_excgroups():
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/contextlib.py", line 153, in __exit__
self.gen.throw(typ, value, traceback)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_utils.py", line 85, in collapse_excgroups
raise exc
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 274, in wrap
await func()
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 254, in stream_response
async for chunk in self.body_iterator:
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/serve/openai/api_server.py", line 489, in completion_stream_generator
async for res in result_generator:
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/serve/async_engine.py", line 777, in generate
prompt_input = await self._get_prompt_input(prompt,
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/serve/vl_async_engine.py", line 98, in _get_prompt_input
results = await self.vl_encoder.async_infer(results)
File "/home/yun/miniconda3/envs/lmdeploy/lib/python3.10/site-packages/lmdeploy/vl/engine.py", line 61, in async_infer
outputs = await future
RuntimeError: split_with_sizes expects split_sizes have only non-negative entries, but got split_sizes=[-2113437780]Metadata
Metadata
Assignees
Labels
No labels