Skip to content

Commit 61cd1a2

Browse files
committed
support vlm offline hf training
1 parent 6ab0a40 commit 61cd1a2

File tree

10 files changed

+425
-21
lines changed

10 files changed

+425
-21
lines changed

angelslim/compressor/speculative/train/data/data_utils.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,18 @@ def paddingtensor2D(intensors, N):
116116
return outtensors
117117

118118

119-
def paddingtensor3D(tensor_list):
119+
def paddingtensor3D_CBN(tensor_list):
120+
N = max(tensor.shape[-1] for tensor in tensor_list)
121+
out_tensor_list = []
122+
for tensor in tensor_list:
123+
c, b, n = tensor.shape
124+
outtensor = torch.zeros(c, b, N, dtype=tensor_list[0].dtype)
125+
outtensor[:, :, :n] = tensor
126+
out_tensor_list.append(outtensor)
127+
return torch.cat(out_tensor_list, dim=1)
128+
129+
130+
def paddingtensor3D_BHW(tensor_list):
120131
max_h = max(tensor.shape[-2] for tensor in tensor_list)
121132
max_w = max(tensor.shape[-1] for tensor in tensor_list)
122133
out_tensor_list = []
@@ -183,23 +194,44 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
183194
"input_ids": batch_input_ids,
184195
"attention_mask": batch_attention_mask,
185196
"loss_mask": batch_loss_mask,
197+
"hidden_states": None,
198+
"target_hiddens": None,
199+
"inputs_embeds": None,
200+
"position_ids": None,
186201
}
187202

188203
if "pixel_values" in features[0]:
189-
batch["pixel_values"] = paddingtensor3D(
204+
batch["pixel_values"] = paddingtensor3D_BHW(
190205
[item["pixel_values"] for item in features]
191206
)
192207
if "video_pixel_values" in features[0]:
193-
batch["video_pixel_values"] = paddingtensor3D(
208+
batch["video_pixel_values"] = paddingtensor3D_BHW(
194209
[item["video_pixel_values"] for item in features]
195210
)
196211
if "image_grid_thw" in features[0]:
197-
batch["image_grid_thw"] = paddingtensor3D(
212+
batch["image_grid_thw"] = paddingtensor3D_BHW(
198213
[item["image_grid_thw"] for item in features]
199214
)
200215
if "video_grid_thw" in features[0]:
201-
batch["video_grid_thw"] = paddingtensor3D(
216+
batch["video_grid_thw"] = paddingtensor3D_BHW(
202217
[item["video_grid_thw"] for item in features]
203218
)
204219

220+
# Check if both hidden_states and target_hiddens exist in all features
221+
if all(
222+
"hidden_states" in item and "target_hiddens" in item for item in features
223+
):
224+
batch["hidden_states"] = torch.cat(
225+
[paddingtensor(item["hidden_states"], max_length) for item in features]
226+
)
227+
batch["target_hiddens"] = torch.cat(
228+
[paddingtensor(item["target_hiddens"], max_length) for item in features]
229+
)
230+
batch["inputs_embeds"] = torch.cat(
231+
[paddingtensor(item["inputs_embeds"], max_length) for item in features]
232+
)
233+
batch["position_ids"] = paddingtensor3D_CBN(
234+
[item["position_ids"] for item in features]
235+
)
236+
205237
return batch

angelslim/compressor/speculative/train/data/dataset_builder/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
from .dataset_builder_factory import DatasetBuilderFactory
1616
from .offline_llm_dataset_builder import OfflineLLMDatasetBuilder
17+
from .offline_vlm_dataset_builder import OfflineVLMDatasetBuilder
1718
from .online_llm_dataset_builder import OnlineLLMDatasetBuilder
1819
from .online_vlm_dataset_builder import OnlineVLMDatasetBuilder
1920

2021
__all__ = [
2122
"OnlineLLMDatasetBuilder",
2223
"OnlineVLMDatasetBuilder",
2324
"OfflineLLMDatasetBuilder",
25+
"OfflineVLMDatasetBuilder",
2426
"DatasetBuilderFactory",
2527
]
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import warnings
16+
from pathlib import Path
17+
from typing import Any, Dict, List, Optional
18+
19+
import torch
20+
from torch.utils.data import Dataset
21+
22+
from angelslim.utils import rank0_print
23+
24+
from ..data_utils import VLMDataCollatorWithPadding
25+
from .base_dataset_builder import DatasetBuilder
26+
from .dataset_builder_factory import DatasetBuilderFactory
27+
28+
29+
class OfflineVLMEagle3Dataset(Dataset):
30+
"""
31+
Offline Dataset for EAGLE3 training.
32+
33+
Loads pre-computed hidden states, logits, and other data from .ckpt files.
34+
Each .ckpt file contains a dictionary with keys: input_ids, target_logits,
35+
hidden_states, and loss_mask.
36+
"""
37+
38+
def __init__(
39+
self, data_dir: str, file_pattern: str = "*.ckpt", cache_in_memory: bool = False
40+
):
41+
"""
42+
Initialize the OfflineVLMEagle3Dataset.
43+
44+
Args:
45+
data_dir: Directory containing .ckpt files
46+
(will search recursively in subdirectories)
47+
file_pattern: Pattern to match checkpoint files (default: "*.ckpt")
48+
cache_in_memory: Whether to cache all data in memory (default: False)
49+
"""
50+
self.data_dir = Path(data_dir)
51+
self.cache_in_memory = cache_in_memory
52+
53+
if not self.data_dir.exists():
54+
raise ValueError(f"Data directory does not exist: {data_dir}")
55+
56+
# Recursively find all checkpoint files in subdirectories
57+
self.ckpt_files = sorted(list(self.data_dir.rglob(file_pattern)))
58+
59+
if len(self.ckpt_files) == 0:
60+
raise ValueError(
61+
f"No checkpoint files found in {data_dir} "
62+
f"(including subdirectories) with pattern {file_pattern}"
63+
)
64+
65+
rank0_print(
66+
f"Found {len(self.ckpt_files)} checkpoint files "
67+
f"in {data_dir} (including subdirectories)"
68+
)
69+
70+
# Track valid indices (files that can be loaded successfully)
71+
self.valid_indices = list(range(len(self.ckpt_files)))
72+
73+
# Cache data in memory if requested
74+
self.cached_data: Optional[List[Dict[str, torch.Tensor]]] = None
75+
if self.cache_in_memory:
76+
rank0_print("Caching all data in memory...")
77+
self.cached_data = []
78+
failed_count = 0
79+
for i in range(len(self.ckpt_files)):
80+
data = self._load_ckpt(i)
81+
if data is not None:
82+
self.cached_data.append(data)
83+
else:
84+
failed_count += 1
85+
86+
# Update valid indices based on successful loads
87+
self.valid_indices = list(range(len(self.cached_data)))
88+
89+
if failed_count > 0:
90+
rank0_print(
91+
f"Data caching completed. "
92+
f"Successfully loaded {len(self.cached_data)} files, "
93+
f"failed to load {failed_count} files"
94+
)
95+
else:
96+
rank0_print("Data caching completed")
97+
98+
def _load_ckpt(self, idx: int) -> Optional[Dict[str, torch.Tensor]]:
99+
"""
100+
Load a checkpoint file.
101+
102+
Args:
103+
idx: Index of the checkpoint file
104+
105+
Returns:
106+
Dictionary containing input_ids, target_hiddens,
107+
hidden_states, and loss_mask, or None if loading fails
108+
"""
109+
ckpt_path = self.ckpt_files[idx]
110+
111+
try:
112+
data = torch.load(ckpt_path, map_location="cpu")
113+
except Exception as e:
114+
warnings.warn(
115+
f"Failed to load checkpoint {ckpt_path}: {e}. Skipping this file.",
116+
RuntimeWarning,
117+
stacklevel=2,
118+
)
119+
return None
120+
121+
# Validate required keys
122+
required_keys = [
123+
"input_ids", # B, N
124+
"target_hiddens", # B, N, D
125+
"hidden_states", # B, N, 3*D
126+
"loss_mask", # B, N
127+
"inputs_embeds", # B, N, D
128+
"position_ids", # B, N
129+
]
130+
missing_keys = [key for key in required_keys if key not in data]
131+
132+
if missing_keys:
133+
warnings.warn(
134+
f"Checkpoint {ckpt_path} is missing required keys: {missing_keys}. "
135+
f"Skipping this file.",
136+
RuntimeWarning,
137+
stacklevel=2,
138+
)
139+
return None
140+
141+
# Validate tensor types
142+
for key in required_keys:
143+
if not isinstance(data[key], torch.Tensor):
144+
warnings.warn(
145+
f"Value for key '{key}' in {ckpt_path} is not a torch.Tensor. "
146+
f"Skipping this file.",
147+
RuntimeWarning,
148+
stacklevel=2,
149+
)
150+
return None
151+
152+
attention_mask = torch.ones_like(data["input_ids"])
153+
data["attention_mask"] = attention_mask # B, N
154+
return data
155+
156+
def __len__(self) -> int:
157+
"""Return the number of valid samples in the dataset."""
158+
if self.cached_data is not None:
159+
return len(self.cached_data)
160+
return len(self.valid_indices)
161+
162+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
163+
"""
164+
Get a sample from the dataset.
165+
166+
Args:
167+
idx: Index of the sample
168+
169+
Returns:
170+
Dictionary containing:
171+
- input_ids: Token IDs (torch.Tensor)
172+
- target_logits: Pre-computed logits from target
173+
model (torch.Tensor)
174+
- hidden_states: Pre-computed hidden states from
175+
target model (torch.Tensor)
176+
- loss_mask: Mask for loss computation (torch.Tensor)
177+
"""
178+
if self.cached_data is not None:
179+
return self.cached_data[idx]
180+
else:
181+
# Try to load the checkpoint, retry with next valid index if fails
182+
max_retries = len(self.valid_indices)
183+
for _attempt in range(max_retries):
184+
actual_idx = self.valid_indices[idx % len(self.valid_indices)]
185+
data = self._load_ckpt(actual_idx)
186+
if data is not None:
187+
return data
188+
else:
189+
# Remove failed index from valid_indices
190+
self.valid_indices.remove(actual_idx)
191+
if len(self.valid_indices) == 0:
192+
raise RuntimeError(
193+
"All checkpoint files failed to load. "
194+
"Cannot continue training."
195+
)
196+
# Try next index
197+
idx += 1
198+
199+
# If all retries failed, raise error
200+
raise RuntimeError(
201+
f"Failed to load any valid checkpoint after {max_retries} attempts"
202+
)
203+
204+
205+
@DatasetBuilderFactory.register("offline", "VLM")
206+
class OfflineVLMDatasetBuilder(DatasetBuilder):
207+
def __init__(
208+
self, file_pattern: str = "*.ckpt", cache_in_memory: bool = False, **kwargs: Any
209+
):
210+
self.file_pattern = file_pattern
211+
self.cache_in_memory = cache_in_memory
212+
213+
def build_dataset(self, datapath: str, **kwargs: Any) -> Dataset:
214+
"""
215+
Create offline datasets from pre-computed .ckpt files.
216+
"""
217+
return OfflineVLMEagle3Dataset(
218+
data_dir=datapath,
219+
file_pattern=self.file_pattern,
220+
cache_in_memory=self.cache_in_memory,
221+
)
222+
223+
def get_data_collator(self) -> Any:
224+
return VLMDataCollatorWithPadding()

angelslim/compressor/speculative/train/models/target/target_head.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6464

6565
@classmethod
6666
def from_pretrained(
67-
cls, model_name_or_path: str, lm_head_key: str = "lm_head.weight"
67+
cls,
68+
model_name_or_path: str,
69+
lm_head_key: str = "lm_head.weight",
70+
sub_config_name=None,
6871
):
6972
"""
7073
Load TargetHead from a pretrained model efficiently.
@@ -82,6 +85,12 @@ def from_pretrained(
8285
"""
8386
# Load model config to get architecture info
8487
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
88+
if hasattr(config, sub_config_name):
89+
config = getattr(config, sub_config_name)
90+
else:
91+
raise ValueError(
92+
f"Config {config} has no sub-config named {sub_config_name}"
93+
)
8594

8695
# Get model dimensions
8796
hidden_size = config.hidden_size

0 commit comments

Comments
 (0)