Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions angelslim/compressor/speculative/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

from .benchmark import BenchmarkConfig, BenchmarkEngine, BenchmarkMode
from .train import (
DataCollatorWithPadding,
DatasetManager,
DraftModelConfig,
OfflineEagle3Trainer,
OnlineEagle3Trainer,
Eagle3TrainerFactory,
TargetHead,
convert_sharegpt_data,
convert_ultrachat_data,
Expand All @@ -35,10 +33,8 @@
"create_draft_model",
"DraftModelConfig",
"create_target_model",
"OnlineEagle3Trainer",
"OfflineEagle3Trainer",
"Eagle3TrainerFactory",
"data_generation_work_flow",
"DataCollatorWithPadding",
"convert_sharegpt_data",
"convert_ultrachat_data",
"DatasetManager",
Expand Down
7 changes: 2 additions & 5 deletions angelslim/compressor/speculative/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .data import (
DataCollatorWithPadding,
DatasetManager,
convert_sharegpt_data,
convert_ultrachat_data,
Expand All @@ -12,16 +11,14 @@
create_draft_model,
create_target_model,
)
from .trainer import OfflineEagle3Trainer, OnlineEagle3Trainer
from .trainer import Eagle3TrainerFactory

__all__ = [
"create_draft_model",
"DraftModelConfig",
"create_target_model",
"OnlineEagle3Trainer",
"OfflineEagle3Trainer",
"Eagle3TrainerFactory",
"data_generation_work_flow",
"DataCollatorWithPadding",
"convert_sharegpt_data",
"convert_ultrachat_data",
"DatasetManager",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 9728,
"max_position_embeddings": 262144,
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"type": "default",
"rope_type": "default",
"mrope_interleaved": true,
"mrope_section": [
24,
20,
20
]
},
"rope_theta": 5000000,
"use_cache": true,
"vocab_size": 151936,
"tie_word_embeddings": true,
"transformers_version": "4.57.1",
"image_token_id": 151655,
"video_token_id": 151656,
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"draft_vocab_size": 32000,
"modal_type": "VLM"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"architectures": [
"Eagle3LlamaForCausalLM"
],
"model_type": "llama",
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 9728,
"max_position_embeddings": 262144,
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 5000000,
"use_cache": true,
"vocab_size": 151936,
"tie_word_embeddings": true,
"transformers_version": "4.57.1",
"image_token_id": 151655,
"video_token_id": 151656,
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"draft_vocab_size": 32000,
"modal_type": "VLM"
}
7 changes: 1 addition & 6 deletions angelslim/compressor/speculative/train/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@

from .chat_templates import get_supported_chat_template_type_strings
from .data_generation import data_generation_work_flow
from .data_utils import (
DataCollatorWithPadding,
convert_sharegpt_data,
convert_ultrachat_data,
)
from .data_utils import convert_sharegpt_data, convert_ultrachat_data
from .dataset import DatasetManager

__all__ = [
"DatasetManager",
"DataCollatorWithPadding",
"convert_sharegpt_data",
"convert_ultrachat_data",
"data_generation_work_flow",
Expand Down
23 changes: 23 additions & 0 deletions angelslim/compressor/speculative/train/data/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ class ChatTemplateType(Enum):

QWEN3 = "qwen3"
HUNYUAN = "hunyuan"
QWEN3_VL = "qwen3_vl"


# String to ChatTemplateType mapping
CHAT_TEMPLATE_TYPE_MAPPING = {
"qwen3": ChatTemplateType.QWEN3,
"hunyuan": ChatTemplateType.HUNYUAN,
"qwen3_vl": ChatTemplateType.QWEN3_VL,
}


Expand Down Expand Up @@ -93,6 +95,27 @@ def _initialize_templates(self) -> Dict[ChatTemplateType, ChatTemplate]:
"please don't share false information."
),
),
ChatTemplateType.QWEN3_VL: ChatTemplate(
user_header="<|im_start|>user\n",
assistant_header="<|im_start|>assistant\n",
system_prompt=[
{
"type": "text",
"text": (
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased "
"and positive in nature.\n\nIf a question does not make "
"any sense, or is not factually coherent, explain why "
"instead of answering something not correct. If you "
"don't know the answer to a question, please don't share "
"false information."
),
}
],
),
}

def get_template(self, chat_template_type: ChatTemplateType) -> ChatTemplate:
Expand Down
133 changes: 107 additions & 26 deletions angelslim/compressor/speculative/train/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"convert_sharegpt_data",
"convert_ultrachat_data",
"DataCollatorWithPadding",
"VLMDataCollatorWithPadding",
]


Expand Down Expand Up @@ -100,33 +101,58 @@ def process_token_dict_to_mappings(
return d2t, t2d


def paddingtensor(intensors, N):
B, n, S = intensors.shape
# padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
outtensors = torch.cat((intensors, padding_tensor), dim=1)
return outtensors


def paddingtensor2D(intensors, N):
B, n = intensors.shape
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
outtensors = torch.cat((intensors, padding_tensor), dim=1)
return outtensors


def paddingtensor3D_CBN(tensor_list):
N = max(tensor.shape[-1] for tensor in tensor_list)
out_tensor_list = []
for tensor in tensor_list:
c, b, n = tensor.shape
outtensor = torch.zeros(c, b, N, dtype=tensor_list[0].dtype)
outtensor[:, :, :n] = tensor
out_tensor_list.append(outtensor)
return torch.cat(out_tensor_list, dim=1)


def paddingtensor3D_BHW(tensor_list):
max_h = max(tensor.shape[-2] for tensor in tensor_list)
max_w = max(tensor.shape[-1] for tensor in tensor_list)
out_tensor_list = []
for tensor in tensor_list:
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
b, h, w = tensor.shape
outtensor = torch.zeros(b, max_h, max_w, dtype=tensor.dtype)
outtensor[:, :h, :w] = tensor
out_tensor_list.append(outtensor)
return torch.cat(out_tensor_list)


class DataCollatorWithPadding:
def paddingtensor(self, intensors, N):
B, n, S = intensors.shape
# padding_tensor = torch.zeros(B, N - n, S,dtype=intensors.dtype)
padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype)
outtensors = torch.cat((intensors, padding_tensor), dim=1)
return outtensors

def paddingtensor2D(self, intensors, N):
B, n = intensors.shape
padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype)
outtensors = torch.cat((intensors, padding_tensor), dim=1)
return outtensors

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
max_length = max(item["input_ids"].shape[1] for item in features)
batch_input_ids = torch.cat(
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]
[paddingtensor2D(item["input_ids"], max_length) for item in features]
)
batch_attention_mask = torch.cat(
[
self.paddingtensor2D(item["attention_mask"], max_length)
for item in features
]
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
)
batch_loss_mask = torch.cat(
[self.paddingtensor2D(item["loss_mask"], max_length) for item in features]
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
)

batch = {
Expand All @@ -142,15 +168,70 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
"hidden_states" in item and "target_hiddens" in item for item in features
):
batch["hidden_states"] = torch.cat(
[
self.paddingtensor(item["hidden_states"], max_length)
for item in features
]
[paddingtensor(item["hidden_states"], max_length) for item in features]
)
batch["target_hiddens"] = torch.cat(
[
self.paddingtensor(item["target_hiddens"], max_length)
for item in features
]
[paddingtensor(item["target_hiddens"], max_length) for item in features]
)
return batch


class VLMDataCollatorWithPadding:

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
max_length = max(item["input_ids"].shape[1] for item in features)
batch_input_ids = torch.cat(
[paddingtensor2D(item["input_ids"], max_length) for item in features]
)
batch_attention_mask = torch.cat(
[paddingtensor2D(item["attention_mask"], max_length) for item in features]
)
batch_loss_mask = torch.cat(
[paddingtensor2D(item["loss_mask"], max_length) for item in features]
)

batch = {
"input_ids": batch_input_ids,
"attention_mask": batch_attention_mask,
"loss_mask": batch_loss_mask,
"hidden_states": None,
"target_hiddens": None,
"inputs_embeds": None,
"position_ids": None,
}

if "pixel_values" in features[0]:
batch["pixel_values"] = paddingtensor3D_BHW(
[item["pixel_values"] for item in features]
)
if "video_pixel_values" in features[0]:
batch["video_pixel_values"] = paddingtensor3D_BHW(
[item["video_pixel_values"] for item in features]
)
if "image_grid_thw" in features[0]:
batch["image_grid_thw"] = paddingtensor3D_BHW(
[item["image_grid_thw"] for item in features]
)
if "video_grid_thw" in features[0]:
batch["video_grid_thw"] = paddingtensor3D_BHW(
[item["video_grid_thw"] for item in features]
)

# Check if both hidden_states and target_hiddens exist in all features
if all(
"hidden_states" in item and "target_hiddens" in item for item in features
):
batch["hidden_states"] = torch.cat(
[paddingtensor(item["hidden_states"], max_length) for item in features]
)
batch["target_hiddens"] = torch.cat(
[paddingtensor(item["target_hiddens"], max_length) for item in features]
)
batch["inputs_embeds"] = torch.cat(
[paddingtensor(item["inputs_embeds"], max_length) for item in features]
)
batch["position_ids"] = paddingtensor3D_CBN(
[item["position_ids"] for item in features]
)

return batch
Loading
Loading