Skip to content

Commit 6ab0a40

Browse files
committed
support vlm online spec train
1 parent 7194c99 commit 6ab0a40

27 files changed

+2017
-717
lines changed

angelslim/compressor/speculative/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414

1515
from .benchmark import BenchmarkConfig, BenchmarkEngine, BenchmarkMode
1616
from .train import (
17-
DataCollatorWithPadding,
1817
DatasetManager,
1918
DraftModelConfig,
20-
OfflineEagle3Trainer,
21-
OnlineEagle3Trainer,
19+
Eagle3TrainerFactory,
2220
TargetHead,
2321
convert_sharegpt_data,
2422
convert_ultrachat_data,
@@ -35,10 +33,8 @@
3533
"create_draft_model",
3634
"DraftModelConfig",
3735
"create_target_model",
38-
"OnlineEagle3Trainer",
39-
"OfflineEagle3Trainer",
36+
"Eagle3TrainerFactory",
4037
"data_generation_work_flow",
41-
"DataCollatorWithPadding",
4238
"convert_sharegpt_data",
4339
"convert_ultrachat_data",
4440
"DatasetManager",

angelslim/compressor/speculative/train/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .data import (
2-
DataCollatorWithPadding,
32
DatasetManager,
43
convert_sharegpt_data,
54
convert_ultrachat_data,
@@ -12,16 +11,14 @@
1211
create_draft_model,
1312
create_target_model,
1413
)
15-
from .trainer import OfflineEagle3Trainer, OnlineEagle3Trainer
14+
from .trainer import Eagle3TrainerFactory
1615

1716
__all__ = [
1817
"create_draft_model",
1918
"DraftModelConfig",
2019
"create_target_model",
21-
"OnlineEagle3Trainer",
22-
"OfflineEagle3Trainer",
20+
"Eagle3TrainerFactory",
2321
"data_generation_work_flow",
24-
"DataCollatorWithPadding",
2522
"convert_sharegpt_data",
2623
"convert_ultrachat_data",
2724
"DatasetManager",
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{
2+
"architectures": [
3+
"Eagle3LlamaForCausalLM"
4+
],
5+
"model_type": "llama",
6+
"attention_bias": false,
7+
"attention_dropout": 0.0,
8+
"bos_token_id": 151643,
9+
"dtype": "bfloat16",
10+
"eos_token_id": 151645,
11+
"head_dim": 128,
12+
"hidden_act": "silu",
13+
"hidden_size": 2560,
14+
"initializer_range": 0.02,
15+
"intermediate_size": 9728,
16+
"max_position_embeddings": 262144,
17+
"num_attention_heads": 32,
18+
"num_hidden_layers": 1,
19+
"num_key_value_heads": 8,
20+
"rms_norm_eps": 1e-06,
21+
"rope_scaling": {
22+
"type": "default",
23+
"rope_type": "default",
24+
"mrope_interleaved": true,
25+
"mrope_section": [
26+
24,
27+
20,
28+
20
29+
]
30+
},
31+
"rope_theta": 5000000,
32+
"use_cache": true,
33+
"vocab_size": 151936,
34+
"tie_word_embeddings": true,
35+
"transformers_version": "4.57.1",
36+
"image_token_id": 151655,
37+
"video_token_id": 151656,
38+
"vision_end_token_id": 151653,
39+
"vision_start_token_id": 151652,
40+
"draft_vocab_size": 32000,
41+
"modal_type": "VLM"
42+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"architectures": [
3+
"Eagle3LlamaForCausalLM"
4+
],
5+
"model_type": "llama",
6+
"attention_bias": false,
7+
"attention_dropout": 0.0,
8+
"bos_token_id": 151643,
9+
"dtype": "bfloat16",
10+
"eos_token_id": 151645,
11+
"head_dim": 128,
12+
"hidden_act": "silu",
13+
"hidden_size": 2560,
14+
"initializer_range": 0.02,
15+
"intermediate_size": 9728,
16+
"max_position_embeddings": 262144,
17+
"num_attention_heads": 32,
18+
"num_hidden_layers": 1,
19+
"num_key_value_heads": 8,
20+
"rms_norm_eps": 1e-06,
21+
"rope_scaling": null,
22+
"rope_theta": 5000000,
23+
"use_cache": true,
24+
"vocab_size": 151936,
25+
"tie_word_embeddings": true,
26+
"transformers_version": "4.57.1",
27+
"image_token_id": 151655,
28+
"video_token_id": 151656,
29+
"vision_end_token_id": 151653,
30+
"vision_start_token_id": 151652,
31+
"draft_vocab_size": 32000,
32+
"modal_type": "VLM"
33+
}

angelslim/compressor/speculative/train/data/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,11 @@
1414

1515
from .chat_templates import get_supported_chat_template_type_strings
1616
from .data_generation import data_generation_work_flow
17-
from .data_utils import (
18-
DataCollatorWithPadding,
19-
convert_sharegpt_data,
20-
convert_ultrachat_data,
21-
)
17+
from .data_utils import convert_sharegpt_data, convert_ultrachat_data
2218
from .dataset import DatasetManager
2319

2420
__all__ = [
2521
"DatasetManager",
26-
"DataCollatorWithPadding",
2722
"convert_sharegpt_data",
2823
"convert_ultrachat_data",
2924
"data_generation_work_flow",

angelslim/compressor/speculative/train/data/chat_templates.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ class ChatTemplateType(Enum):
2828

2929
QWEN3 = "qwen3"
3030
HUNYUAN = "hunyuan"
31+
QWEN3_VL = "qwen3_vl"
3132

3233

3334
# String to ChatTemplateType mapping
3435
CHAT_TEMPLATE_TYPE_MAPPING = {
3536
"qwen3": ChatTemplateType.QWEN3,
3637
"hunyuan": ChatTemplateType.HUNYUAN,
38+
"qwen3_vl": ChatTemplateType.QWEN3_VL,
3739
}
3840

3941

@@ -93,6 +95,27 @@ def _initialize_templates(self) -> Dict[ChatTemplateType, ChatTemplate]:
9395
"please don't share false information."
9496
),
9597
),
98+
ChatTemplateType.QWEN3_VL: ChatTemplate(
99+
user_header="<|im_start|>user\n",
100+
assistant_header="<|im_start|>assistant\n",
101+
system_prompt=[
102+
{
103+
"type": "text",
104+
"text": (
105+
"You are a helpful, respectful and honest assistant. "
106+
"Always answer as helpfully as possible, while being safe. "
107+
"Your answers should not include any harmful, unethical, "
108+
"racist, sexist, toxic, dangerous, or illegal content. "
109+
"Please ensure that your responses are socially unbiased "
110+
"and positive in nature.\n\nIf a question does not make "
111+
"any sense, or is not factually coherent, explain why "
112+
"instead of answering something not correct. If you "
113+
"don't know the answer to a question, please don't share "
114+
"false information."
115+
),
116+
}
117+
],
118+
),
96119
}
97120

98121
def get_template(self, chat_template_type: ChatTemplateType) -> ChatTemplate:

angelslim/compressor/speculative/train/data/data_utils.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"convert_sharegpt_data",
2222
"convert_ultrachat_data",
2323
"DataCollatorWithPadding",
24+
"VLMDataCollatorWithPadding",
2425
]
2526

2627

@@ -100,33 +101,47 @@ def process_token_dict_to_mappings(
100101
return d2t, t2d
101102

102103

104+
def paddingtensor(intensors, N):
105+
B, n, S = intensors.shape
106+
# padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
107+
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
108+
outtensors = torch.cat((intensors, padding_tensor), dim=1)
109+
return outtensors
110+
111+
112+
def paddingtensor2D(intensors, N):
113+
B, n = intensors.shape
114+
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
115+
outtensors = torch.cat((intensors, padding_tensor), dim=1)
116+
return outtensors
117+
118+
119+
def paddingtensor3D(tensor_list):
120+
max_h = max(tensor.shape[-2] for tensor in tensor_list)
121+
max_w = max(tensor.shape[-1] for tensor in tensor_list)
122+
out_tensor_list = []
123+
for tensor in tensor_list:
124+
if tensor.ndim == 2:
125+
tensor = tensor.unsqueeze(0)
126+
b, h, w = tensor.shape
127+
outtensor = torch.zeros(b, max_h, max_w, dtype=tensor.dtype)
128+
outtensor[:, :h, :w] = tensor
129+
out_tensor_list.append(outtensor)
130+
return torch.cat(out_tensor_list)
131+
132+
103133
class DataCollatorWithPadding:
104-
def paddingtensor(self, intensors, N):
105-
B, n, S = intensors.shape
106-
# padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
107-
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
108-
outtensors = torch.cat((intensors, padding_tensor), dim=1)
109-
return outtensors
110-
111-
def paddingtensor2D(self, intensors, N):
112-
B, n = intensors.shape
113-
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
114-
outtensors = torch.cat((intensors, padding_tensor), dim=1)
115-
return outtensors
116134

117135
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
118136
max_length = max(item["input_ids"].shape[1] for item in features)
119137
batch_input_ids = torch.cat(
120-
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]
138+
[paddingtensor2D(item["input_ids"], max_length) for item in features]
121139
)
122140
batch_attention_mask = torch.cat(
123-
[
124-
self.paddingtensor2D(item["attention_mask"], max_length)
125-
for item in features
126-
]
141+
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
127142
)
128143
batch_loss_mask = torch.cat(
129-
[self.paddingtensor2D(item["loss_mask"], max_length) for item in features]
144+
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
130145
)
131146

132147
batch = {
@@ -142,15 +157,49 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
142157
"hidden_states" in item and "target_hiddens" in item for item in features
143158
):
144159
batch["hidden_states"] = torch.cat(
145-
[
146-
self.paddingtensor(item["hidden_states"], max_length)
147-
for item in features
148-
]
160+
[paddingtensor(item["hidden_states"], max_length) for item in features]
149161
)
150162
batch["target_hiddens"] = torch.cat(
151-
[
152-
self.paddingtensor(item["target_hiddens"], max_length)
153-
for item in features
154-
]
163+
[paddingtensor(item["target_hiddens"], max_length) for item in features]
164+
)
165+
return batch
166+
167+
168+
class VLMDataCollatorWithPadding:
169+
170+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
171+
max_length = max(item["input_ids"].shape[1] for item in features)
172+
batch_input_ids = torch.cat(
173+
[paddingtensor2D(item["input_ids"], max_length) for item in features]
174+
)
175+
batch_attention_mask = torch.cat(
176+
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
177+
)
178+
batch_loss_mask = torch.cat(
179+
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
180+
)
181+
182+
batch = {
183+
"input_ids": batch_input_ids,
184+
"attention_mask": batch_attention_mask,
185+
"loss_mask": batch_loss_mask,
186+
}
187+
188+
if "pixel_values" in features[0]:
189+
batch["pixel_values"] = paddingtensor3D(
190+
[item["pixel_values"] for item in features]
191+
)
192+
if "video_pixel_values" in features[0]:
193+
batch["video_pixel_values"] = paddingtensor3D(
194+
[item["video_pixel_values"] for item in features]
155195
)
196+
if "image_grid_thw" in features[0]:
197+
batch["image_grid_thw"] = paddingtensor3D(
198+
[item["image_grid_thw"] for item in features]
199+
)
200+
if "video_grid_thw" in features[0]:
201+
batch["video_grid_thw"] = paddingtensor3D(
202+
[item["video_grid_thw"] for item in features]
203+
)
204+
156205
return batch

0 commit comments

Comments
 (0)