@@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2889
2889
case GGML_OP_REPEAT:
2890
2890
return op->src [0 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
2891
2891
case GGML_OP_PAD:
2892
- return op->src [0 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
2893
- op->src [0 ]->ne [3 ] == 1 && op->ne [3 ] == 1 &&
2894
- (ggml_get_op_params_i32 (op, 0 ) == 0 ) && (ggml_get_op_params_i32 (op, 2 ) == 0 ) &&
2895
- (ggml_get_op_params_i32 (op, 4 ) == 0 ) && (ggml_get_op_params_i32 (op, 6 ) == 0 );
2892
+ return op->src [0 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2896
2893
case GGML_OP_UPSCALE:
2897
2894
return op->src [0 ]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2898
2895
case GGML_OP_CONV_2D:
@@ -5881,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
5881
5878
GGML_ASSERT (dst->extra );
5882
5879
GGML_ASSERT (src0->type == GGML_TYPE_F32);
5883
5880
GGML_ASSERT (dst->type == GGML_TYPE_F32);
5884
- GGML_ASSERT (src0->ne [3 ] == 1 && dst->ne [3 ] == 1 );
5885
5881
5886
5882
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5887
5883
@@ -5899,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
5899
5895
const int s_ne0 = src0->ne [0 ];
5900
5896
const int s_ne1 = src0->ne [1 ];
5901
5897
const int s_ne2 = src0->ne [2 ];
5898
+ const int s_ne3 = src0->ne [3 ];
5899
+
5900
+ const int s_nb0 = src0->nb [0 ];
5901
+ const int s_nb1 = src0->nb [1 ];
5902
+ const int s_nb2 = src0->nb [2 ];
5903
+ const int s_nb3 = src0->nb [3 ];
5902
5904
5903
5905
const int d_ne0 = dst->ne [0 ];
5904
5906
const int d_ne1 = dst->ne [1 ];
5905
5907
const int d_ne2 = dst->ne [2 ];
5908
+ const int d_ne3 = dst->ne [3 ];
5909
+
5910
+ const int d_nb0 = dst->nb [0 ];
5911
+ const int d_nb1 = dst->nb [1 ];
5912
+ const int d_nb2 = dst->nb [2 ];
5913
+ const int d_nb3 = dst->nb [3 ];
5914
+
5915
+ const int lp0 = ((const int *)(dst->op_params ))[0 ];
5916
+ const int rp0 = ((const int *)(dst->op_params ))[1 ];
5917
+ const int lp1 = ((const int *)(dst->op_params ))[2 ];
5918
+ const int rp1 = ((const int *)(dst->op_params ))[3 ];
5919
+ const int lp2 = ((const int *)(dst->op_params ))[4 ];
5920
+ const int rp2 = ((const int *)(dst->op_params ))[5 ];
5921
+ const int lp3 = ((const int *)(dst->op_params ))[6 ];
5922
+ const int rp3 = ((const int *)(dst->op_params ))[7 ];
5906
5923
5907
5924
cl_kernel kernel = backend_ctx->kernel_pad ;
5908
5925
5909
- CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra_src0->data_device ));
5910
- CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &off_src0));
5911
- CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra_dst->data_device ));
5912
- CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &off_dst));
5913
- CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &s_ne0));
5914
- CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &s_ne1));
5915
- CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &s_ne2));
5916
- CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &d_ne0));
5917
- CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (int ), &d_ne1));
5918
- CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (int ), &d_ne2));
5926
+ CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra_src0->data_device ));
5927
+ CL_CHECK (clSetKernelArg (kernel, 1 , sizeof (cl_ulong), &off_src0));
5928
+ CL_CHECK (clSetKernelArg (kernel, 2 , sizeof (cl_mem), &extra_dst->data_device ));
5929
+ CL_CHECK (clSetKernelArg (kernel, 3 , sizeof (cl_ulong), &off_dst));
5930
+ CL_CHECK (clSetKernelArg (kernel, 4 , sizeof (int ), &s_ne0));
5931
+ CL_CHECK (clSetKernelArg (kernel, 5 , sizeof (int ), &s_ne1));
5932
+ CL_CHECK (clSetKernelArg (kernel, 6 , sizeof (int ), &s_ne2));
5933
+ CL_CHECK (clSetKernelArg (kernel, 7 , sizeof (int ), &s_ne3));
5934
+ CL_CHECK (clSetKernelArg (kernel, 8 , sizeof (cl_ulong), &s_nb0));
5935
+ CL_CHECK (clSetKernelArg (kernel, 9 , sizeof (cl_ulong), &s_nb1));
5936
+ CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (cl_ulong), &s_nb2));
5937
+ CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &s_nb3));
5938
+ CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ), &d_ne0));
5939
+ CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ), &d_ne1));
5940
+ CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ), &d_ne2));
5941
+ CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ), &d_ne3));
5942
+ CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &d_nb0));
5943
+ CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &d_nb1));
5944
+ CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (cl_ulong), &d_nb2));
5945
+ CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (cl_ulong), &d_nb3));
5946
+ CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (int ), &lp0));
5947
+ CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (int ), &rp0));
5948
+ CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ), &lp1));
5949
+ CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ), &rp1));
5950
+ CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ), &lp2));
5951
+ CL_CHECK (clSetKernelArg (kernel, 25 , sizeof (int ), &rp2));
5952
+ CL_CHECK (clSetKernelArg (kernel, 26 , sizeof (int ), &lp3));
5953
+ CL_CHECK (clSetKernelArg (kernel, 27 , sizeof (int ), &rp3));
5919
5954
5920
5955
size_t lws0 = 64 ;
5921
5956
size_t gws0 = (( (size_t )d_ne0 + lws0 - 1 ) / lws0) * lws0;
5922
5957
5923
- size_t global_work_size[] = { gws0, (size_t )d_ne1, (size_t )d_ne2 };
5958
+ size_t global_work_size[] = { gws0, (size_t )d_ne1, (size_t )d_ne2*d_ne3 };
5924
5959
size_t local_work_size[] = { lws0, 1 , 1 };
5925
5960
5926
5961
size_t * local_work_size_ptr = local_work_size;
0 commit comments