Skip to content

Commit 559a8e8

Browse files
authored
rotaryembed layer (#6407)
* add deepseek_v3 attention test * fuse non interleaved rotary embed * fuse t5 layernorm without gamma * test qwen2 attention * sdpa pattern ++
1 parent 11fb997 commit 559a8e8

23 files changed

+937
-7
lines changed

docs/developer-guide/operators.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
* [Reshape](#reshape)
7777
* [RMSNorm](#rmsnorm)
7878
* [RNN](#rnn)
79+
* [RotaryEmbed](#rotaryembed)
7980
* [Scale](#scale)
8081
* [SDPA](#sdpa)
8182
* [SELU](#selu)
@@ -1778,6 +1779,18 @@ Direction flag:
17781779
- 1 = reverse only
17791780
- 2 = bidirectional
17801781

1782+
# RotaryEmbed
1783+
Apply rotary positional embeddings with cos and sin cache
1784+
1785+
```
1786+
y1 = x1 * cos - x2 * sin
1787+
y2 = x1 * sin + x2 * cos
1788+
```
1789+
1790+
| param id | name | type | default | description |
1791+
| --------- | ------------- | ----- | --------- | ----------------- |
1792+
| 0 | interleaved | int | 0 | |
1793+
17811794
# Scale
17821795
```
17831796
if scale_data_size == -233 y = x0 * x1

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ ncnn_add_layer(Spectrogram)
172172
ncnn_add_layer(InverseSpectrogram)
173173
ncnn_add_layer(Flip)
174174
ncnn_add_layer(SDPA)
175+
ncnn_add_layer(RotaryEmbed)
175176

176177
if(NCNN_VULKAN)
177178
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)

src/layer/rotaryembed.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "rotaryembed.h"
5+
6+
namespace ncnn {
7+
8+
RotaryEmbed::RotaryEmbed()
9+
{
10+
}
11+
12+
int RotaryEmbed::load_param(const ParamDict& pd)
13+
{
14+
interleaved = pd.get(0, 0);
15+
16+
return 0;
17+
}
18+
19+
int RotaryEmbed::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
20+
{
21+
// assert bottom_blobs.size() == 3
22+
23+
const Mat& bottom_blob = bottom_blobs[0];
24+
const Mat& cos_cache = bottom_blobs[1];
25+
const Mat& sin_cache = bottom_blobs[2];
26+
27+
const int embed_dim = bottom_blob.w;
28+
const int seqlen = bottom_blob.h;
29+
const int num_heads = bottom_blob.c;
30+
31+
Mat& top_blob = top_blobs[0];
32+
top_blob.create_like(bottom_blob, opt.blob_allocator);
33+
if (top_blob.empty())
34+
return -100;
35+
36+
#pragma omp parallel for num_threads(opt.num_threads)
37+
for (int q = 0; q < num_heads; q++)
38+
{
39+
const Mat head = bottom_blob.channel(q);
40+
Mat out_head = top_blob.channel(q);
41+
42+
for (int i = 0; i < seqlen; i++)
43+
{
44+
if (interleaved)
45+
{
46+
const float* ptr = head.row(i);
47+
const float* cos_ptr = cos_cache.row(i);
48+
const float* sin_ptr = sin_cache.row(i);
49+
float* outptr = out_head.row(i);
50+
51+
for (int j = 0; j < embed_dim / 2; j++)
52+
{
53+
const float x0 = ptr[0];
54+
const float x1 = ptr[1];
55+
const float cos_val = *cos_ptr++;
56+
const float sin_val = *sin_ptr++;
57+
outptr[0] = x0 * cos_val - x1 * sin_val;
58+
outptr[1] = x0 * sin_val + x1 * cos_val;
59+
ptr += 2;
60+
outptr += 2;
61+
}
62+
}
63+
else
64+
{
65+
const float* ptr0 = head.row(i);
66+
const float* ptr1 = ptr0 + embed_dim / 2;
67+
const float* sin_ptr = sin_cache.row(i);
68+
const float* cos_ptr = cos_cache.row(i);
69+
float* outptr0 = out_head.row(i);
70+
float* outptr1 = outptr0 + embed_dim / 2;
71+
72+
for (int j = 0; j < embed_dim / 2; j++)
73+
{
74+
const float x0 = *ptr0++;
75+
const float x1 = *ptr1++;
76+
const float cos_val = *cos_ptr++;
77+
const float sin_val = *sin_ptr++;
78+
*outptr0++ = x0 * cos_val - x1 * sin_val;
79+
*outptr1++ = x0 * sin_val + x1 * cos_val;
80+
}
81+
}
82+
}
83+
}
84+
85+
return 0;
86+
}
87+
88+
} // namespace ncnn

src/layer/rotaryembed.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#ifndef LAYER_ROTARYEMBED_H
5+
#define LAYER_ROTARYEMBED_H
6+
7+
#include "layer.h"
8+
9+
namespace ncnn {
10+
11+
class RotaryEmbed : public Layer
12+
{
13+
public:
14+
RotaryEmbed();
15+
16+
virtual int load_param(const ParamDict& pd);
17+
18+
virtual int forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const;
19+
20+
public:
21+
int interleaved;
22+
};
23+
24+
} // namespace ncnn
25+
26+
#endif // LAYER_ROTARYEMBED_H

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ ncnn_add_layer_test(RMSNorm)
150150
ncnn_add_layer_test(RNN)
151151
ncnn_add_layer_test(ROIPooling)
152152
ncnn_add_layer_test(ROIAlign)
153+
ncnn_add_layer_test(RotaryEmbed)
153154
ncnn_add_layer_test(Scale)
154155
ncnn_add_layer_test(SDPA)
155156
ncnn_add_layer_test(SELU)

tests/test_rotaryembed.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "testutil.h"
5+
6+
static int test_rotaryembed(const ncnn::Mat& a, int interleaved)
7+
{
8+
const int embed_dim = a.w;
9+
const int seqlen = a.h;
10+
const int num_heads = a.c;
11+
12+
ncnn::Mat cos_cache = RandomMat(embed_dim / 2, seqlen);
13+
ncnn::Mat sin_cache = RandomMat(embed_dim / 2, seqlen);
14+
15+
ncnn::ParamDict pd;
16+
pd.set(0, interleaved);
17+
18+
std::vector<ncnn::Mat> weights(0);
19+
20+
std::vector<ncnn::Mat> as(3);
21+
as[0] = a;
22+
as[1] = cos_cache;
23+
as[2] = sin_cache;
24+
25+
int ret = test_layer("RotaryEmbed", pd, weights, as, 1);
26+
if (ret != 0)
27+
{
28+
fprintf(stderr, "test_rotaryembed failed a=(%d %d %d) interleaved=%d\n", a.w, a.h, a.c, interleaved);
29+
}
30+
31+
return ret;
32+
}
33+
34+
static int test_rotaryembed_0()
35+
{
36+
return 0
37+
|| test_rotaryembed(RandomMat(32, 66, 8), 0)
38+
|| test_rotaryembed(RandomMat(26, 64, 8), 1)
39+
|| test_rotaryembed(RandomMat(64, 28, 12), 0)
40+
|| test_rotaryembed(RandomMat(48, 22, 12), 1)
41+
|| test_rotaryembed(RandomMat(44, 28, 64), 0)
42+
|| test_rotaryembed(RandomMat(12, 27, 64), 1)
43+
|| test_rotaryembed(RandomMat(28, 17, 15), 0)
44+
|| test_rotaryembed(RandomMat(28, 17, 15), 1);
45+
}
46+
47+
int main()
48+
{
49+
SRAND(7767517);
50+
51+
return test_rotaryembed_0();
52+
}

tests/test_rotaryembed_oom.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "testutil.h"
5+
6+
static int test_rotaryembed_oom(const ncnn::Mat& a, int interleaved)
7+
{
8+
const int embed_dim = a.w;
9+
const int seqlen = a.h;
10+
const int num_heads = a.c;
11+
12+
ncnn::Mat cos_cache = RandomMat(embed_dim / 2, seqlen);
13+
ncnn::Mat sin_cache = RandomMat(embed_dim / 2, seqlen);
14+
15+
ncnn::ParamDict pd;
16+
pd.set(0, interleaved);
17+
18+
std::vector<ncnn::Mat> weights(0);
19+
20+
std::vector<ncnn::Mat> as(3);
21+
as[0] = a;
22+
as[1] = cos_cache;
23+
as[2] = sin_cache;
24+
25+
int ret = test_layer_oom("RotaryEmbed", pd, weights, as, 1);
26+
if (ret != 0)
27+
{
28+
fprintf(stderr, "test_rotaryembed_oom failed a=(%d %d %d) interleaved=%d\n", a.w, a.h, a.c, interleaved);
29+
}
30+
31+
return ret;
32+
}
33+
34+
static int test_rotaryembed_0()
35+
{
36+
return 0
37+
|| test_rotaryembed_oom(RandomMat(32, 66, 8), 0)
38+
|| test_rotaryembed_oom(RandomMat(28, 17, 15), 1);
39+
}
40+
41+
int main()
42+
{
43+
SRAND(7767517);
44+
45+
return test_rotaryembed_0();
46+
}

tools/modelwriter.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
#include "layer/rnn.h"
9393
#include "layer/roialign.h"
9494
#include "layer/roipooling.h"
95+
#include "layer/rotaryembed.h"
9596
#include "layer/scale.h"
9697
#include "layer/sdpa.h"
9798
#include "layer/shufflechannel.h"
@@ -2407,6 +2408,13 @@ int ModelWriter::save(const char* parampath, const char* binpath)
24072408
fprintf_param_value(" 1=%d", pooled_height)
24082409
fprintf_param_value(" 2=%e", spatial_scale)
24092410
}
2411+
else if (layer->type == "RotaryEmbed")
2412+
{
2413+
ncnn::RotaryEmbed* op = (ncnn::RotaryEmbed*)layer;
2414+
ncnn::RotaryEmbed* op_default = (ncnn::RotaryEmbed*)layer_default;
2415+
2416+
fprintf_param_value(" 0=%d", interleaved)
2417+
}
24102418
else if (layer->type == "Scale")
24112419
{
24122420
ncnn::Scale* op = (ncnn::Scale*)layer;

tools/pnnx/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ set(pnnx_pass_ncnn_SRCS
430430
pass_ncnn/eliminate_output.cpp
431431
pass_ncnn/expand_expression.cpp
432432
pass_ncnn/fuse_convert_shufflechannel_slice.cpp
433+
pass_ncnn/fuse_convert_rotaryembed.cpp
433434
pass_ncnn/insert_split.cpp
434435
pass_ncnn/chain_multi_output.cpp
435436
pass_ncnn/solve_batch_index.cpp

tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,71 @@ pnnx.Output output 1 0 out
5454
}
5555
};
5656

57+
class fuse_rmsnorm_pass_without_gamma : public GraphRewriterPass
58+
{
59+
public:
60+
const char* match_pattern_graph() const
61+
{
62+
return R"PNNXIR(7767517
63+
5 4
64+
pnnx.Input input 0 1 input
65+
pnnx.Expression op_0 1 1 input sq expr=pow(@0,2)
66+
torch.mean op_1 1 1 sq sqmean dim=(-1) keepdim=True
67+
pnnx.Expression op_2 2 1 input sqmean out expr=mul(@0,rsqrt(add(@1,%eps)))
68+
pnnx.Output output 1 0 out
69+
)PNNXIR";
70+
}
71+
72+
const char* type_str() const
73+
{
74+
return "nn.RMSNorm";
75+
}
76+
77+
const char* name_str() const
78+
{
79+
return "t5ln";
80+
}
81+
82+
bool match(const std::map<std::string, const Operator*>& matched_operators, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& /*captured_attrs*/) const
83+
{
84+
const Operator* op_0 = matched_operators.at("op_0");
85+
const std::vector<int>& shape = op_0->inputs[0]->shape;
86+
if (shape.empty())
87+
{
88+
// unknown normalized_shape
89+
return false;
90+
}
91+
92+
return true;
93+
}
94+
95+
void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
96+
{
97+
const std::vector<int>& shape = op->inputs[0]->shape;
98+
const int c = shape[shape.size() - 1];
99+
100+
op->params["elementwise_affine"] = false;
101+
op->params["eps"] = captured_params.at("eps");
102+
op->params["normalized_shape"] = std::vector<int>{c};
103+
}
104+
};
105+
106+
class fuse_rmsnorm_pass_without_gamma_1 : public fuse_rmsnorm_pass_without_gamma
107+
{
108+
public:
109+
const char* match_pattern_graph() const
110+
{
111+
return R"PNNXIR(7767517
112+
5 4
113+
pnnx.Input input 0 1 input
114+
pnnx.Expression op_0 1 1 input sq expr=pow(@0,2)
115+
torch.mean op_1 1 1 sq sqmean dim=(-1) keepdim=True
116+
pnnx.Expression op_2 2 1 input sqmean out expr=mul(@0,reciprocal(sqrt(add(@1,%eps))))
117+
pnnx.Output output 1 0 out
118+
)PNNXIR";
119+
}
120+
};
121+
57122
class fuse_rmsnorm_pass_onnx : public fuse_rmsnorm_pass
58123
{
59124
public:
@@ -75,11 +140,15 @@ void fuse_rmsnorm(Graph& graph)
75140
{
76141
fuse_rmsnorm_pass a;
77142
fuse_rmsnorm_pass_1 a1;
143+
fuse_rmsnorm_pass_without_gamma a2;
144+
fuse_rmsnorm_pass_without_gamma_1 a3;
78145
fuse_rmsnorm_pass_onnx b;
79146
int opindex = 0;
80147

81148
pnnx_graph_rewrite(graph, &a, opindex);
82149
pnnx_graph_rewrite(graph, &a1, opindex);
150+
pnnx_graph_rewrite(graph, &a2, opindex);
151+
pnnx_graph_rewrite(graph, &a3, opindex);
83152
pnnx_graph_rewrite(graph, &b, opindex);
84153
}
85154

0 commit comments

Comments
 (0)