Skip to content

Commit 9096012

Browse files
authored
support interleave text and image in messages (#4141)
* support interleave text and image in messages * change qwen2-vl and qwen3-vl chat template * update ut ci flow * update ut ci flow * update * fix qwen vl series backward compatibility * add ut for internvl series chat template * simplify Response's __str__ and __repr__ * update * use superclass's method * fix
1 parent fe0bce0 commit 9096012

File tree

11 files changed

+498
-214
lines changed

11 files changed

+498
-214
lines changed

.github/workflows/unit-test.yml

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,43 +35,24 @@ jobs:
3535
runs-on: [self-hosted, linux-a100-s2]
3636
timeout-minutes: 4320 # 72hours
3737
container:
38-
image: nvidia/cuda:11.8.0-devel-ubuntu22.04
38+
image: openmmlab/lmdeploy:dev-cu12.8
3939
options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e CUDA_VISIBLE_DEVICES=2,3 --pull never"
4040
volumes:
4141
- /nvme/share_data/github-actions/pip-cache:/root/.cache/pip
4242
- /nvme/share_data/github-actions/hf_home:/root/.cache/huggingface
4343
- /nvme/share_data/github-actions/packages:/root/packages
4444
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
4545
steps:
46-
- name: Setup systems
47-
run: |
48-
apt-get update -y && apt-get install -y software-properties-common wget git curl &&\
49-
add-apt-repository ppa:deadsnakes/ppa -y && apt-get update -y && apt-get install -y --no-install-recommends \
50-
ninja-build rapidjson-dev libgoogle-glog-dev gdb python3.10 python3.10-dev python3.10-venv \
51-
&& apt-get clean -y && rm -rf /var/lib/apt/lists/* && cd /opt && python3 -m venv py3
52-
echo "PATH=/opt/py3/bin:$PATH" >> "$GITHUB_ENV"
5346
- name: Clone repository
54-
uses: actions/checkout@v2
55-
- name: Install pytorch
56-
run: |
57-
python3 -V
58-
python3 -m pip cache dir
59-
python3 -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu118
47+
uses: actions/checkout@v5
6048
- name: Install lmdeploy
6149
run: |
62-
python3 -m pip install packaging protobuf transformers_stream_generator matplotlib
63-
# manually install flash attn
64-
python3 -m pip install /root/packages/cu118/flash_attn-*.whl
65-
python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt
50+
python3 -m pip install -r requirements/test.txt
6651
python3 -m pip install -e .
6752
- name: Check env
6853
run: |
6954
python3 -m pip list
7055
lmdeploy check_env
71-
- name: Test lmdeploy csrc
72-
run: |
73-
#./build/bin/build/bin/unittest
74-
echo "TODO"
7556
- name: Test lmdeploy python UT
7657
run: |
7758
coverage run --branch --source lmdeploy -m pytest -rsE tests

docker/Dockerfile_dev

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \
3434
uv pip install -r requirements_cuda.txt --extra-index-url https://download.pytorch.org/whl/cu128 && \
3535
uv pip install -e .
3636

37+
# install flash_attn
38+
RUN --mount=type=cache,target=/root/.cache/uv \
39+
uv pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
40+
3741
RUN uv cache clean

lmdeploy/messages.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -473,17 +473,35 @@ class Response:
473473
index: int = 0
474474
routed_experts: Any = None
475475

476+
def __str__(self):
477+
return f'text={self.text}\n{self._format_none_text_fields()}'
478+
476479
def __repr__(self):
477-
logits = 'logits=None' if self.logits is None else f'logits.shape={self.logits.shape}\nlogits={self.logits}'
478-
hidden_state = (
479-
'last_hidden_state=None' if self.last_hidden_state is None else
480-
f'last_hidden_state.shape={self.last_hidden_state.shape}\nlast_hidden_state={self.last_hidden_state}')
481-
routed_experts = 'routed_experts=None' if self.routed_experts is None else \
482-
f'routed_experts.shape={self.routed_experts.shape}'
483-
484-
s = (f'text={self.text!r}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n'
485-
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}\n{routed_experts}')
486-
return s
480+
return f'text={self.text!r}\n{self._format_none_text_fields()}'
481+
482+
def _format_none_text_fields(self):
483+
fields = []
484+
fields.append(f'input_token_len={self.input_token_len}')
485+
fields.append(f'generate_token_len={self.generate_token_len}')
486+
fields.append(f'finish_reason="{self.finish_reason}"')
487+
fields.append(f'token_ids={self.token_ids}')
488+
fields.append(f'logprobs={self.logprobs}')
489+
490+
# Helper function to format tensor information
491+
def _format_tensor(name: str, tensor: Optional[torch.Tensor]) -> List[str]:
492+
if tensor is None:
493+
return [f'{name}=None']
494+
try:
495+
return [f'{name}.shape={tensor.shape}', f'{name}={tensor}']
496+
except: # noqa
497+
# in case tensor is not torch.Tensor or has no shape
498+
return [f'{name}={tensor}']
499+
500+
# Format tensor fields
501+
fields.extend(_format_tensor('logits', self.logits))
502+
fields.extend(_format_tensor('last_hidden_state', self.last_hidden_state))
503+
fields.extend(_format_tensor('routed_experts', self.routed_experts))
504+
return '\n'.join(fields)
487505

488506

489507
# modified from https://github.com/vllm-project/vllm/blob/main/vllm/v1/engine/__init__.py

lmdeploy/vl/model/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,27 @@ def collect_images(messages):
181181
}) for x in content if x['type'] == 'image'])
182182
return images
183183

184+
@staticmethod
185+
def IMAGE_TOKEN_included(messages):
186+
"""Check whether the IMAGE_TOKEN is included in the messages.
187+
188+
Args:
189+
messages (List[Dict]): a list of message
190+
Returns:
191+
bool: whether the IMAGE_TOKEN is included in the messages
192+
"""
193+
for message in messages:
194+
role, content = message['role'], message['content']
195+
if role != 'user':
196+
continue
197+
if isinstance(content, str) and '<IMAGE_TOKEN>' in content:
198+
return True
199+
elif isinstance(content, List):
200+
content = [x['text'] for x in content if x['type'] == 'text']
201+
if any('<IMAGE_TOKEN>' in x for x in content):
202+
return True
203+
return False
204+
184205
def to_pytorch_with_input_ids(self, messages):
185206
"""Pack the preprocessing results in a format compatible with what is
186207
required by pytorch engine when input_ids are provided directly.

lmdeploy/vl/model/internvl.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def __init__(self,
7676
hf_config: AutoConfig = None,
7777
backend: str = ''):
7878
super().__init__(model_path, with_llm, max_memory, hf_config, backend)
79-
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
79+
self.image_token = '<IMG_CONTEXT>'
8080
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
81-
self.image_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
81+
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
8282

8383
def build_preprocessor(self):
8484
self.config = self.hf_config
@@ -224,8 +224,8 @@ def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
224224
messages.append(dict(role='forward', content=outputs))
225225
return messages
226226

227-
@staticmethod
228227
def proc_messages(
228+
self,
229229
messages,
230230
chat_template,
231231
sequence_start,
@@ -235,32 +235,39 @@ def proc_messages(
235235
"""Apply chat template to get the prompt."""
236236
prompt_messages = []
237237
IMAGE_TOKEN = '<IMAGE_TOKEN>'
238-
for message in messages:
239-
if isinstance(message['content'], str):
240-
prompt_messages.append(message)
241-
continue
242-
elif message['role'] in ['preprocess', 'forward']:
243-
continue
244-
n_images = len([1 for x in message['content'] if x['type'] == 'image'])
245-
content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
246-
if len(content) == 0:
247-
content.append('')
248-
prompt = content[0]
249-
if IMAGE_TOKEN in prompt and f'<img>{IMAGE_TOKEN}' not in prompt:
250-
prompt = prompt.replace(f'{IMAGE_TOKEN}', f'<img>{IMAGE_TOKEN}</img>')
251-
prompt = prompt.replace('</img><img>', '')
252-
prompt = prompt.replace('<img><img>', '<img>')
253-
prompt = prompt.replace('</img></img>', '</img>')
254-
elif IMAGE_TOKEN not in prompt:
255-
prompt = f'<img>{IMAGE_TOKEN * n_images}</img>\n' + prompt
256-
else:
257-
pass
258-
prompt_messages.append(dict(role='user', content=prompt))
238+
messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]
239+
if VisonModel.IMAGE_TOKEN_included(messages):
240+
# backward compatibility
241+
for message in messages:
242+
role, content = message['role'], message['content']
243+
if role != 'user' or isinstance(content, str):
244+
prompt_messages.append(message)
245+
continue
246+
content = [x['text'] for x in content if x['type'] == 'text']
247+
prompt = ''.join(content)
248+
prompt = prompt.replace(f'{IMAGE_TOKEN}', f'<img>{self.image_token}</img>')
249+
prompt_messages.append(dict(role='user', content=prompt))
250+
else:
251+
for message in messages:
252+
role, content = message['role'], message['content']
253+
if role != 'user' or isinstance(content, str):
254+
prompt_messages.append(message)
255+
continue
256+
_content = []
257+
for item in content:
258+
item_type = item['type']
259+
if item_type == 'text':
260+
_content.append(item['text'])
261+
elif item_type in ['image', 'image_url']:
262+
_content.append(f'<img>{self.image_token}</img>\n')
263+
else:
264+
raise ValueError(f'Unsupported message type: {item["type"]}')
265+
prompt_messages.append(dict(role='user', content=''.join(_content)))
259266
prompt = chat_template.messages2prompt(prompt_messages,
260267
sequence_start,
261268
tools=tools,
262269
enable_thinking=enable_thinking)
263-
return prompt, IMAGE_TOKEN
270+
return prompt, self.image_token
264271

265272
def to_pytorch(self,
266273
messages,

lmdeploy/vl/model/internvl3_hf.py

Lines changed: 4 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs
77

88
from lmdeploy.utils import get_logger
9-
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
9+
from lmdeploy.vl.model.internvl import VISION_MODELS, InternVLVisionModel
1010
from lmdeploy.vl.model.utils import disable_logging
1111

1212
logger = get_logger('lmdeploy')
@@ -32,7 +32,7 @@ class InternVLProcessorKwargs(ProcessingKwargs, total=False):
3232

3333

3434
@VISION_MODELS.register_module()
35-
class InternVL3VisionModel(VisonModel):
35+
class InternVL3VisionModel(InternVLVisionModel):
3636
"""Internvl3 vision model."""
3737

3838
_arch = ['InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration']
@@ -44,11 +44,12 @@ def __init__(self,
4444
hf_config: AutoConfig = None,
4545
backend: str = ''):
4646
super().__init__(model_path, with_llm, max_memory, hf_config, backend)
47-
self.arch = hf_config.architectures[0]
47+
self.arch = self.hf_config.architectures[0]
4848

4949
def build_preprocessor(self):
5050
self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
5151
tokenizer = self.processor.tokenizer
52+
self.image_token = self.processor.image_token
5253
self.image_token_id = tokenizer.context_image_token_id
5354
self.image_tokens_per_patch = self.processor.image_seq_length
5455
self.tokenizer_init_kwargs = tokenizer.init_kwargs
@@ -145,69 +146,3 @@ def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
145146
outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats])
146147
messages.append(dict(role='forward', content=outputs))
147148
return messages
148-
149-
@staticmethod
150-
def proc_messages(
151-
messages,
152-
chat_template,
153-
sequence_start,
154-
tools: Optional[List[object]] = None,
155-
enable_thinking: Optional[bool] = None,
156-
):
157-
"""Apply chat template to get the prompt."""
158-
prompt_messages = []
159-
IMAGE_TOKEN = '<IMAGE_TOKEN>'
160-
for message in messages:
161-
if isinstance(message['content'], str):
162-
prompt_messages.append(message)
163-
continue
164-
elif message['role'] in ['preprocess', 'forward']:
165-
continue
166-
n_images = len([1 for x in message['content'] if x['type'] == 'image'])
167-
content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
168-
prompt = content[0]
169-
if IMAGE_TOKEN in prompt and f'<img>{IMAGE_TOKEN}' not in prompt:
170-
prompt = prompt.replace(f'{IMAGE_TOKEN}', f'<img>{IMAGE_TOKEN}</img>')
171-
prompt = prompt.replace('</img><img>', '')
172-
prompt = prompt.replace('<img><img>', '<img>')
173-
prompt = prompt.replace('</img></img>', '</img>')
174-
elif IMAGE_TOKEN not in prompt:
175-
prompt = f'<img>{IMAGE_TOKEN * n_images}</img>\n' + prompt
176-
else:
177-
pass
178-
prompt_messages.append(dict(role='user', content=prompt))
179-
prompt = chat_template.messages2prompt(prompt_messages,
180-
sequence_start,
181-
tools=tools,
182-
enable_thinking=enable_thinking)
183-
return prompt, IMAGE_TOKEN
184-
185-
def to_pytorch(self,
186-
messages,
187-
chat_template,
188-
tokenizer,
189-
sequence_start,
190-
tools: Optional[List[object]] = None,
191-
enable_thinking: Optional[bool] = None,
192-
**kwargs):
193-
prompt, IMAGE_TOKEN = self.proc_messages(messages,
194-
chat_template,
195-
sequence_start,
196-
tools=tools,
197-
enable_thinking=enable_thinking)
198-
return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)
199-
200-
def to_turbomind(self,
201-
messages,
202-
chat_template,
203-
tokenizer,
204-
sequence_start,
205-
tools: Optional[List[object]] = None,
206-
enable_thinking: Optional[bool] = None,
207-
**kwargs):
208-
prompt, IMAGE_TOKEN = self.proc_messages(messages,
209-
chat_template,
210-
sequence_start,
211-
tools=tools,
212-
enable_thinking=enable_thinking)
213-
return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)

lmdeploy/vl/model/qwen2.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def build_preprocessor(self):
3232
from transformers import AutoProcessor
3333
self.processor = AutoProcessor.from_pretrained(self.model_path)
3434
tokenizer = self.processor.tokenizer
35-
image_token = self.processor.image_token
36-
self.image_token_id = tokenizer.encode(image_token)[-1]
35+
self.image_token = self.processor.image_token
36+
self.image_token_id = tokenizer.encode(self.image_token)[-1]
3737

3838
def preprocess(self, messages: List[Dict]) -> List[Dict]:
3939
"""Refer to `super().preprocess()` for spec."""
@@ -124,33 +124,40 @@ def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
124124
messages.append(dict(role='forward', content=outputs))
125125
return messages
126126

127-
@staticmethod
128-
def proc_messages(messages, chat_template, sequence_start):
127+
def proc_messages(self, messages, chat_template, sequence_start):
129128
"""Apply chat template to get the prompt."""
130129
prompt_messages = []
131130
IMAGE_TOKEN = '<IMAGE_TOKEN>'
132-
for message in messages:
133-
if isinstance(message['content'], str):
131+
messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']]
132+
if VisonModel.IMAGE_TOKEN_included(messages):
133+
# backward compatibility
134+
for message in messages:
135+
role, content = message['role'], message['content']
136+
if role != 'user' or isinstance(content, str):
137+
prompt_messages.append(message)
138+
continue
139+
content = [x['text'] for x in content if x['type'] == 'text']
140+
prompt = ''.join(content)
141+
prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{self.image_token}<|vision_end|>')
142+
prompt_messages.append(dict(role='user', content=prompt))
143+
else:
144+
for message in messages:
145+
role, content = message['role'], message['content']
146+
if role != 'user' or isinstance(content, str):
147+
prompt_messages.append(message)
148+
continue
149+
_content = []
150+
for item in content:
151+
if item['type'] == 'text':
152+
_content.append(item['text'])
153+
elif item['type'] in ['image', 'image_url']:
154+
_content.append(f'<|vision_start|>{self.image_token}<|vision_end|>')
155+
else:
156+
raise ValueError(f'Unsupported message type: {item["type"]}')
157+
message = dict(role=role, content=''.join(_content))
134158
prompt_messages.append(message)
135-
continue
136-
elif message['role'] in ['images', 'preprocess', 'forward']:
137-
continue
138-
n_images = len([1 for x in message['content'] if x['type'] == 'image'])
139-
content = [item['text'] for item in message['content'] if item['type'] == 'text']
140-
prompt = content[0]
141-
if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt:
142-
prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>')
143-
else:
144-
# Qwen2-VL-2B-Instruct will concat image and user prompt
145-
# according to their order in the content list
146-
# we insert image token before user prompt by default. The
147-
# user can use custom image token position if they want the
148-
# same decorated prompt as Qwen2-VL
149-
prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \
150-
n_images + prompt
151-
prompt_messages.append(dict(role=message['role'], content=prompt))
152159
prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
153-
return prompt, IMAGE_TOKEN
160+
return prompt, self.image_token
154161

155162
@staticmethod
156163
def get_mrope_info(seq_len: int,

0 commit comments

Comments
 (0)