Skip to content

model : add GroveMoE support #15510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2415,7 +2415,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--cpu-moe", "-cmoe"},
"keep all Mixture of Experts (MoE) weights in the CPU",
[](common_params & params) {
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_(ch|)exps", ggml_backend_cpu_buffer_type()});
}
).set_env("LLAMA_ARG_CPU_MOE"));
add_opt(common_arg(
Expand All @@ -2428,7 +2428,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
for (int i = 0; i < value; ++i) {
// keep strings alive and avoid leaking memory by storing them in a static vector
static std::list<std::string> buft_overrides;
buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_(ch|)exps", i));
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()});
}
}
Expand Down
115 changes: 115 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7654,6 +7654,121 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
class GroveMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROVEMOE

def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L299
self.gguf_writer.add_expert_chunk_feed_forward_length(self.hparams.get("head_dim") or 128)
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L298
self.gguf_writer.add_experts_per_group(2)
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L376
self.gguf_writer.add_expert_group_scale(0.05)
# YaRN is not enabled by default
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])

_experts: list[dict[str, Tensor]] | None = None
_chunk_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.endswith(".expert_bias"):
# FIXME?: Unused https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L303
return []

# process the experts separately
if name.find("chunk_experts") != -1:
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
assert bid is not None

if self._chunk_experts is None:
self._chunk_experts = [{} for _ in range(self.block_count)]

self._chunk_experts[bid][name] = data_torch

if len(self._chunk_experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.chunk_experts.{xid}.{w_name}.weight"
datas.append(self._chunk_experts[bid][ename])
del self._chunk_experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.chunk_experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []
elif name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()

if self._chunk_experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
chunk_experts = [k for d in self._chunk_experts for k in d.keys()]
if len(chunk_experts) > 0:
raise ValueError(f"Unprocessed adjugate experts: {chunk_experts}")

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("ChameleonForConditionalGeneration")
@ModelBase.register("ChameleonForCausalLM") # obsolete
class ChameleonModel(TextModel):
Expand Down
10 changes: 10 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,16 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

GGML_API struct ggml_tensor * ggml_div_scalar_i32(
struct ggml_context * ctx,
struct ggml_tensor * a,
int32_t b);

GGML_API struct ggml_tensor * ggml_div_scalar_left_i32(
struct ggml_context * ctx,
int32_t a,
struct ggml_tensor * b);

GGML_API struct ggml_tensor * ggml_sqr(
struct ggml_context * ctx,
struct ggml_tensor * a);
Expand Down
7 changes: 6 additions & 1 deletion ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2471,7 +2471,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
case GGML_OP_SQR:
case GGML_OP_SQRT:
Expand All @@ -2494,6 +2493,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_DIV:
{
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
return a && b && a->type != GGML_TYPE_I32 && b->type != GGML_TYPE_I32;
} break;
case GGML_OP_SCALE:
float bias;
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
Expand Down
39 changes: 38 additions & 1 deletion ggml/src/ggml-cpu/binary-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,50 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
}
}

static void apply_scalar_div_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src = src0 ? src0 : src1;
const int32_t scalar = ggml_get_op_params_i32(dst, 0);

GGML_ASSERT(ggml_are_same_shape(src, dst));

GGML_TENSOR_BINARY_OP_LOCALS

const auto [ir0, ir1] = get_thread_range(params, src);

for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);

int32_t * dst_ptr = (int32_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
const int32_t * src_ptr = (const int32_t *) ((const char *) src->data + i03*nb03 + i02*nb02 + i01*nb01);

for (int i = 0; i < ne00; i++) {
dst_ptr[i] = src0 ? src_ptr[i] / scalar : scalar / src_ptr[i];
}
}
}

// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
template <float (*op)(float, float)>
static void binary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

/* */ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
/* */ if (!src0 || !src1) { // scalar
if (dst->type == GGML_TYPE_I32) {
if (op == op_div) {
apply_scalar_div_op(params, dst);
} else {
GGML_ABORT("%s: unsupported op\n", __func__);
}
} else {
GGML_ABORT("%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
apply_binary_op<op, float, float, float>(params, dst);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
apply_binary_op<op, ggml_fp16_t, ggml_fp16_t, ggml_fp16_t>(params, dst);
Expand Down
7 changes: 6 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3434,7 +3434,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
Expand All @@ -3443,6 +3442,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CLAMP:
case GGML_OP_LOG:
return true;
case GGML_OP_DIV:
{
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
return a && b && a->type != GGML_TYPE_I32 && b->type != GGML_TYPE_I32;
} break;
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
// Mamba2
Expand Down
7 changes: 6 additions & 1 deletion ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1814,9 +1814,14 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_DIV:
{
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
return a && b && a->type == GGML_TYPE_F32 && b->type == GGML_TYPE_F32;
} break;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
Expand Down
10 changes: 9 additions & 1 deletion ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2608,11 +2608,19 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
}
}
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SUB:
return (op->src[0]->type == op->src[1]->type) &&
(op->src[0]->type == op->type) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_DIV:
{
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
return (a && b) &&
(a->type == b->type) &&
(a->type == op->type) &&
(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16);
} break;
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UNARY:
Expand Down
7 changes: 6 additions & 1 deletion ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4349,9 +4349,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_REPEAT:
return true;
case GGML_OP_DIV:
{
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
return a && b && a->type != GGML_TYPE_I32 && b->type != GGML_TYPE_I32;
} break;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
Expand Down
10 changes: 9 additions & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11397,10 +11397,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_DIV:
{
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
return (a && b) &&
(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16) &&
(b->type == GGML_TYPE_F32 || b->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
} break;
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_F32;
Expand Down
30 changes: 30 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,36 @@ struct ggml_tensor * ggml_div_inplace(
return ggml_div_impl(ctx, a, b, true);
}

struct ggml_tensor * ggml_div_scalar_i32(
struct ggml_context * ctx,
struct ggml_tensor * a,
int32_t b) {
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);

ggml_set_op_params_i32(result, 0, b);

result->op = GGML_OP_DIV;
result->src[0] = a;
result->src[1] = NULL;

return result;
}

struct ggml_tensor * ggml_div_scalar_left_i32(
struct ggml_context * ctx,
int32_t a,
struct ggml_tensor * b) {
struct ggml_tensor * result = ggml_dup_tensor(ctx, b);

ggml_set_op_params_i32(result, 0, a);

result->op = GGML_OP_DIV;
result->src[0] = NULL;
result->src[1] = b;

return result;
}

// ggml_sqr

static struct ggml_tensor * ggml_sqr_impl(
Expand Down
Loading
Loading