Skip to content

Commit 3fe0e39

Browse files
committed
Merge commit '4dca015b7e019d5bfa9d3872b19ad4cf97859c22' into concedo_experimental
# Conflicts: # .github/copilot-instructions.md # README.md # docs/ops.md # docs/ops/CPU.csv # docs/ops/CUDA.csv # docs/ops/Vulkan.csv # ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp # src/CMakeLists.txt # tests/test-backend-ops.cpp
2 parents 85060da + 4dca015 commit 3fe0e39

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+2790
-848
lines changed

common/common.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
363363
}
364364

365365
void common_init() {
366-
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
367-
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
368-
common_log_add(common_log_main(), level, "%s", text);
369-
}
370-
}, NULL);
366+
llama_log_set(common_log_default_callback, NULL);
371367

372368
#ifdef NDEBUG
373369
const char * build_type = "";

common/log.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,9 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
442442
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
443443
log->set_timestamps(timestamps);
444444
}
445+
446+
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
447+
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
448+
common_log_add(common_log_main(), level, "%s", text);
449+
}
450+
}

common/log.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ extern int common_log_verbosity_thold;
3636

3737
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
3838

39+
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
40+
3941
// the common_log uses an internal worker thread to print/write log messages
4042
// when the worker thread is paused, incoming log messages are discarded
4143
struct common_log;

convert_hf_to_gguf.py

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
189189
return tensors
190190

191191
prefix = "model" if not self.is_mistral_format else "consolidated"
192-
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
192+
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
193193
is_safetensors: bool = len(part_names) > 0
194194
if not is_safetensors:
195-
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
195+
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
196196

197197
tensor_names_from_index: set[str] = set()
198198

@@ -209,6 +209,7 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
209209
if weight_map is None or not isinstance(weight_map, dict):
210210
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
211211
tensor_names_from_index.update(weight_map.keys())
212+
part_names |= set(weight_map.values())
212213
else:
213214
weight_map = {}
214215
else:
@@ -825,6 +826,15 @@ def set_gguf_parameters(self):
825826
self.gguf_writer.add_expert_group_used_count(n_group_used)
826827
logger.info(f"gguf: expert groups used count = {n_group_used}")
827828

829+
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None:
830+
if score_func == "sigmoid":
831+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
832+
elif score_func == "softmax":
833+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
834+
else:
835+
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
836+
logger.info(f"gguf: expert score gating function = {score_func}")
837+
828838
if (head_dim := self.hparams.get("head_dim")) is not None:
829839
self.gguf_writer.add_key_length(head_dim)
830840
self.gguf_writer.add_value_length(head_dim)
@@ -1124,6 +1134,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
11241134
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
11251135
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
11261136
res = "mellum"
1137+
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
1138+
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
1139+
res = "afmoe"
11271140
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
11281141
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
11291142
res = "bailingmoe2"
@@ -2533,6 +2546,72 @@ def set_gguf_parameters(self):
25332546
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
25342547

25352548

2549+
@ModelBase.register("AfmoeForCausalLM")
2550+
class AfmoeModel(LlamaModel):
2551+
model_arch = gguf.MODEL_ARCH.AFMOE
2552+
2553+
def set_gguf_parameters(self):
2554+
super().set_gguf_parameters()
2555+
2556+
# MoE parameters
2557+
if (n_experts := self.hparams.get("num_experts")) is not None:
2558+
self.gguf_writer.add_expert_count(n_experts)
2559+
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
2560+
self.gguf_writer.add_expert_shared_count(n_shared_experts)
2561+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
2562+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
2563+
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
2564+
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
2565+
2566+
# Route normalization and scaling
2567+
if (route_norm := self.hparams.get("route_norm")) is not None:
2568+
self.gguf_writer.add_expert_weights_norm(route_norm)
2569+
if (route_scale := self.hparams.get("route_scale")) is not None:
2570+
self.gguf_writer.add_expert_weights_scale(route_scale)
2571+
2572+
# Sliding window attention
2573+
if (sliding_window := self.hparams.get("sliding_window")) is not None:
2574+
self.gguf_writer.add_sliding_window(sliding_window)
2575+
2576+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2577+
# Handle expert weights - they're already merged in the HF format
2578+
# process the experts separately
2579+
if name.find("mlp.experts") != -1:
2580+
n_experts = self.hparams["num_experts"]
2581+
assert bid is not None
2582+
2583+
if self._experts is None:
2584+
self._experts = [{} for _ in range(self.block_count)]
2585+
2586+
self._experts[bid][name] = data_torch
2587+
2588+
if len(self._experts[bid]) >= n_experts * 3:
2589+
tensors: list[tuple[str, Tensor]] = []
2590+
2591+
# merge the experts into a single 3d tensor
2592+
for w_name in ["gate_proj", "up_proj", "down_proj"]:
2593+
datas: list[Tensor] = []
2594+
2595+
for xid in range(n_experts):
2596+
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
2597+
datas.append(self._experts[bid][ename_to_retrieve])
2598+
del self._experts[bid][ename_to_retrieve]
2599+
2600+
data_torch = torch.stack(datas, dim=0)
2601+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
2602+
new_name = self.map_tensor_name(merged_name)
2603+
tensors.append((new_name, data_torch))
2604+
2605+
return tensors
2606+
else:
2607+
return []
2608+
2609+
if name.endswith(".expert_bias"):
2610+
name = name.replace(".expert_bias", ".expert_bias.bias")
2611+
2612+
return [(self.map_tensor_name(name), data_torch)]
2613+
2614+
25362615
@ModelBase.register(
25372616
"LlavaForConditionalGeneration", # pixtral
25382617
"Mistral3ForConditionalGeneration", # mistral small 3.1
@@ -7104,13 +7183,6 @@ def set_gguf_parameters(self):
71047183
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
71057184
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
71067185

7107-
if hparams["scoring_func"] == "sigmoid":
7108-
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
7109-
elif hparams["scoring_func"] == "softmax":
7110-
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
7111-
else:
7112-
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
7113-
71147186
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
71157187

71167188
rope_scaling = self.hparams.get("rope_scaling") or {}
@@ -7216,12 +7288,6 @@ def __init__(self, *args, **kwargs):
72167288

72177289
def set_gguf_parameters(self):
72187290
super().set_gguf_parameters()
7219-
if self.hparams["scoring_func"] == "sigmoid":
7220-
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
7221-
elif self.hparams["scoring_func"] == "softmax":
7222-
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
7223-
else:
7224-
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
72257291

72267292
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
72277293
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
@@ -7314,11 +7380,6 @@ def set_gguf_parameters(self):
73147380
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
73157381
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
73167382

7317-
if self.hparams["scoring_func"] == "noaux_tc":
7318-
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
7319-
else:
7320-
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
7321-
73227383
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
73237384
if name.endswith("e_score_correction_bias"):
73247385
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
@@ -8639,13 +8700,6 @@ def set_gguf_parameters(self):
86398700
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
86408701
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
86418702

8642-
if hparams["score_function"] == "sigmoid":
8643-
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
8644-
elif hparams["score_function"] == "softmax":
8645-
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
8646-
else:
8647-
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
8648-
86498703
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
86508704
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
86518705

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
139139
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
140140
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
141141
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
142+
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
142143
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
143144
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
144145
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },

ggml/include/ggml.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ extern "C" {
481481
GGML_OP_COS,
482482
GGML_OP_SUM,
483483
GGML_OP_SUM_ROWS,
484+
GGML_OP_CUMSUM,
484485
GGML_OP_MEAN,
485486
GGML_OP_ARGMAX,
486487
GGML_OP_COUNT_EQUAL,
@@ -536,6 +537,8 @@ extern "C" {
536537
GGML_OP_TIMESTEP_EMBEDDING,
537538
GGML_OP_ARGSORT,
538539
GGML_OP_LEAKY_RELU,
540+
GGML_OP_TRI,
541+
GGML_OP_FILL,
539542

540543
GGML_OP_FLASH_ATTN_EXT,
541544
GGML_OP_FLASH_ATTN_BACK,
@@ -548,6 +551,7 @@ extern "C" {
548551
GGML_OP_RWKV_WKV6,
549552
GGML_OP_GATED_LINEAR_ATTN,
550553
GGML_OP_RWKV_WKV7,
554+
GGML_OP_SOLVE_TRI,
551555

552556
GGML_OP_UNARY,
553557

@@ -594,6 +598,8 @@ extern "C" {
594598
GGML_UNARY_OP_HARDSWISH,
595599
GGML_UNARY_OP_HARDSIGMOID,
596600
GGML_UNARY_OP_EXP,
601+
GGML_UNARY_OP_EXPM1,
602+
GGML_UNARY_OP_SOFTPLUS,
597603
GGML_UNARY_OP_GELU_ERF,
598604
GGML_UNARY_OP_XIELU,
599605
GGML_UNARY_OP_FLOOR,
@@ -638,6 +644,13 @@ extern "C" {
638644
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
639645
};
640646

647+
enum ggml_tri_type {
648+
GGML_TRI_TYPE_UPPER_DIAG = 0,
649+
GGML_TRI_TYPE_UPPER = 1,
650+
GGML_TRI_TYPE_LOWER_DIAG = 2,
651+
GGML_TRI_TYPE_LOWER = 3
652+
};
653+
641654
struct ggml_init_params {
642655
// memory pool
643656
size_t mem_size; // bytes
@@ -982,6 +995,22 @@ extern "C" {
982995
struct ggml_context * ctx,
983996
struct ggml_tensor * a);
984997

998+
GGML_API struct ggml_tensor * ggml_expm1(
999+
struct ggml_context * ctx,
1000+
struct ggml_tensor * a);
1001+
1002+
GGML_API struct ggml_tensor * ggml_expm1_inplace(
1003+
struct ggml_context * ctx,
1004+
struct ggml_tensor * a);
1005+
1006+
GGML_API struct ggml_tensor * ggml_softplus(
1007+
struct ggml_context * ctx,
1008+
struct ggml_tensor * a);
1009+
1010+
GGML_API struct ggml_tensor * ggml_softplus_inplace(
1011+
struct ggml_context * ctx,
1012+
struct ggml_tensor * a);
1013+
9851014
GGML_API struct ggml_tensor * ggml_sin(
9861015
struct ggml_context * ctx,
9871016
struct ggml_tensor * a);
@@ -1008,6 +1037,10 @@ extern "C" {
10081037
struct ggml_context * ctx,
10091038
struct ggml_tensor * a);
10101039

1040+
GGML_API struct ggml_tensor * ggml_cumsum(
1041+
struct ggml_context * ctx,
1042+
struct ggml_tensor * a);
1043+
10111044
// mean along rows
10121045
GGML_API struct ggml_tensor * ggml_mean(
10131046
struct ggml_context * ctx,
@@ -2212,6 +2245,23 @@ extern "C" {
22122245
int shift2,
22132246
int shift3);
22142247

2248+
// Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
2249+
// zeroes everywhere outside the masked area
2250+
GGML_API struct ggml_tensor * ggml_tri(
2251+
struct ggml_context * ctx,
2252+
struct ggml_tensor * a,
2253+
enum ggml_tri_type type);
2254+
2255+
// Fill tensor a with constant c
2256+
GGML_API struct ggml_tensor * ggml_fill(
2257+
struct ggml_context * ctx,
2258+
struct ggml_tensor * a,
2259+
float c);
2260+
2261+
GGML_API struct ggml_tensor * ggml_fill_inplace(
2262+
struct ggml_context * ctx,
2263+
struct ggml_tensor * a,
2264+
float c);
22152265

22162266
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
22172267
// timesteps: [N,]
@@ -2381,6 +2431,27 @@ extern "C" {
23812431
struct ggml_tensor * b,
23822432
struct ggml_tensor * state);
23832433

2434+
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
2435+
* without zeroes on the diagonal (i.e. invertible).
2436+
* B can have any number of columns, but must have the same number of rows as A
2437+
* If A is [n, n] and B is [n, m], then the result will be [n, m] as well
2438+
* Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
2439+
* where n > 100 sparingly, pre-chunk if necessary.
2440+
*
2441+
* If left = false, solves xA=B instead
2442+
* If lower = false, assumes upper triangular instead
2443+
* If uni = true, assumes diagonal of A to be all ones (will override actual values)
2444+
*
2445+
* TODO: currently only lower, right, non-unitriangular variant is implemented
2446+
*/
2447+
GGML_API struct ggml_tensor * ggml_solve_tri(
2448+
struct ggml_context * ctx,
2449+
struct ggml_tensor * a,
2450+
struct ggml_tensor * b,
2451+
bool left,
2452+
bool lower,
2453+
bool uni);
2454+
23842455
// custom operators
23852456

23862457
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

ggml/src/ggml-backend.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,8 +1704,6 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
17041704
GGML_ASSERT(sched);
17051705
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
17061706

1707-
ggml_backend_sched_reset(sched);
1708-
17091707
ggml_backend_sched_synchronize(sched);
17101708

17111709
ggml_backend_sched_split_graph(sched, measure_graph);

0 commit comments

Comments
 (0)