Skip to content

Commit f3cad9f

Browse files
authored
[Feature] support offline eagle3 training (#142)
1 parent 9f6a5ce commit f3cad9f

File tree

18 files changed

+2007
-360
lines changed

18 files changed

+2007
-360
lines changed

angelslim/compressor/speculative/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
DataCollatorWithPadding,
1818
DatasetManager,
1919
DraftModelConfig,
20+
OfflineEagle3Trainer,
2021
OnlineEagle3Trainer,
22+
TargetHead,
2123
convert_sharegpt_data,
2224
convert_ultrachat_data,
2325
create_draft_model,
@@ -34,10 +36,12 @@
3436
"DraftModelConfig",
3537
"create_target_model",
3638
"OnlineEagle3Trainer",
39+
"OfflineEagle3Trainer",
3740
"data_generation_work_flow",
3841
"DataCollatorWithPadding",
3942
"convert_sharegpt_data",
4043
"convert_ultrachat_data",
4144
"DatasetManager",
4245
"get_supported_chat_template_type_strings",
46+
"TargetHead",
4347
]

angelslim/compressor/speculative/train/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,25 @@
66
data_generation_work_flow,
77
get_supported_chat_template_type_strings,
88
)
9-
from .models import DraftModelConfig, create_draft_model, create_target_model
10-
from .trainer import OnlineEagle3Trainer
9+
from .models import (
10+
DraftModelConfig,
11+
TargetHead,
12+
create_draft_model,
13+
create_target_model,
14+
)
15+
from .trainer import OfflineEagle3Trainer, OnlineEagle3Trainer
1116

1217
__all__ = [
1318
"create_draft_model",
1419
"DraftModelConfig",
1520
"create_target_model",
1621
"OnlineEagle3Trainer",
22+
"OfflineEagle3Trainer",
1723
"data_generation_work_flow",
1824
"DataCollatorWithPadding",
1925
"convert_sharegpt_data",
2026
"convert_ultrachat_data",
2127
"DatasetManager",
2228
"get_supported_chat_template_type_strings",
29+
"TargetHead",
2330
]

angelslim/compressor/speculative/train/data/data_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,24 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
133133
"input_ids": batch_input_ids,
134134
"attention_mask": batch_attention_mask,
135135
"loss_mask": batch_loss_mask,
136+
"hidden_states": None,
137+
"target_hiddens": None,
136138
}
139+
140+
# Check if both hidden_states and target_hiddens exist in all features
141+
if all(
142+
"hidden_states" in item and "target_hiddens" in item for item in features
143+
):
144+
batch["hidden_states"] = torch.cat(
145+
[
146+
self.paddingtensor(item["hidden_states"], max_length)
147+
for item in features
148+
]
149+
)
150+
batch["target_hiddens"] = torch.cat(
151+
[
152+
self.paddingtensor(item["target_hiddens"], max_length)
153+
for item in features
154+
]
155+
)
137156
return batch

0 commit comments

Comments
 (0)