Skip to content

Commit e597014

Browse files
committed
add vlm train factory
1 parent 9eebed6 commit e597014

17 files changed

+1277
-578
lines changed

angelslim/compressor/speculative/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414

1515
from .benchmark import BenchmarkConfig, BenchmarkEngine, BenchmarkMode
1616
from .train import (
17-
DataCollatorWithPadding,
1817
DatasetManager,
1918
DraftModelConfig,
20-
OnlineEagle3Trainer,
19+
Eagle3TrainerFactory,
2120
convert_sharegpt_data,
2221
convert_ultrachat_data,
2322
create_draft_model,
@@ -33,9 +32,8 @@
3332
"create_draft_model",
3433
"DraftModelConfig",
3534
"create_target_model",
36-
"OnlineEagle3Trainer",
35+
"Eagle3TrainerFactory",
3736
"data_generation_work_flow",
38-
"DataCollatorWithPadding",
3937
"convert_sharegpt_data",
4038
"convert_ultrachat_data",
4139
"DatasetManager",

angelslim/compressor/speculative/train/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
from .data import (
2-
DataCollatorWithPadding,
32
DatasetManager,
43
convert_sharegpt_data,
54
convert_ultrachat_data,
65
data_generation_work_flow,
76
get_supported_chat_template_type_strings,
87
)
98
from .models import DraftModelConfig, create_draft_model, create_target_model
10-
from .trainer import OnlineEagle3Trainer
9+
from .trainer import Eagle3TrainerFactory
1110

1211
__all__ = [
1312
"create_draft_model",
1413
"DraftModelConfig",
1514
"create_target_model",
16-
"OnlineEagle3Trainer",
15+
"Eagle3TrainerFactory",
1716
"data_generation_work_flow",
18-
"DataCollatorWithPadding",
1917
"convert_sharegpt_data",
2018
"convert_ultrachat_data",
2119
"DatasetManager",

angelslim/compressor/speculative/train/data/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,11 @@
1414

1515
from .chat_templates import get_supported_chat_template_type_strings
1616
from .data_generation import data_generation_work_flow
17-
from .data_utils import (
18-
DataCollatorWithPadding,
19-
convert_sharegpt_data,
20-
convert_ultrachat_data,
21-
)
17+
from .data_utils import convert_sharegpt_data, convert_ultrachat_data
2218
from .dataset import DatasetManager
2319

2420
__all__ = [
2521
"DatasetManager",
26-
"DataCollatorWithPadding",
2722
"convert_sharegpt_data",
2823
"convert_ultrachat_data",
2924
"data_generation_work_flow",

angelslim/compressor/speculative/train/data/data_utils.py

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"convert_sharegpt_data",
2222
"convert_ultrachat_data",
2323
"DataCollatorWithPadding",
24+
"VLMDataCollatorWithPadding",
2425
]
2526

2627

@@ -100,38 +101,92 @@ def process_token_dict_to_mappings(
100101
return d2t, t2d
101102

102103

104+
def paddingtensor(intensors, N):
105+
B, n, S = intensors.shape
106+
# padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
107+
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
108+
outtensors = torch.cat((intensors, padding_tensor), dim=1)
109+
return outtensors
110+
111+
112+
def paddingtensor2D(intensors, N):
113+
B, n = intensors.shape
114+
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
115+
outtensors = torch.cat((intensors, padding_tensor), dim=1)
116+
return outtensors
117+
118+
119+
def paddingtensor3D(tensor_list):
120+
max_h = max(tensor.shape[-2] for tensor in tensor_list)
121+
max_w = max(tensor.shape[-1] for tensor in tensor_list)
122+
out_tensor_list = []
123+
for tensor in tensor_list:
124+
if tensor.ndim == 2:
125+
tensor = tensor.unsqueeze(0)
126+
b, h, w = tensor.shape
127+
outtensor = torch.zeros(b, max_h, max_w, dtype=tensor.dtype)
128+
outtensor[:, :h, :w] = tensor
129+
out_tensor_list.append(outtensor)
130+
return torch.cat(out_tensor_list)
131+
132+
103133
class DataCollatorWithPadding:
104-
def paddingtensor(self, intensors, N):
105-
B, n, S = intensors.shape
106-
# padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
107-
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
108-
outtensors = torch.cat((intensors, padding_tensor), dim=1)
109-
return outtensors
110-
111-
def paddingtensor2D(self, intensors, N):
112-
B, n = intensors.shape
113-
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
114-
outtensors = torch.cat((intensors, padding_tensor), dim=1)
115-
return outtensors
116134

117135
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
118136
max_length = max(item["input_ids"].shape[1] for item in features)
119137
batch_input_ids = torch.cat(
120-
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]
138+
[paddingtensor2D(item["input_ids"], max_length) for item in features]
139+
)
140+
batch_attention_mask = torch.cat(
141+
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
142+
)
143+
batch_loss_mask = torch.cat(
144+
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
145+
)
146+
147+
batch = {
148+
"input_ids": batch_input_ids,
149+
"attention_mask": batch_attention_mask,
150+
"loss_mask": batch_loss_mask,
151+
}
152+
return batch
153+
154+
155+
class VLMDataCollatorWithPadding:
156+
157+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
158+
max_length = max(item["input_ids"].shape[1] for item in features)
159+
batch_input_ids = torch.cat(
160+
[paddingtensor2D(item["input_ids"], max_length) for item in features]
121161
)
122162
batch_attention_mask = torch.cat(
123-
[
124-
self.paddingtensor2D(item["attention_mask"], max_length)
125-
for item in features
126-
]
163+
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
127164
)
128165
batch_loss_mask = torch.cat(
129-
[self.paddingtensor2D(item["loss_mask"], max_length) for item in features]
166+
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
130167
)
131168

132169
batch = {
133170
"input_ids": batch_input_ids,
134171
"attention_mask": batch_attention_mask,
135172
"loss_mask": batch_loss_mask,
136173
}
174+
175+
if "pixel_values" in features[0]:
176+
batch["pixel_values"] = paddingtensor3D(
177+
[item["pixel_values"] for item in features]
178+
)
179+
if "video_pixel_values" in features[0]:
180+
batch["video_pixel_values"] = paddingtensor3D(
181+
[item["video_pixel_values"] for item in features]
182+
)
183+
if "image_grid_thw" in features[0]:
184+
batch["image_grid_thw"] = paddingtensor3D(
185+
[item["image_grid_thw"] for item in features]
186+
)
187+
if "video_grid_thw" in features[0]:
188+
batch["video_grid_thw"] = paddingtensor3D(
189+
[item["video_grid_thw"] for item in features]
190+
)
191+
137192
return batch

0 commit comments

Comments
 (0)