Skip to content

Commit 012fb71

Browse files
committed
tests : add rms_norm + mul + add test
ggml-ci
1 parent 4d568cb commit 012fb71

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,14 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
138138

139139
if (ctx->mtl_device_ref_count == 0) {
140140
if (ctx->debug_fusion > 0) {
141+
fprintf(stderr, "%s: fusion stats:\n", __func__);
141142
for (int i = 0; i < GGML_OP_COUNT; i++) {
142143
if (ctx->fuse_cnt[i] == 0) {
143144
continue;
144145
}
145146

146147
// note: cannot use ggml_log here
147-
fprintf(stderr, "%s: %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
148+
fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
148149
}
149150
}
150151

@@ -2212,8 +2213,6 @@ static int ggml_metal_encode_node(
22122213
}
22132214
}
22142215

2215-
//GGML_LOG_INFO("%s: XXXXXXXXXXXXXXXXXXX n_fuse = %d\n", __func__, n_fuse);
2216-
22172216
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
22182217
GGML_ASSERT(ggml_is_contiguous(src0));
22192218

@@ -4335,8 +4334,6 @@ static int ggml_metal_encode_node(
43354334
}
43364335
}
43374336

4338-
//GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);
4339-
43404337
if (n_fuse > 1) {
43414338
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
43424339
}

tests/test-backend-ops.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2636,15 +2636,15 @@ struct test_rms_norm_back : public test_case {
26362636
}
26372637
};
26382638

2639-
// GGML_OP_RMS_NORM + GGML_OP_MUL
2640-
struct test_rms_norm_mul : public test_case {
2639+
// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ADD
2640+
struct test_rms_norm_mul_add : public test_case {
26412641
const ggml_type type;
26422642
const std::array<int64_t, 4> ne;
26432643
const float eps;
26442644

26452645
std::string op_desc(ggml_tensor * t) override {
26462646
GGML_UNUSED(t);
2647-
return "RMS_NORM_MUL";
2647+
return "RMS_NORM_MUL_ADD";
26482648
}
26492649

26502650
bool run_whole_graph() override { return true; }
@@ -2653,22 +2653,25 @@ struct test_rms_norm_mul : public test_case {
26532653
return VARS_TO_STR3(type, ne, eps);
26542654
}
26552655

2656-
test_rms_norm_mul(ggml_type type = GGML_TYPE_F32,
2656+
test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
26572657
std::array<int64_t, 4> ne = {64, 5, 4, 3},
26582658
float eps = 1e-6f)
26592659
: type(type), ne(ne), eps(eps) {}
26602660

26612661
ggml_tensor * build_graph(ggml_context * ctx) override {
26622662
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
26632663
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
2664+
ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());
26642665
ggml_set_param(a);
26652666
ggml_set_name(a, "a");
26662667
ggml_set_param(b);
26672668
ggml_set_name(b, "b");
2669+
ggml_set_param(c);
2670+
ggml_set_name(c, "c");
26682671

2669-
// Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul
2670-
a = ggml_add(ctx, a, b);
2671-
ggml_tensor * out = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
2672+
// Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul
2673+
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
2674+
ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c);
26722675
ggml_set_name(out, "out");
26732676

26742677
return out;
@@ -5188,7 +5191,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
51885191
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
51895192
}
51905193
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
5191-
test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
5194+
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
51925195
}
51935196

51945197
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));

0 commit comments

Comments
 (0)