Skip to content

Granite Four #13550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 138 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
138 commits
Select commit Hold shift + click to select a range
271104c
wip: llama : separate recurrent states from the KV cache
compilade Apr 3, 2024
8db1e4d
llama : use std::find for seq_nodes in llama_rs_cache
compilade Apr 4, 2024
0028010
llama : state checkpoints for recurrent models
compilade Apr 8, 2024
0c8b3b2
llama : correctly handle more edge cases for the rs cache
compilade Apr 9, 2024
d66849f
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 10, 2024
a09db95
llama : rename many llama_kv_cache_* functions
compilade Apr 29, 2024
c460ff1
Merge branch 'master' into compilade/refactor-kv-cache
compilade Apr 29, 2024
b6fafd1
llama : remove useless return value for some llama_cache_* functions
compilade Apr 29, 2024
b7ec12e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 12, 2024
3b57b55
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 22, 2024
7e13f19
llama : rethink recurrent state cell counts
compilade May 24, 2024
cbc743e
llama : support Jamba
compilade May 24, 2024
0fd13e9
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 24, 2024
61a88a1
llama : fix BERT inference without KV cache
compilade May 25, 2024
ea2e63e
convert-hf : check for unprocessed Jamba experts
compilade May 25, 2024
fc59407
convert-hf : support Mini-Jamba conversion
compilade May 25, 2024
181dadf
llama : fix Jamba quantization sanity checks
compilade May 28, 2024
3a414b0
llama : sequence-length-aware batch splitting
compilade May 28, 2024
4e4c41e
Merge branch 'master' into compilade/refactor-kv-cache
compilade May 28, 2024
3587a94
llama : use equal-sequence-length sub-batches for recurrent models
compilade Jun 1, 2024
5d3c7b9
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 1, 2024
72eea49
llama : fix batch split output count for embeddings
compilade Jun 1, 2024
18d1c14
llama : minimize swaps when reordering logits
compilade Jun 1, 2024
61200ef
llama : fix edge case finding batch seq_id of split recurrent cell
compilade Jun 1, 2024
eb589d5
llama : avoid copies for simple batch splits
compilade Jun 2, 2024
8fb57ac
llama : use im2col and mul_mat to perform convolution for Mamba
compilade Jun 3, 2024
17f6c1e
llama : fix .base() compilation error on Windows
compilade Jun 3, 2024
fee3c1d
llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
compilade Jun 3, 2024
6840ac0
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 8, 2024
372482d
llama : rename llama_cache to llama_past
compilade Jun 8, 2024
43d8d4b
examples : replace llama_kv_cache_seq_* with llama_past_seq_*
compilade Jun 10, 2024
ff794f5
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 12, 2024
33425a7
mamba : fix non-contiguous usage of ggml_silu
compilade Jun 12, 2024
10c3c41
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jun 30, 2024
9b38f8b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 4, 2024
1f0fea7
llama : initial Mamba-2 support
compilade Aug 1, 2024
dceff23
ggml : SIMD ggml_ssm_scan for Mamba-2
compilade Aug 19, 2024
2bfe9de
llama : support running Mamba-Codestral-7B-v0.1
compilade Aug 19, 2024
aff9692
llama : fix Mamba-2 conv state saving
compilade Aug 21, 2024
e04910d
llama : remove unused variable
compilade Aug 22, 2024
fa358e7
llama : add missing break
compilade Aug 22, 2024
38913dc
convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present
compilade Aug 22, 2024
bc320ef
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 1, 2024
fcb889c
llama : session saving and reloading for hybrid models
compilade Sep 2, 2024
a03e32a
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 2, 2024
9d3f44d
convert_hf : fix Jamba conversion
compilade Sep 2, 2024
5f62db7
llama : fix mixed signedness comparison
compilade Sep 2, 2024
375de5b
llama : use unused n_embd_k_gqa in k_shift
compilade Sep 2, 2024
4bb4b22
llama : begin renaming llama_past back to llama_kv_cache
compilade Sep 14, 2024
63ac36b
Merge branch 'master' into compilade/refactor-kv-cache
compilade Sep 14, 2024
0e601ca
Merge branch 'master' into compilade/mamba2
compilade Sep 18, 2024
273e7a4
llama : avoid redundant state copy for Mamba 1 and 2
compilade Sep 30, 2024
7d6cb36
Merge branch 'master' into compilade/mamba2
compilade Oct 1, 2024
2c77d79
metal : attempt to adapt SSM_SCAN for Mamba-2
compilade Oct 2, 2024
87b97d0
metal : fix SSM_SCAN pipeline scope
compilade Oct 2, 2024
03d0e6e
metal : use log and exp instead of log1pf and expf in SSM_SCAN
compilade Oct 2, 2024
7a351ab
metal : remove unused arguments for SSM_SCAN
compilade Oct 2, 2024
8b15bc6
metal : add back n_seqs to SSM_SCAN args
compilade Oct 2, 2024
5b8ec2b
metal : fix SSM_SCAN state head offset
compilade Oct 2, 2024
62b09b3
metal : fix wrong number of tokens per sequence in SSM_SCAN
compilade Oct 3, 2024
124c222
Merge branch 'master' into compilade/refactor-kv-cache
compilade Oct 12, 2024
038d958
Merge branch 'master' into compilade/mamba2
compilade Oct 12, 2024
805512a
ggml : remove unused fast broadcast path in GGML_MUL
compilade Oct 12, 2024
7d16e1b
Merge branch 'master' into compilade/mamba2
compilade Nov 1, 2024
3bc7103
ggml : avoid multiply by D in GGML_OP_SSM_SCAN
compilade Nov 4, 2024
8d8f065
Merge branch 'master' into compilade/mamba2
compilade Nov 4, 2024
b4e9c59
convert : fix flake8 lint
compilade Nov 4, 2024
8006f3b
llama : remove implicit recurrent state rollbacks
compilade Nov 25, 2024
691698e
Merge branch 'master' into compilade/refactor-kv-cache
compilade Nov 25, 2024
e3fe612
llama : partially apply clang-format style
compilade Nov 25, 2024
1ee6c48
Merge branch 'master' into compilade/mamba2
compilade Nov 25, 2024
c9ecf62
Merge branch 'master' into compilade/mamba2
compilade Feb 26, 2025
35d06fa
Merge branch 'master' into compilade/mamba2
compilade May 1, 2025
cf4f0a4
metal : fix confusion between ; and ,
compilade May 1, 2025
6def5cd
metal : add missing args for nb references in ssm_scan_f32_group
compilade May 1, 2025
791998b
metal : single-user mamba2 inference works
compilade May 2, 2025
94c3d53
kv-cache : remove const_cast when setting inputs for s_copy
compilade May 2, 2025
929fe85
Merge branch 'master' into compilade/mamba2
compilade May 2, 2025
d55b0d0
convert : avoid AutoConfig for Mamba and Mamba2 hparams
compilade May 2, 2025
e94f393
kv-cache : allow context shift for recurrent models
compilade May 2, 2025
9864bfc
Merge branch 'master' into compilade/mamba2
compilade Jun 10, 2025
2fa5f2c
graph : fix recurrent state copies when avoiding copies
compilade Jun 11, 2025
757aa62
ggml : fix mamba2 ssm scan when compiled with SVE
compilade Jun 11, 2025
0b6f6be
ggml-cpu : reorder SVE FMA for consistency with other SIMD arches
compilade Jun 11, 2025
a42f239
Merge branch 'master' into compilade/mamba2
compilade Jun 19, 2025
f8c7cae
cuda : implement ssm scan for Mamba2
compilade May 15, 2025
830e554
Merge branch 'master' into compilade/mamba2
compilade Jun 19, 2025
afdb669
Merge branch 'master' into compilade/mamba2
compilade Jun 23, 2025
28881af
feat: Add conversion for Bamba models
gabe-l-hart May 13, 2025
c43259b
feat: Add Granite 4 conversion
gabe-l-hart May 9, 2025
26816fd
feat: Plumb bamba through llama-arch
gabe-l-hart May 9, 2025
b901947
feat: Add bamba to llama_arch_is_hybrid_recurrent
gabe-l-hart May 20, 2025
fc56325
feat: Add optional mamba ssm_in bias tensor
gabe-l-hart May 13, 2025
b3453dc
feat: Add template specialization for get_arr to load a vector<uint32…
gabe-l-hart May 13, 2025
13e8d3d
feat: Use an explicit bool to determine mamaba vs mamba2
gabe-l-hart Jun 12, 2025
b435dce
feat: Isolate mamba(2) and granite attention layer building in static…
gabe-l-hart Jun 18, 2025
3d4c36b
fix: Use per-layer sizes in granite build_attention_layer
gabe-l-hart May 14, 2025
0d28bf6
feat: First (broken) pass at end-to-end Bamba implementation
gabe-l-hart May 14, 2025
ed6216a
fix: Only do Granite multipliers if set
gabe-l-hart May 14, 2025
a6f9f90
refactor: Pull granite ffn portion into a static function and reuse i…
gabe-l-hart May 14, 2025
de4d870
feat(py): Allow gguf duplicate keys if they match by value and type
gabe-l-hart May 14, 2025
7c2b0b8
refactor(py): Simplify granitemoehybrid conversion to use parents better
gabe-l-hart May 14, 2025
915f1e3
feat: Add GRANITE_MOE_HYBRID through llama-arch
gabe-l-hart May 14, 2025
d0d3723
feat: Support GRANITE_MOE_HYBRID in llama-model
gabe-l-hart May 14, 2025
2ca3416
style: Fix flake8 errors
gabe-l-hart May 14, 2025
3c22e1d
fix: Fix recurrent cache get after rebase
gabe-l-hart May 28, 2025
08493bf
fix: Fix hybrid granite implementation for signature changes in build…
gabe-l-hart May 29, 2025
ed15012
refactor: Refactor relationship between non-hybrid classes and hybrid…
gabe-l-hart Jun 26, 2025
40e2346
refactor: Implement the full copy-paste version to duplicate the laye…
gabe-l-hart Jun 26, 2025
a9dcc84
refactor: Rename llm_build_hybrid_mamba -> llm_build_granite_hybrid
gabe-l-hart Jun 26, 2025
dc1d109
mamba : fix mismatched new and delete size for llm_build_mamba
compilade Jun 26, 2025
fdc9a8d
Merge remote-tracking branch 'origin/compilade/mamba2' into mamba2-sync
gabe-l-hart Jun 27, 2025
2b263e6
Merge branch 'mamba2-sync' into GraniteFour
gabe-l-hart Jun 27, 2025
66a7a43
memory : correctly handle failure in apply()
ggerganov Jun 29, 2025
8cb4df5
Merge remote-tracking branch 'origin/master' into GraniteFour
gabe-l-hart Jun 30, 2025
f13f5bc
Merge remote-tracking branch 'origin/gg/memory-is-fail' into GraniteFour
gabe-l-hart Jun 30, 2025
6cac586
Merge remote-tracking branch 'origin/master' into GraniteFour
gabe-l-hart Jun 30, 2025
28361c4
Merge remote-tracking branch 'origin/master' into GraniteFour
gabe-l-hart Jul 1, 2025
bb2bb37
Merge remote-tracking branch 'origin/master' into GraniteFour
gabe-l-hart Jul 2, 2025
8f9b513
style: Remove TODO for adding first hybrid models to the switch
gabe-l-hart Jul 2, 2025
eaec9c6
fix: Fix bad merge in tensor_mapping.py w/ SSM_NORM
gabe-l-hart Jul 2, 2025
1085cf9
fix: Fix bad merge resolution with variable renames/moves in llm_buil…
gabe-l-hart Jul 2, 2025
b6d772f
docs: Fix comment about duplicate key check
gabe-l-hart Jul 2, 2025
bb590f2
fix: Conform to standard way of initializing inp_out_ids
gabe-l-hart Jul 2, 2025
1c21a04
Merge remote-tracking branch 'origin/master' into GraniteFour
gabe-l-hart Jul 2, 2025
2bcaf64
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 3, 2025
908e655
convert : fix jamba conv1d shape squeezing
compilade Jul 3, 2025
d7f4d73
Merge remote-tracking branch 'origin/master' into GraniteFour
gabe-l-hart Jul 3, 2025
e100153
Merge remote-tracking branch 'origin/compilade/refactor-kv-cache' int…
gabe-l-hart Jul 3, 2025
4b5f673
fix: Fix input initialization in granite_hybrid after removal of hybr…
gabe-l-hart Jul 3, 2025
0796726
fix: Use llm_graph_context_mamba in llm_build_granite_hybrid
gabe-l-hart Jul 3, 2025
f7fa1b1
refactor: Refactor mamba2/granite/jamba/granite_hybrid relationships …
gabe-l-hart Jul 3, 2025
4682e21
Merge branch 'master' into compilade/refactor-kv-cache
compilade Jul 3, 2025
20f8e43
graph : add back hybrid memory graph input
compilade Jul 3, 2025
07c252f
model : add Jamba to Mamba-specific hparams printing
compilade Jul 3, 2025
2e1431f
Merge remote-tracking branch 'origin/compilade/refactor-kv-cache' int…
gabe-l-hart Jul 7, 2025
5c32e80
fix: Fix input setup after upstream merge
gabe-l-hart Jul 7, 2025
f9d6dd1
Merge remote-tracking branch 'origin/master' into GraniteFour
gabe-l-hart Jul 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 269 additions & 16 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4872,6 +4872,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
Copy link
Contributor Author

@gabe-l-hart gabe-l-hart Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pulled these into the class so that they can be set differently by derived conversion classes and then used in the common methods below

self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1

def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
Expand All @@ -4894,30 +4897,27 @@ def set_vocab(self):
self._set_vocab_builtin("gpt-neox", vocab_size)

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64

rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
assert self.d_inner == 2 * self.d_model
assert self.d_inner % head_dim == 0

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_embedding_length(self.d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
self.gguf_writer.add_ssm_group_count(n_group)
self.gguf_writer.add_ssm_time_step_rank(self.d_inner // head_dim)
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)

Expand All @@ -4942,10 +4942,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
# (D is also unsqueezed, but for more straightforward broadcast internally)
data_torch = data_torch.reshape((*data_torch.shape, 1))
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
n_group = self.hparams.get("n_groups", 1)
data_torch = data_torch.reshape((n_group, d_inner // n_group))
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group))

if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
Expand All @@ -4954,6 +4951,229 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield (new_name, data_torch)


@ModelBase.register("BambaForCausalLM")
class BambaModel(Mamba2Model):
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
model_arch = gguf.MODEL_ARCH.BAMBA
undo_permute = True

def __init__(self, *args, **kwargs):

# Hybrid mamba models use a prefix for the mamba-specific params.
# TODO: Extend this if the prefix(es) need to be configurable
self.hparam_prefixes = ["mamba"]

super().__init__(*args, **kwargs)

# Use Llama conversion for attention
self._transformer_model_class: type[TextModel] = LlamaModel

# Lists of which layers use ssm vs attention
self._attn_layers = self.get_attn_layres()
self._ssm_layers = [
i for i in range(self.block_count)
if i not in self._attn_layers
]

# n_group and d_inner are used during reshape_tensors for mamaba2
self.d_model = self.find_hparam(["hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["expand"]) * self.d_model

def get_attn_layres(self) -> list[int]:
attn_layers = self.hparams.get("attn_layer_indices", [])
if not attn_layers:
attn_period = self.hparams.get("attn_layer_period")
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
attn_offset = self.hparams.get("attn_layer_offset")
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
attn_layers = [
i for i in range(self.block_count)
if i % attn_period == attn_offset
]
return attn_layers

def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
for pfx in self.hparam_prefixes:
prefixed.extend(
"_".join([pfx, k])
for k in keys
)
keys = list(keys) + prefixed
return super().find_hparam(keys, *args, **kwargs)

def set_gguf_parameters(self):

## General Params ##
self.gguf_writer.add_embedding_length(self.d_model)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])

## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))

## Attention params ##
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
if rope_dim := self.hparams.get("attn_rotary_emb"):
self.gguf_writer.add_rope_dimension_count(rope_dim)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))

## Feed Forward Params ##
self.gguf_writer.add_layer_norm_rms_eps(
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
)

## Validation ##
d_head = self.find_hparam(["d_head"], optional=True) or 64
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"

def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:

# Determine whether this is a mamaba layer or an attention layer
if bid in self._ssm_layers:
for mamba_new_name, data_torch in super().modify_tensors(
data_torch, name, bid
):
yield mamba_new_name, data_torch
elif bid in self._attn_layers:
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
self, data_torch, name, bid
):
yield llama_new_name, data_torch
else:
yield self.map_tensor_name(name), data_torch


@ModelBase.register("JambaForCausalLM")
class JambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.JAMBA

def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused

return "gpt-2"

def set_vocab(self):
if (self.dir_model / "tokenizer.model").is_file():
# Using Jamba's tokenizer.json causes errors on model load
# (something about "byte not found in vocab"),
# but there's a working tokenizer.model
self._set_vocab_sentencepiece()
else:
# Some Jamba models only have a tokenizer.json, which works.
self._set_vocab_gpt2()

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
d_inner = self.hparams["mamba_expand"] * d_model
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
n_kv_head = self.hparams["num_key_value_heads"]
attn_offset = self.hparams["attn_layer_offset"]
attn_period = self.hparams["attn_layer_period"]
n_kv_vec = [0 for _ in range(attn_offset)] + [
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
]

self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(n_kv_vec)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_file_type(self.ftype)

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:

# Mini-Jamba
name = name.replace(".moe.", ".feed_forward.")
if bid is not None:
moe_offset = self.hparams["expert_layer_offset"]
moe_period = self.hparams["expert_layer_period"]

if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
name = name.replace(".experts.0.", ".")

# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]

assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:

# merge the experts into a single 3d tensor
for wid in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

# using the same merged name as qwen2moe
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"

new_name = self.map_tensor_name(merged_name)

yield new_name, data_torch
return

new_name = self.map_tensor_name(name)

if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
data_torch = data_torch.squeeze()

if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)

yield (new_name, data_torch)

def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("CohereForCausalLM")
class CommandR2Model(TextModel):
model_arch = gguf.MODEL_ARCH.COMMAND_R
Expand Down Expand Up @@ -6327,6 +6547,39 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("GraniteMoeHybridForCausalLM")
class GraniteMoeHybridModel(BambaModel, GraniteMoeModel):
"""GraniteMoeHybrid is a hybrid SSM + MoE Attention model that uses Mamba2
SSM layers"""
model_arch = gguf.MODEL_ARCH.GRANITE_MOE_HYBRID

def get_attn_layres(self):
if layer_types := self.hparams.get("layer_types"):
return [
i for i, typ in enumerate(layer_types)
if typ == "attention"
]
return super().get_attn_layres()

def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
if (
name.endswith("block_sparse_moe.input_linear.weight")
or name.endswith("shared_mlp.input_linear.weight")
):
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
return super().modify_tensors(data_torch, name, bid)

def set_gguf_parameters(self):
GraniteMoeModel.set_gguf_parameters(self)
BambaModel.set_gguf_parameters(self)

def set_vocab(self):
self.hparams["pad_vocab_size_multiple"] = 8
super().set_vocab()


@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
Expand Down
Loading
Loading