@@ -252,10 +252,10 @@ struct webgpu_context_struct {
252252 webgpu_pipeline get_rows_pipeline[30 ];
253253 webgpu_pipeline get_rows_f32_no_vec_pipeline;
254254 webgpu_pipeline cpy_pipeline[2 ][2 ]; // src type, dst type
255- webgpu_pipeline add_pipeline[2 ][2 ][ 2 ] ; // type, inplace, overlap
256- webgpu_pipeline sub_pipeline[2 ][2 ][ 2 ] ; // type, inplace, overlap
257- webgpu_pipeline mul_pipeline[2 ][2 ][ 2 ] ; // type, inplace, overlap
258- webgpu_pipeline div_pipeline[2 ][2 ][ 2 ] ; // type, inplace, overlap
255+ webgpu_pipeline add_pipeline[2 ][2 ]; // type, inplace
256+ webgpu_pipeline sub_pipeline[2 ][2 ]; // type, inplace
257+ webgpu_pipeline mul_pipeline[2 ][2 ]; // type, inplace
258+ webgpu_pipeline div_pipeline[2 ][2 ]; // type, inplace
259259 webgpu_pipeline rms_norm_pipeline[2 ]; // inplace
260260 webgpu_pipeline rope_pipeline[2 ][2 ][2 ]; // type, ff, inplace
261261 webgpu_pipeline glu_pipeline[7 ][2 ][2 ]; // glu-op, type, split
@@ -677,12 +677,9 @@ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor
677677 return offset & ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
678678}
679679
680- static size_t ggml_webgpu_tensor_align_binding_size (size_t size) {
681- return (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1 ) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1 );
682- }
683-
684680static size_t ggml_webgpu_tensor_binding_size (webgpu_context & ctx, ggml_tensor * t) {
685- return ggml_webgpu_tensor_align_binding_size (ggml_nbytes (t) + ggml_webgpu_tensor_misalignment (ctx, t));
681+ return (ggml_nbytes (t) + ggml_webgpu_tensor_misalignment (ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1 ) &
682+ ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1 );
686683}
687684
688685// Used to determine if two tensors are the same for in-place operations
@@ -691,12 +688,6 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
691688 (ggml_webgpu_tensor_offset (a) == ggml_webgpu_tensor_offset (b));
692689}
693690
694- static bool ggml_webgpu_tensor_overlap (ggml_tensor * a, ggml_tensor * b) {
695- return (ggml_webgpu_tensor_buf (a).Get () == ggml_webgpu_tensor_buf (b).Get ()) &&
696- ggml_webgpu_tensor_offset (a) < (ggml_webgpu_tensor_offset (b) + ggml_nbytes (b)) &&
697- ggml_webgpu_tensor_offset (b) < (ggml_webgpu_tensor_offset (a) + ggml_nbytes (a));
698- }
699-
700691static webgpu_command ggml_webgpu_cpy (webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
701692 uint32_t ne = (uint32_t ) ggml_nelements (dst);
702693
@@ -879,27 +870,16 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
879870 return ggml_backend_webgpu_build (ctx, ctx->mul_mat_pipeline [src0->type ][src1->type ], params, entries, wg_x);
880871}
881872
882- template <size_t a, size_t b, size_t c>
883- static webgpu_command ggml_webgpu_binary_op (webgpu_context & ctx,
884- ggml_tensor * src0,
885- ggml_tensor * src1,
886- ggml_tensor * dst,
887- webgpu_pipeline (&pipelines)[a][b][c]) {
888- int inplace = ggml_webgpu_tensor_equal (src0, dst);
889- int overlap = ggml_webgpu_tensor_overlap (src0, src1);
890- webgpu_pipeline pipeline = pipelines[dst->type ][inplace][overlap];
891-
892- uint32_t src1_offset = ggml_webgpu_tensor_misalignment (ctx, src1) / ggml_type_size (src1->type );
893- if (overlap) {
894- // when overlapped, bind a single buffer covering both src0 and src1
895- // TODO: Do other operations need this?
896- src1_offset = (uint32_t ) ((ggml_webgpu_tensor_offset (src1) - ggml_webgpu_tensor_align_offset (ctx, src0)) /
897- ggml_type_size (src1->type ));
898- }
873+ static webgpu_command ggml_webgpu_binary_op (webgpu_context & ctx,
874+ ggml_tensor * src0,
875+ ggml_tensor * src1,
876+ ggml_tensor * dst,
877+ webgpu_pipeline & pipeline,
878+ bool inplace) {
899879 std::vector<uint32_t > params = {
900880 (uint32_t ) ggml_nelements (dst),
901881 (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src0) / ggml_type_size (src0->type )),
902- src1_offset ,
882+ ( uint32_t ) ( ggml_webgpu_tensor_misalignment (ctx, src1) / ggml_type_size (src1-> type )) ,
903883 (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )),
904884 (uint32_t ) (src1->nb [0 ] / ggml_type_size (src1->type )),
905885 (uint32_t ) (src1->nb [1 ] / ggml_type_size (src1->type )),
@@ -914,36 +894,25 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
914894 (uint32_t ) src1->ne [3 ],
915895 };
916896
917- size_t src0_binding_size = ggml_webgpu_tensor_binding_size (ctx, src0);
918- if (overlap) {
919- const uint64_t base_align = ggml_webgpu_tensor_align_offset (ctx, src0);
920- // assume end of src1 is >= end of src0
921- const uint64_t max_end = ggml_webgpu_tensor_offset (src1) + ggml_nbytes (src1);
922- src0_binding_size = ggml_webgpu_tensor_align_binding_size (max_end - base_align);
923- }
924897 std::vector<wgpu::BindGroupEntry> entries = {
925898 { .binding = 0 ,
926899 .buffer = ggml_webgpu_tensor_buf (src0),
927900 .offset = ggml_webgpu_tensor_align_offset (ctx, src0),
928- .size = src0_binding_size }
901+ .size = ggml_webgpu_tensor_binding_size (ctx, src0) },
902+ { .binding = 1 ,
903+ .buffer = ggml_webgpu_tensor_buf (src1),
904+ .offset = ggml_webgpu_tensor_align_offset (ctx, src1),
905+ .size = ggml_webgpu_tensor_binding_size (ctx, src1) }
929906 };
930- uint32_t binding_num = 1 ;
931- if (!overlap) {
932- entries.push_back ({ .binding = binding_num,
933- .buffer = ggml_webgpu_tensor_buf (src1),
934- .offset = ggml_webgpu_tensor_align_offset (ctx, src1),
935- .size = ggml_webgpu_tensor_binding_size (ctx, src1) });
936- binding_num++;
937- }
938907 if (!inplace) {
939- entries.push_back ({ .binding = binding_num ,
908+ entries.push_back ({ .binding = 2 ,
940909 .buffer = ggml_webgpu_tensor_buf (dst),
941910 .offset = ggml_webgpu_tensor_align_offset (ctx, dst),
942911 .size = ggml_webgpu_tensor_binding_size (ctx, dst) });
943912 }
944913
945- size_t max_wg_size = ctx->max_wg_size_x ;
946- uint32_t wg_x = (ggml_nelements (dst) + max_wg_size - 1 ) / max_wg_size;
914+ size_t max_wg_size = ctx->max_wg_size_x ;
915+ uint32_t wg_x = (ggml_nelements (dst) + max_wg_size - 1 ) / max_wg_size;
947916 return ggml_backend_webgpu_build (ctx, pipeline, params, entries, wg_x);
948917}
949918
@@ -1263,13 +1232,25 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
12631232 case GGML_OP_MUL_MAT:
12641233 return ggml_webgpu_mul_mat (ctx, src0, src1, node);
12651234 case GGML_OP_ADD:
1266- return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->add_pipeline );
1235+ {
1236+ int inplace = ggml_webgpu_tensor_equal (src0, node);
1237+ return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->add_pipeline [node->type ][inplace], inplace);
1238+ }
12671239 case GGML_OP_SUB:
1268- return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->sub_pipeline );
1240+ {
1241+ int inplace = ggml_webgpu_tensor_equal (src0, node);
1242+ return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->sub_pipeline [node->type ][inplace], inplace);
1243+ }
12691244 case GGML_OP_MUL:
1270- return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->mul_pipeline );
1245+ {
1246+ int inplace = ggml_webgpu_tensor_equal (src0, node);
1247+ return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->mul_pipeline [node->type ][inplace], inplace);
1248+ }
12711249 case GGML_OP_DIV:
1272- return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->div_pipeline );
1250+ {
1251+ int inplace = ggml_webgpu_tensor_equal (src0, node);
1252+ return ggml_webgpu_binary_op (ctx, src0, src1, node, ctx->div_pipeline [node->type ][inplace], inplace);
1253+ }
12731254 case GGML_OP_RMS_NORM:
12741255 return ggml_webgpu_rms_norm (ctx, src0, node);
12751256 case GGML_OP_ROPE:
@@ -1719,82 +1700,50 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
17191700
17201701static void ggml_webgpu_init_add_pipeline (webgpu_context & webgpu_ctx) {
17211702 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x );
1722- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][0 ][ 0 ] , wgsl_add_f32,
1723- " add_f32 " , constants);
1724- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][0 ][ 0 ] , wgsl_add_f16,
1725- " add_f16 " , constants);
1726- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][1 ][ 0 ] , wgsl_add_f32_inplace,
1703+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][0 ], wgsl_add_f32, " add_f32 " ,
1704+ constants);
1705+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][0 ], wgsl_add_f16, " add_f16 " ,
1706+ constants);
1707+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][1 ], wgsl_add_f32_inplace,
17271708 " add_f32_inplace" , constants);
1728- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][1 ][ 0 ] , wgsl_add_f16_inplace,
1709+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][1 ], wgsl_add_f16_inplace,
17291710 " add_f16_inplace" , constants);
1730- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][0 ][1 ], wgsl_add_f32_overlap,
1731- " add_f32_overlap" , constants);
1732- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F32][1 ][1 ],
1733- wgsl_add_f32_inplace_overlap, " add_f32_inplace_overlap" , constants);
1734- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][0 ][1 ], wgsl_add_f16_overlap,
1735- " add_f16_overlap" , constants);
1736- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->add_pipeline [GGML_TYPE_F16][1 ][1 ],
1737- wgsl_add_f16_inplace_overlap, " add_f16_inplace_overlap" , constants);
17381711}
17391712
17401713static void ggml_webgpu_init_sub_pipeline (webgpu_context & webgpu_ctx) {
17411714 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x );
1742- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][0 ][ 0 ] , wgsl_sub_f32,
1743- " sub_f32 " , constants);
1744- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][0 ][ 0 ] , wgsl_sub_f16,
1745- " sub_f16 " , constants);
1746- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][1 ][ 0 ] , wgsl_sub_f32_inplace,
1715+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][0 ], wgsl_sub_f32, " sub_f32 " ,
1716+ constants);
1717+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][0 ], wgsl_sub_f16, " sub_f16 " ,
1718+ constants);
1719+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][1 ], wgsl_sub_f32_inplace,
17471720 " sub_f32_inplace" , constants);
1748- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][1 ][ 0 ] , wgsl_sub_f16_inplace,
1721+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][1 ], wgsl_sub_f16_inplace,
17491722 " sub_f16_inplace" , constants);
1750- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][0 ][1 ], wgsl_sub_f32_overlap,
1751- " sub_f32_overlap" , constants);
1752- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F32][1 ][1 ],
1753- wgsl_sub_f32_inplace_overlap, " sub_f32_inplace_overlap" , constants);
1754- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][0 ][1 ], wgsl_sub_f16_overlap,
1755- " sub_f16_overlap" , constants);
1756- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->sub_pipeline [GGML_TYPE_F16][1 ][1 ],
1757- wgsl_sub_f16_inplace_overlap, " sub_f16_inplace_overlap" , constants);
17581723}
17591724
17601725static void ggml_webgpu_init_mul_pipeline (webgpu_context & webgpu_ctx) {
17611726 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x );
1762- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][0 ][ 0 ] , wgsl_mul_f32,
1763- " mul_f32 " , constants);
1764- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][0 ][ 0 ] , wgsl_mul_f16,
1765- " mul_f16 " , constants);
1766- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][1 ][ 0 ] , wgsl_mul_f32_inplace,
1727+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][0 ], wgsl_mul_f32, " mul_f32 " ,
1728+ constants);
1729+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][0 ], wgsl_mul_f16, " mul_f16 " ,
1730+ constants);
1731+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][1 ], wgsl_mul_f32_inplace,
17671732 " mul_f32_inplace" , constants);
1768- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][1 ][ 0 ] , wgsl_mul_f16_inplace,
1733+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][1 ], wgsl_mul_f16_inplace,
17691734 " mul_f16_inplace" , constants);
1770- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][0 ][1 ], wgsl_mul_f32_overlap,
1771- " mul_f32_overlap" , constants);
1772- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F32][1 ][1 ],
1773- wgsl_mul_f32_inplace_overlap, " mul_f32_inplace_overlap" , constants);
1774- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][0 ][1 ], wgsl_mul_f16_overlap,
1775- " mul_f16_overlap" , constants);
1776- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->mul_pipeline [GGML_TYPE_F16][1 ][1 ],
1777- wgsl_mul_f16_inplace_overlap, " mul_f16_inplace_overlap" , constants);
17781735}
17791736
17801737static void ggml_webgpu_init_div_pipeline (webgpu_context & webgpu_ctx) {
17811738 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x );
1782- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][0 ][ 0 ] , wgsl_div_f32,
1783- " div_f32 " , constants);
1784- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][0 ][ 0 ] , wgsl_div_f16,
1785- " div_f16 " , constants);
1786- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][1 ][ 0 ] , wgsl_div_f32_inplace,
1739+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][0 ], wgsl_div_f32, " div_f32 " ,
1740+ constants);
1741+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][0 ], wgsl_div_f16, " div_f16 " ,
1742+ constants);
1743+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][1 ], wgsl_div_f32_inplace,
17871744 " div_f32_inplace" , constants);
1788- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][1 ][ 0 ] , wgsl_div_f16_inplace,
1745+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][1 ], wgsl_div_f16_inplace,
17891746 " div_f16_inplace" , constants);
1790- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][0 ][1 ], wgsl_div_f32_overlap,
1791- " div_f32_overlap" , constants);
1792- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F32][1 ][1 ],
1793- wgsl_div_f32_inplace_overlap, " div_f32_inplace_overlap" , constants);
1794- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][0 ][1 ], wgsl_div_f16_overlap,
1795- " div_f16_overlap" , constants);
1796- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->div_pipeline [GGML_TYPE_F16][1 ][1 ],
1797- wgsl_div_f16_inplace_overlap, " div_f16_inplace_overlap" , constants);
17981747}
17991748
18001749static void ggml_webgpu_init_rms_norm_pipeline (webgpu_context & webgpu_ctx) {
@@ -2203,9 +2152,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
22032152 // TODO: Don't enable for WASM builds, they won't have an effect anyways
22042153 // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
22052154 // only for native performance?
2206- const char * const deviceEnabledToggles[] = { " disable_robustness" , " disable_workgroup_init" ,
2207- " disable_polyfills_on_integer_div_and_mod" };
2208- const char * const deviceDisabledToggles[] = { " timestamp_quantization" };
2155+ const char * const deviceEnabledToggles[] = { " skip_validation " , " disable_robustness" , " disable_workgroup_init" ,
2156+ " disable_polyfills_on_integer_div_and_mod" };
2157+ const char * const deviceDisabledToggles[] = { " timestamp_quantization" };
22092158 wgpu::DawnTogglesDescriptor deviceTogglesDesc;
22102159 deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
22112160 deviceTogglesDesc.enabledToggleCount = 4 ;
0 commit comments