@@ -189,10 +189,10 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
189189 return tensors
190190
191191 prefix = "model" if not self .is_mistral_format else "consolidated"
192- part_names : list [str ] = ModelBase .get_model_part_names (self .dir_model , prefix , ".safetensors" )
192+ part_names : set [str ] = set ( ModelBase .get_model_part_names (self .dir_model , prefix , ".safetensors" ) )
193193 is_safetensors : bool = len (part_names ) > 0
194194 if not is_safetensors :
195- part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
195+ part_names = set ( ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" ) )
196196
197197 tensor_names_from_index : set [str ] = set ()
198198
@@ -209,6 +209,7 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
209209 if weight_map is None or not isinstance (weight_map , dict ):
210210 raise ValueError (f"Can't load 'weight_map' from { index_name !r} " )
211211 tensor_names_from_index .update (weight_map .keys ())
212+ part_names |= set (weight_map .values ())
212213 else :
213214 weight_map = {}
214215 else :
@@ -825,6 +826,15 @@ def set_gguf_parameters(self):
825826 self .gguf_writer .add_expert_group_used_count (n_group_used )
826827 logger .info (f"gguf: expert groups used count = { n_group_used } " )
827828
829+ if (score_func := self .find_hparam (["score_function" , "scoring_func" , "score_func" ], optional = True )) is not None :
830+ if score_func == "sigmoid" :
831+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
832+ elif score_func == "softmax" :
833+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SOFTMAX )
834+ else :
835+ raise ValueError (f"Unsupported expert score gating function value: { score_func } " )
836+ logger .info (f"gguf: expert score gating function = { score_func } " )
837+
828838 if (head_dim := self .hparams .get ("head_dim" )) is not None :
829839 self .gguf_writer .add_key_length (head_dim )
830840 self .gguf_writer .add_value_length (head_dim )
@@ -1124,6 +1134,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
11241134 if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756" :
11251135 # ref: https://huggingface.co/JetBrains/Mellum-4b-base
11261136 res = "mellum"
1137+ if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df" :
1138+ # ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
1139+ res = "afmoe"
11271140 if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206" :
11281141 # ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
11291142 res = "bailingmoe2"
@@ -2533,6 +2546,72 @@ def set_gguf_parameters(self):
25332546 self .gguf_writer .add_rope_scaling_orig_ctx_len (rope_scaling ["original_max_position_embeddings" ])
25342547
25352548
2549+ @ModelBase .register ("AfmoeForCausalLM" )
2550+ class AfmoeModel (LlamaModel ):
2551+ model_arch = gguf .MODEL_ARCH .AFMOE
2552+
2553+ def set_gguf_parameters (self ):
2554+ super ().set_gguf_parameters ()
2555+
2556+ # MoE parameters
2557+ if (n_experts := self .hparams .get ("num_experts" )) is not None :
2558+ self .gguf_writer .add_expert_count (n_experts )
2559+ if (n_shared_experts := self .hparams .get ("num_shared_experts" )) is not None :
2560+ self .gguf_writer .add_expert_shared_count (n_shared_experts )
2561+ if (moe_intermediate_size := self .hparams .get ("moe_intermediate_size" )) is not None :
2562+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size )
2563+ if (n_dense_layers := self .hparams .get ("num_dense_layers" )) is not None :
2564+ self .gguf_writer .add_leading_dense_block_count (n_dense_layers )
2565+
2566+ # Route normalization and scaling
2567+ if (route_norm := self .hparams .get ("route_norm" )) is not None :
2568+ self .gguf_writer .add_expert_weights_norm (route_norm )
2569+ if (route_scale := self .hparams .get ("route_scale" )) is not None :
2570+ self .gguf_writer .add_expert_weights_scale (route_scale )
2571+
2572+ # Sliding window attention
2573+ if (sliding_window := self .hparams .get ("sliding_window" )) is not None :
2574+ self .gguf_writer .add_sliding_window (sliding_window )
2575+
2576+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2577+ # Handle expert weights - they're already merged in the HF format
2578+ # process the experts separately
2579+ if name .find ("mlp.experts" ) != - 1 :
2580+ n_experts = self .hparams ["num_experts" ]
2581+ assert bid is not None
2582+
2583+ if self ._experts is None :
2584+ self ._experts = [{} for _ in range (self .block_count )]
2585+
2586+ self ._experts [bid ][name ] = data_torch
2587+
2588+ if len (self ._experts [bid ]) >= n_experts * 3 :
2589+ tensors : list [tuple [str , Tensor ]] = []
2590+
2591+ # merge the experts into a single 3d tensor
2592+ for w_name in ["gate_proj" , "up_proj" , "down_proj" ]:
2593+ datas : list [Tensor ] = []
2594+
2595+ for xid in range (n_experts ):
2596+ ename_to_retrieve = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
2597+ datas .append (self ._experts [bid ][ename_to_retrieve ])
2598+ del self ._experts [bid ][ename_to_retrieve ]
2599+
2600+ data_torch = torch .stack (datas , dim = 0 )
2601+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
2602+ new_name = self .map_tensor_name (merged_name )
2603+ tensors .append ((new_name , data_torch ))
2604+
2605+ return tensors
2606+ else :
2607+ return []
2608+
2609+ if name .endswith (".expert_bias" ):
2610+ name = name .replace (".expert_bias" , ".expert_bias.bias" )
2611+
2612+ return [(self .map_tensor_name (name ), data_torch )]
2613+
2614+
25362615@ModelBase .register (
25372616 "LlavaForConditionalGeneration" , # pixtral
25382617 "Mistral3ForConditionalGeneration" , # mistral small 3.1
@@ -7104,13 +7183,6 @@ def set_gguf_parameters(self):
71047183 self .gguf_writer .add_expert_weights_scale (hparams ["routed_scaling_factor" ])
71057184 self .gguf_writer .add_expert_weights_norm (hparams ["norm_topk_prob" ])
71067185
7107- if hparams ["scoring_func" ] == "sigmoid" :
7108- self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
7109- elif hparams ["scoring_func" ] == "softmax" :
7110- self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SOFTMAX )
7111- else :
7112- raise ValueError (f"Unsupported scoring_func value: { hparams ['scoring_func' ]} " )
7113-
71147186 self .gguf_writer .add_rope_dimension_count (hparams ["qk_rope_head_dim" ])
71157187
71167188 rope_scaling = self .hparams .get ("rope_scaling" ) or {}
@@ -7216,12 +7288,6 @@ def __init__(self, *args, **kwargs):
72167288
72177289 def set_gguf_parameters (self ):
72187290 super ().set_gguf_parameters ()
7219- if self .hparams ["scoring_func" ] == "sigmoid" :
7220- self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
7221- elif self .hparams ["scoring_func" ] == "softmax" :
7222- self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SOFTMAX )
7223- else :
7224- raise ValueError (f"Unsupported scoring_func value: { self .hparams ['scoring_func' ]} " )
72257291
72267292 self .gguf_writer .add_expert_feed_forward_length (self .find_hparam (["intermediate_size" ]))
72277293 self .gguf_writer .add_rope_dimension_count (self .find_hparam (["rotary_dim" ]))
@@ -7314,11 +7380,6 @@ def set_gguf_parameters(self):
73147380 self .gguf_writer .add_expert_weights_scale (self .hparams ["routed_scaling_factor" ])
73157381 self .gguf_writer .add_expert_weights_norm (self .hparams ["norm_topk_prob" ])
73167382
7317- if self .hparams ["scoring_func" ] == "noaux_tc" :
7318- self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
7319- else :
7320- raise ValueError (f"Unsupported scoring_func value: { self .hparams ['scoring_func' ]} " )
7321-
73227383 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ):
73237384 if name .endswith ("e_score_correction_bias" ):
73247385 name = name .replace ("e_score_correction_bias" , "e_score_correction.bias" )
@@ -8639,13 +8700,6 @@ def set_gguf_parameters(self):
86398700 self .gguf_writer .add_expert_shared_count (hparams ["num_shared_experts" ])
86408701 self .gguf_writer .add_expert_weights_norm (hparams ["norm_topk_prob" ])
86418702
8642- if hparams ["score_function" ] == "sigmoid" :
8643- self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
8644- elif hparams ["score_function" ] == "softmax" :
8645- self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SOFTMAX )
8646- else :
8647- raise ValueError (f"Unsupported score_function value: { hparams ['score_function' ]} " )
8648-
86498703 if (nextn_layers := self .hparams .get ("num_nextn_predict_layers" )) is not None :
86508704 self .gguf_writer .add_nextn_predict_layers (nextn_layers )
86518705
0 commit comments