Skip to content

Commit 7c156df

Browse files
authored
opencl: support pad_ext (ggml-org#15888)
1 parent 16b0ca0 commit 7c156df

File tree

2 files changed

+80
-36
lines changed

2 files changed

+80
-36
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
28892889
case GGML_OP_REPEAT:
28902890
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
28912891
case GGML_OP_PAD:
2892-
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
2893-
op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
2894-
(ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
2895-
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
2892+
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
28962893
case GGML_OP_UPSCALE:
28972894
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
28982895
case GGML_OP_CONV_2D:
@@ -5881,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
58815878
GGML_ASSERT(dst->extra);
58825879
GGML_ASSERT(src0->type == GGML_TYPE_F32);
58835880
GGML_ASSERT(dst->type == GGML_TYPE_F32);
5884-
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1);
58855881

58865882
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
58875883

@@ -5899,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
58995895
const int s_ne0 = src0->ne[0];
59005896
const int s_ne1 = src0->ne[1];
59015897
const int s_ne2 = src0->ne[2];
5898+
const int s_ne3 = src0->ne[3];
5899+
5900+
const int s_nb0 = src0->nb[0];
5901+
const int s_nb1 = src0->nb[1];
5902+
const int s_nb2 = src0->nb[2];
5903+
const int s_nb3 = src0->nb[3];
59025904

59035905
const int d_ne0 = dst->ne[0];
59045906
const int d_ne1 = dst->ne[1];
59055907
const int d_ne2 = dst->ne[2];
5908+
const int d_ne3 = dst->ne[3];
5909+
5910+
const int d_nb0 = dst->nb[0];
5911+
const int d_nb1 = dst->nb[1];
5912+
const int d_nb2 = dst->nb[2];
5913+
const int d_nb3 = dst->nb[3];
5914+
5915+
const int lp0 = ((const int*)(dst->op_params))[0];
5916+
const int rp0 = ((const int*)(dst->op_params))[1];
5917+
const int lp1 = ((const int*)(dst->op_params))[2];
5918+
const int rp1 = ((const int*)(dst->op_params))[3];
5919+
const int lp2 = ((const int*)(dst->op_params))[4];
5920+
const int rp2 = ((const int*)(dst->op_params))[5];
5921+
const int lp3 = ((const int*)(dst->op_params))[6];
5922+
const int rp3 = ((const int*)(dst->op_params))[7];
59065923

59075924
cl_kernel kernel = backend_ctx->kernel_pad;
59085925

5909-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
5910-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
5911-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
5912-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
5913-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
5914-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
5915-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
5916-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0));
5917-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1));
5918-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2));
5926+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
5927+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
5928+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
5929+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
5930+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
5931+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
5932+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
5933+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3));
5934+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0));
5935+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1));
5936+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2));
5937+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3));
5938+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0));
5939+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1));
5940+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2));
5941+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3));
5942+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0));
5943+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1));
5944+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2));
5945+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3));
5946+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0));
5947+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0));
5948+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1));
5949+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1));
5950+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2));
5951+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2));
5952+
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3));
5953+
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3));
59195954

59205955
size_t lws0 = 64;
59215956
size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;
59225957

5923-
size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 };
5958+
size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 };
59245959
size_t local_work_size[] = { lws0, 1, 1 };
59255960

59265961
size_t * local_work_size_ptr = local_work_size;
Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,39 @@
11
kernel void kernel_pad(
2-
global const void * src0_ptr,
3-
ulong src0_offset,
4-
global void * dst_ptr,
5-
ulong dst_offset,
6-
int s_ne0, int s_ne1, int s_ne2,
7-
int d_ne0, int d_ne1, int d_ne2
2+
global void * src0,
3+
ulong offset0,
4+
global void * dst,
5+
ulong offsetd,
6+
int ne00, int ne01, int ne02, int ne03,
7+
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
8+
int ne0, int ne1, int ne2, int ne3,
9+
ulong nb0, ulong nb1, ulong nb2, ulong nb3,
10+
int lp0, int rp0,
11+
int lp1, int rp1,
12+
int lp2, int rp2,
13+
int lp3, int rp3
814
) {
9-
global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset);
10-
global float * dst = (global float *)((global char *)dst_ptr + dst_offset);
15+
src0 = (global float*)((global char*)src0 + offset0);
16+
dst = (global float*)((global char*)dst + offsetd);
1117

12-
int nidx = get_global_id(0);
13-
int idx_d1 = get_group_id(1);
14-
int idx_d2 = get_group_id(2);
18+
int i0 = get_global_id(0);
19+
int i1 = get_group_id(1);
20+
int i2 = get_group_id(2) % ne2;
21+
int i3 = get_group_id(2) / ne2;
1522

16-
if (nidx >= d_ne0) {
23+
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
1724
return;
1825
}
1926

20-
int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1;
27+
uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
28+
uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
2129

22-
bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2);
30+
global float * src0_ptr = (global float *)((global char *)src0 + src0_idx);
31+
global float * dst_ptr = (global float *)((global char *)dst + dst_idx);
2332

24-
if (in_src_bounds) {
25-
int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1;
26-
dst[dst_el_offset] = src0[src_el_offset];
27-
} else {
28-
dst[dst_el_offset] = 0.0f;
29-
}
33+
bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) &&
34+
(i1 >= lp1 && i1 < ne1 - rp1) &&
35+
(i2 >= lp2 && i2 < ne2 - rp2) &&
36+
(i3 >= lp3 && i3 < ne3 - rp3);
37+
38+
*dst_ptr = in_src_bounds ? *src0_ptr : 0.0f;
3039
}

0 commit comments

Comments
 (0)