@@ -2636,15 +2636,15 @@ struct test_rms_norm_back : public test_case {
2636
2636
}
2637
2637
};
2638
2638
2639
- // GGML_OP_RMS_NORM + GGML_OP_MUL
2640
- struct test_rms_norm_mul : public test_case {
2639
+ // GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ADD
2640
+ struct test_rms_norm_mul_add : public test_case {
2641
2641
const ggml_type type;
2642
2642
const std::array<int64_t , 4 > ne;
2643
2643
const float eps;
2644
2644
2645
2645
std::string op_desc (ggml_tensor * t) override {
2646
2646
GGML_UNUSED (t);
2647
- return " RMS_NORM_MUL " ;
2647
+ return " RMS_NORM_MUL_ADD " ;
2648
2648
}
2649
2649
2650
2650
bool run_whole_graph () override { return true ; }
@@ -2653,22 +2653,25 @@ struct test_rms_norm_mul : public test_case {
2653
2653
return VARS_TO_STR3 (type, ne, eps);
2654
2654
}
2655
2655
2656
- test_rms_norm_mul (ggml_type type = GGML_TYPE_F32,
2656
+ test_rms_norm_mul_add (ggml_type type = GGML_TYPE_F32,
2657
2657
std::array<int64_t , 4 > ne = {64 , 5 , 4 , 3 },
2658
2658
float eps = 1e-6f )
2659
2659
: type(type), ne(ne), eps(eps) {}
2660
2660
2661
2661
ggml_tensor * build_graph (ggml_context * ctx) override {
2662
2662
ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
2663
2663
ggml_tensor * b = ggml_new_tensor (ctx, type, 4 , ne.data ());
2664
+ ggml_tensor * c = ggml_new_tensor (ctx, type, 4 , ne.data ());
2664
2665
ggml_set_param (a);
2665
2666
ggml_set_name (a, " a" );
2666
2667
ggml_set_param (b);
2667
2668
ggml_set_name (b, " b" );
2669
+ ggml_set_param (c);
2670
+ ggml_set_name (c, " c" );
2668
2671
2669
- // Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul
2670
- a = ggml_add (ctx, a, b);
2671
- ggml_tensor * out = ggml_mul (ctx, ggml_rms_norm (ctx, a, eps), b);
2672
+ // Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul
2673
+ a = ggml_add (ctx, ggml_add (ctx, a, b), c );
2674
+ ggml_tensor * out = ggml_add (ctx, ggml_mul (ctx, ggml_rms_norm (ctx, a, eps), b), c );
2672
2675
ggml_set_name (out, " out" );
2673
2676
2674
2677
return out;
@@ -5188,7 +5191,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
5188
5191
test_cases.emplace_back (new test_l2_norm (GGML_TYPE_F32, {64 , 5 , 4 , 3 }, eps));
5189
5192
}
5190
5193
for (float eps : {0 .0f , 1e-6f , 1e-4f , 1e-1f , 1 .0f }) {
5191
- test_cases.emplace_back (new test_rms_norm_mul (GGML_TYPE_F32, {64 , 5 , 4 , 3 }, eps));
5194
+ test_cases.emplace_back (new test_rms_norm_mul_add (GGML_TYPE_F32, {64 , 5 , 4 , 3 }, eps));
5192
5195
}
5193
5196
5194
5197
test_cases.emplace_back (new test_l2_norm (GGML_TYPE_F32, {64 , 5 , 4 , 3 }, 1e-12f ));
0 commit comments