Skip to content

Commit 7382a63

Browse files
authored
Support VLM Eagle3 Training (#141)
1 parent 7194c99 commit 7382a63

37 files changed

+2542
-734
lines changed

angelslim/compressor/speculative/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414

1515
from .benchmark import BenchmarkConfig, BenchmarkEngine, BenchmarkMode
1616
from .train import (
17-
DataCollatorWithPadding,
1817
DatasetManager,
1918
DraftModelConfig,
20-
OfflineEagle3Trainer,
21-
OnlineEagle3Trainer,
19+
Eagle3TrainerFactory,
2220
TargetHead,
2321
convert_sharegpt_data,
2422
convert_ultrachat_data,
@@ -35,10 +33,8 @@
3533
"create_draft_model",
3634
"DraftModelConfig",
3735
"create_target_model",
38-
"OnlineEagle3Trainer",
39-
"OfflineEagle3Trainer",
36+
"Eagle3TrainerFactory",
4037
"data_generation_work_flow",
41-
"DataCollatorWithPadding",
4238
"convert_sharegpt_data",
4339
"convert_ultrachat_data",
4440
"DatasetManager",

angelslim/compressor/speculative/train/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .data import (
2-
DataCollatorWithPadding,
32
DatasetManager,
43
convert_sharegpt_data,
54
convert_ultrachat_data,
@@ -12,16 +11,14 @@
1211
create_draft_model,
1312
create_target_model,
1413
)
15-
from .trainer import OfflineEagle3Trainer, OnlineEagle3Trainer
14+
from .trainer import Eagle3TrainerFactory
1615

1716
__all__ = [
1817
"create_draft_model",
1918
"DraftModelConfig",
2019
"create_target_model",
21-
"OnlineEagle3Trainer",
22-
"OfflineEagle3Trainer",
20+
"Eagle3TrainerFactory",
2321
"data_generation_work_flow",
24-
"DataCollatorWithPadding",
2522
"convert_sharegpt_data",
2623
"convert_ultrachat_data",
2724
"DatasetManager",
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{
2+
"architectures": [
3+
"Eagle3LlamaForCausalLM"
4+
],
5+
"model_type": "llama",
6+
"attention_bias": false,
7+
"attention_dropout": 0.0,
8+
"bos_token_id": 151643,
9+
"dtype": "bfloat16",
10+
"eos_token_id": 151645,
11+
"head_dim": 128,
12+
"hidden_act": "silu",
13+
"hidden_size": 2560,
14+
"initializer_range": 0.02,
15+
"intermediate_size": 9728,
16+
"max_position_embeddings": 262144,
17+
"num_attention_heads": 32,
18+
"num_hidden_layers": 1,
19+
"num_key_value_heads": 8,
20+
"rms_norm_eps": 1e-06,
21+
"rope_scaling": {
22+
"type": "default",
23+
"rope_type": "default",
24+
"mrope_interleaved": true,
25+
"mrope_section": [
26+
24,
27+
20,
28+
20
29+
]
30+
},
31+
"rope_theta": 5000000,
32+
"use_cache": true,
33+
"vocab_size": 151936,
34+
"tie_word_embeddings": true,
35+
"transformers_version": "4.57.1",
36+
"image_token_id": 151655,
37+
"video_token_id": 151656,
38+
"vision_end_token_id": 151653,
39+
"vision_start_token_id": 151652,
40+
"draft_vocab_size": 32000,
41+
"modal_type": "VLM"
42+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"architectures": [
3+
"Eagle3LlamaForCausalLM"
4+
],
5+
"model_type": "llama",
6+
"attention_bias": false,
7+
"attention_dropout": 0.0,
8+
"bos_token_id": 151643,
9+
"dtype": "bfloat16",
10+
"eos_token_id": 151645,
11+
"head_dim": 128,
12+
"hidden_act": "silu",
13+
"hidden_size": 2560,
14+
"initializer_range": 0.02,
15+
"intermediate_size": 9728,
16+
"max_position_embeddings": 262144,
17+
"num_attention_heads": 32,
18+
"num_hidden_layers": 1,
19+
"num_key_value_heads": 8,
20+
"rms_norm_eps": 1e-06,
21+
"rope_scaling": null,
22+
"rope_theta": 5000000,
23+
"use_cache": true,
24+
"vocab_size": 151936,
25+
"tie_word_embeddings": true,
26+
"transformers_version": "4.57.1",
27+
"image_token_id": 151655,
28+
"video_token_id": 151656,
29+
"vision_end_token_id": 151653,
30+
"vision_start_token_id": 151652,
31+
"draft_vocab_size": 32000,
32+
"modal_type": "VLM"
33+
}

angelslim/compressor/speculative/train/data/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,11 @@
1414

1515
from .chat_templates import get_supported_chat_template_type_strings
1616
from .data_generation import data_generation_work_flow
17-
from .data_utils import (
18-
DataCollatorWithPadding,
19-
convert_sharegpt_data,
20-
convert_ultrachat_data,
21-
)
17+
from .data_utils import convert_sharegpt_data, convert_ultrachat_data
2218
from .dataset import DatasetManager
2319

2420
__all__ = [
2521
"DatasetManager",
26-
"DataCollatorWithPadding",
2722
"convert_sharegpt_data",
2823
"convert_ultrachat_data",
2924
"data_generation_work_flow",

angelslim/compressor/speculative/train/data/chat_templates.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ class ChatTemplateType(Enum):
2828

2929
QWEN3 = "qwen3"
3030
HUNYUAN = "hunyuan"
31+
QWEN3_VL = "qwen3_vl"
3132

3233

3334
# String to ChatTemplateType mapping
3435
CHAT_TEMPLATE_TYPE_MAPPING = {
3536
"qwen3": ChatTemplateType.QWEN3,
3637
"hunyuan": ChatTemplateType.HUNYUAN,
38+
"qwen3_vl": ChatTemplateType.QWEN3_VL,
3739
}
3840

3941

@@ -93,6 +95,27 @@ def _initialize_templates(self) -> Dict[ChatTemplateType, ChatTemplate]:
9395
"please don't share false information."
9496
),
9597
),
98+
ChatTemplateType.QWEN3_VL: ChatTemplate(
99+
user_header="<|im_start|>user\n",
100+
assistant_header="<|im_start|>assistant\n",
101+
system_prompt=[
102+
{
103+
"type": "text",
104+
"text": (
105+
"You are a helpful, respectful and honest assistant. "
106+
"Always answer as helpfully as possible, while being safe. "
107+
"Your answers should not include any harmful, unethical, "
108+
"racist, sexist, toxic, dangerous, or illegal content. "
109+
"Please ensure that your responses are socially unbiased "
110+
"and positive in nature.\n\nIf a question does not make "
111+
"any sense, or is not factually coherent, explain why "
112+
"instead of answering something not correct. If you "
113+
"don't know the answer to a question, please don't share "
114+
"false information."
115+
),
116+
}
117+
],
118+
),
96119
}
97120

98121
def get_template(self, chat_template_type: ChatTemplateType) -> ChatTemplate:

angelslim/compressor/speculative/train/data/data_utils.py

Lines changed: 107 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"convert_sharegpt_data",
2222
"convert_ultrachat_data",
2323
"DataCollatorWithPadding",
24+
"VLMDataCollatorWithPadding",
2425
]
2526

2627

@@ -100,33 +101,58 @@ def process_token_dict_to_mappings(
100101
return d2t, t2d
101102

102103

104+
def paddingtensor(intensors, N):
105+
B, n, S = intensors.shape
106+
# padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
107+
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
108+
outtensors = torch.cat((intensors, padding_tensor), dim=1)
109+
return outtensors
110+
111+
112+
def paddingtensor2D(intensors, N):
113+
B, n = intensors.shape
114+
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
115+
outtensors = torch.cat((intensors, padding_tensor), dim=1)
116+
return outtensors
117+
118+
119+
def paddingtensor3D_CBN(tensor_list):
120+
N = max(tensor.shape[-1] for tensor in tensor_list)
121+
out_tensor_list = []
122+
for tensor in tensor_list:
123+
c, b, n = tensor.shape
124+
outtensor = torch.zeros(c, b, N, dtype=tensor_list[0].dtype)
125+
outtensor[:, :, :n] = tensor
126+
out_tensor_list.append(outtensor)
127+
return torch.cat(out_tensor_list, dim=1)
128+
129+
130+
def paddingtensor3D_BHW(tensor_list):
131+
max_h = max(tensor.shape[-2] for tensor in tensor_list)
132+
max_w = max(tensor.shape[-1] for tensor in tensor_list)
133+
out_tensor_list = []
134+
for tensor in tensor_list:
135+
if tensor.ndim == 2:
136+
tensor = tensor.unsqueeze(0)
137+
b, h, w = tensor.shape
138+
outtensor = torch.zeros(b, max_h, max_w, dtype=tensor.dtype)
139+
outtensor[:, :h, :w] = tensor
140+
out_tensor_list.append(outtensor)
141+
return torch.cat(out_tensor_list)
142+
143+
103144
class DataCollatorWithPadding:
104-
def paddingtensor(self, intensors, N):
105-
B, n, S = intensors.shape
106-
# padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
107-
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
108-
outtensors = torch.cat((intensors, padding_tensor), dim=1)
109-
return outtensors
110-
111-
def paddingtensor2D(self, intensors, N):
112-
B, n = intensors.shape
113-
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
114-
outtensors = torch.cat((intensors, padding_tensor), dim=1)
115-
return outtensors
116145

117146
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
118147
max_length = max(item["input_ids"].shape[1] for item in features)
119148
batch_input_ids = torch.cat(
120-
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]
149+
[paddingtensor2D(item["input_ids"], max_length) for item in features]
121150
)
122151
batch_attention_mask = torch.cat(
123-
[
124-
self.paddingtensor2D(item["attention_mask"], max_length)
125-
for item in features
126-
]
152+
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
127153
)
128154
batch_loss_mask = torch.cat(
129-
[self.paddingtensor2D(item["loss_mask"], max_length) for item in features]
155+
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
130156
)
131157

132158
batch = {
@@ -142,15 +168,70 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
142168
"hidden_states" in item and "target_hiddens" in item for item in features
143169
):
144170
batch["hidden_states"] = torch.cat(
145-
[
146-
self.paddingtensor(item["hidden_states"], max_length)
147-
for item in features
148-
]
171+
[paddingtensor(item["hidden_states"], max_length) for item in features]
149172
)
150173
batch["target_hiddens"] = torch.cat(
151-
[
152-
self.paddingtensor(item["target_hiddens"], max_length)
153-
for item in features
154-
]
174+
[paddingtensor(item["target_hiddens"], max_length) for item in features]
155175
)
156176
return batch
177+
178+
179+
class VLMDataCollatorWithPadding:
180+
181+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
182+
max_length = max(item["input_ids"].shape[1] for item in features)
183+
batch_input_ids = torch.cat(
184+
[paddingtensor2D(item["input_ids"], max_length) for item in features]
185+
)
186+
batch_attention_mask = torch.cat(
187+
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
188+
)
189+
batch_loss_mask = torch.cat(
190+
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
191+
)
192+
193+
batch = {
194+
"input_ids": batch_input_ids,
195+
"attention_mask": batch_attention_mask,
196+
"loss_mask": batch_loss_mask,
197+
"hidden_states": None,
198+
"target_hiddens": None,
199+
"inputs_embeds": None,
200+
"position_ids": None,
201+
}
202+
203+
if "pixel_values" in features[0]:
204+
batch["pixel_values"] = paddingtensor3D_BHW(
205+
[item["pixel_values"] for item in features]
206+
)
207+
if "video_pixel_values" in features[0]:
208+
batch["video_pixel_values"] = paddingtensor3D_BHW(
209+
[item["video_pixel_values"] for item in features]
210+
)
211+
if "image_grid_thw" in features[0]:
212+
batch["image_grid_thw"] = paddingtensor3D_BHW(
213+
[item["image_grid_thw"] for item in features]
214+
)
215+
if "video_grid_thw" in features[0]:
216+
batch["video_grid_thw"] = paddingtensor3D_BHW(
217+
[item["video_grid_thw"] for item in features]
218+
)
219+
220+
# Check if both hidden_states and target_hiddens exist in all features
221+
if all(
222+
"hidden_states" in item and "target_hiddens" in item for item in features
223+
):
224+
batch["hidden_states"] = torch.cat(
225+
[paddingtensor(item["hidden_states"], max_length) for item in features]
226+
)
227+
batch["target_hiddens"] = torch.cat(
228+
[paddingtensor(item["target_hiddens"], max_length) for item in features]
229+
)
230+
batch["inputs_embeds"] = torch.cat(
231+
[paddingtensor(item["inputs_embeds"], max_length) for item in features]
232+
)
233+
batch["position_ids"] = paddingtensor3D_CBN(
234+
[item["position_ids"] for item in features]
235+
)
236+
237+
return batch

0 commit comments

Comments
 (0)