Skip to content
55 changes: 15 additions & 40 deletions examples/quantization_w8a8_fp8/granite4_example.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,42 @@
from compressed_tensors.utils import replace_module
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
GraniteMoeHybridParallelExperts,
)

from llmcompressor import oneshot
from llmcompressor.modeling.granite4 import GraniteMoeHybridParallelExpertsLinear
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modeling import replace_modules_for_calibration

"""Please see details in `README_granite4.md`."""
MODEL_ID = "ibm-granite/granite-4.0-h-small"

MODEL_ID = "ibm-granite/granite-4.0-tiny-preview"

# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

skip_router_only = True # assume we want to quantize input/output moe layers
ignore_lay = [
"lm_head",
]
if skip_router_only:
# swap moe linears to a custom class
for n, m in model.named_modules():
if isinstance(m, GraniteMoeHybridParallelExperts):
new_mod = GraniteMoeHybridParallelExpertsLinear.from_3d_expert(m)
replace_module(model, n, new_mod)
ignore_lay += ["re:.*block_sparse_moe.router"]
SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter"
else:
# Skip all .input_linear, .output-linear, and router layers.
ignore_lay += ["re:.*block_sparse_moe"]
SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoe"
model = replace_modules_for_calibration(model)

ignore_lay = ["lm_head"]

recipe = QuantizationModifier(
targets=["Linear", "GraniteMoeHybridParallelExpertsLinear"],
targets=["Linear"],
scheme="FP8_DYNAMIC",
ignore=ignore_lay,
)

# Apply quantization.
oneshot(model=model, recipe=recipe)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer(
"What is your favorite TV show?", return_tensors="pt"
).input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
"Describe Large Language Model", return_tensors="pt"
).input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=35)
print(tokenizer.decode(output[0]))
print("==========================================")

# Revert weights of MoE experts to 3D format (num_experts, output_size, input_size)
for n, m in model.named_modules():
if isinstance(m, GraniteMoeHybridParallelExpertsLinear):
# NOTE: can assert type != "meta" instead, which is sign of offloading
assert m.weight.device.type == "cuda", (
"Found some offloaded weights. This is not compatible with reshaping "
"experts to 3D prior model save. Ensure the model is fully on cuda."
)
m.to_3d_expert()
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-block"
print(f"Saving to {SAVE_DIR}")

model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
42 changes: 42 additions & 0 deletions examples/quantization_w8a8_fp8/granite4_fp8_block_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modeling import replace_modules_for_calibration

MODEL_ID = "ibm-granite/granite-4.0-h-small"

model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model = replace_modules_for_calibration(model)

ignore_lay = ["lm_head", "re:.*block_sparse_moe.router", "re:.*mamba.in_proj", "re:.*shared_mlp.input_linear"]

recipe = QuantizationModifier(
targets=["Linear"],
scheme="FP8_BLOCK",
ignore=ignore_lay,
)

oneshot(model=model, recipe=recipe)

print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer(
"Describe Large Language Model", return_tensors="pt"
).input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=35)
print(tokenizer.decode(output[0]))
print("==========================================")

SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-block"
print(f"Saving to {SAVE_DIR}")

model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
250 changes: 250 additions & 0 deletions src/llmcompressor/modeling/granite4.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,258 @@
import os
import torch
import torch.nn as nn
from transformers.models.granitemoehybrid.configuration_granitemoehybrid import GraniteMoeHybridConfig
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
GraniteMoeHybridParallelExperts,
GraniteMoeHybridMoE,
)

from llmcompressor.modeling.moe_context import MoECalibrationModule

class SequentialGraniteMoeExperts(nn.Module):
"""
Unpacked version of GraniteMoeHybridParallelExperts with individual expert layers.

This module:
1. Unpacks the packed expert weights (3D -> individual Linear layers)
2. Processes experts sequentially
3. Compatible with FP8 block quantization and vLLM
"""

def __init__(
self,
original: GraniteMoeHybridParallelExperts,
calibrate_all_experts: bool = True,
):
super().__init__()
self.num_experts = original.num_experts
self.input_size = original.input_size
self.output_size = original.output_size
self.calibrate_all_experts = calibrate_all_experts

# Create individual linear layers for each expert
self.experts = nn.ModuleList([
nn.Linear(self.input_size, self.output_size, bias=False)
for _ in range(self.num_experts)
])

# Copy weights from the original 3D tensor
# Original format: [num_experts, output_size, input_size]
for i in range(self.num_experts):
self.experts[i].weight.data = original.weight.data[i].clone()

def forward(self, inputs, expert_size, batch_index=None):
"""
Forward pass using individual expert layers.

Args:
inputs: Input tensor to be processed by experts
expert_size: List containing the size of inputs for each expert
batch_index: Token indices for routing (needed in calibration mode)

Returns:
Concatenated output from all experts
"""
if self.calibrate_all_experts:
# During calibration, process all inputs through each expert
# but only keep the outputs corresponding to tokens routed to that expert
output_list = []
start_idx = 0
for i in range(self.num_experts):
end_idx = start_idx + expert_size[i]
# Get token indices assigned to this expert
expert_token_indices = batch_index[start_idx:end_idx]
# Process ALL tokens through this expert
expert_out_all = self.experts[i](inputs)
# Only keep outputs for tokens assigned to this expert
expert_out = expert_out_all[expert_token_indices]
output_list.append(expert_out)
start_idx = end_idx
results = torch.cat(output_list, dim=0)
else:
# Normal routing: only process tokens assigned to this expert
input_list = inputs.split(expert_size, dim=0)
output_list = []
for i in range(self.num_experts):
output_list.append(self.experts[i](input_list[i]))
results = torch.cat(output_list, dim=0)

return results


@MoECalibrationModule.register("GraniteMoeHybridMoE")
class CalibrationGraniteMoeHybridMoE(MoECalibrationModule):
"""
Calibration version of GraniteMoeHybridMoE that unpacks both input_linear and output_linear experts.

This module:
1. Replaces both GraniteMoeHybridParallelExperts modules with unpacked versions
2. Optionally sends all tokens to all experts during calibration
3. Stays in unpacked form (permanent) for vLLM compatibility and FP8 block quantization
"""

is_permanent = True

def __init__(
self,
original: GraniteMoeHybridMoE,
config: GraniteMoeHybridConfig,
calibrate_all_experts: bool = True,
):
super().__init__()
self.input_size = original.input_size
self.hidden_size = original.hidden_size
self.activation = original.activation
self.calibrate_all_experts = calibrate_all_experts

# Replace input_linear and output_linear with unpacked versions
self.input_linear = SequentialGraniteMoeExperts(
original.input_linear,
calibrate_all_experts=calibrate_all_experts,
)
self.output_linear = SequentialGraniteMoeExperts(
original.output_linear,
calibrate_all_experts=calibrate_all_experts,
)

# Keep the router unchanged
self.router = original.router

def forward(self, layer_input):
"""
Forward pass of the MoE layer.

Args:
layer_input: Input tensor of shape [batch_size, seq_len, hidden_size]

Returns:
Tuple of (output tensor, router_logits) where:
- output tensor has shape [batch_size, seq_len, hidden_size]
- router_logits has shape [batch_size * seq_len, num_experts]
"""
bsz, length, emb_size = layer_input.size()
layer_input_flat = layer_input.reshape(-1, emb_size)

# Router determines expert assignments
_, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input_flat)

if self.calibrate_all_experts:
# During calibration, send all tokens to all experts
# Pass batch_index so experts know which outputs to keep
hidden_states = self.input_linear(layer_input_flat, expert_size, batch_index)

# Apply activation (SwiGLU-style)
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]

# Process through output_linear experts
expert_outputs = self.output_linear(hidden_states, expert_size, batch_index)

# Apply gating weights
expert_outputs_gated = expert_outputs * batch_gates[:, None]
else:
# Normal routing: only send tokens to assigned experts
expert_inputs = layer_input_flat[batch_index]

# Process through input_linear experts
hidden_states = self.input_linear(expert_inputs, expert_size)

# Apply activation (SwiGLU-style)
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]

# Process through output_linear experts
expert_outputs = self.output_linear(hidden_states, expert_size)

# Apply gating weights
expert_outputs_gated = expert_outputs * batch_gates[:, None]

# Aggregate expert outputs
zeros = torch.zeros(
(bsz * length, self.input_size),
dtype=expert_outputs_gated.dtype,
device=expert_outputs_gated.device
)
layer_output = zeros.index_add(0, batch_index, expert_outputs_gated)
layer_output = layer_output.view(bsz, length, self.input_size)

return layer_output, router_logits


# Legacy function for backward compatibility with prepare.py
def replace(
config: GraniteMoeHybridConfig,
module: GraniteMoeHybridMoE,
calibrate_all_experts: bool,
):
"""
Legacy replacement function for use with prepare.py.

This function is deprecated. Use moe_calibration_context instead:

Example:
from llmcompressor.modeling.moe_context import moe_calibration_context

with moe_calibration_context(model, calibrate_all_experts=True):
# Run calibration
pass

Args:
config: The GraniteMoeHybridConfig for the model
module: The GraniteMoeHybridMoE module to replace
calibrate_all_experts: Whether to calibrate all experts

Returns:
CalibrationGraniteMoeHybridMoE calibration module
"""
return CalibrationGraniteMoeHybridMoE(
module,
config,
calibrate_all_experts=calibrate_all_experts,
)


def replace_granite_moe_with_linear_experts(model):
"""
Legacy replacement function that recursively replaces all GraniteMoeHybridMoE modules.

This function is deprecated. Use moe_calibration_context instead:

Example:
from llmcompressor.modeling.moe_context import moe_calibration_context

with moe_calibration_context(model, calibrate_all_experts=True):
# Run calibration
pass

Args:
model: The model containing GraniteMoeHybridMoE modules

Returns:
The modified model with replaced expert modules
"""
def replace_moe_modules(module, name=''):
for child_name, child in module.named_children():
full_name = f"{name}.{child_name}" if name else child_name

if child.__class__.__name__ == 'GraniteMoeHybridMoE':
# Create replacement module with unpacked experts
calibrated = CalibrationGraniteMoeHybridMoE(
original=child,
config=model.config,
calibrate_all_experts=True,
)
# Replace the module
setattr(module, child_name, calibrated)
print(f"Replaced {full_name}: GraniteMoeHybridMoE with unpacked experts")
else:
# Recursively process children
replace_moe_modules(child, full_name)

replace_moe_modules(model)
return model



class GraniteMoeHybridParallelExpertsLinear(torch.nn.Linear):
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
Expand Down
Loading
Loading