Skip to content

Commit b319672

Browse files
committed
Revert "Implement overlap binary operators"
This reverts commit ed710b3.
1 parent ed710b3 commit b319672

File tree

4 files changed

+159
-436
lines changed

4 files changed

+159
-436
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 64 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ struct webgpu_context_struct {
252252
webgpu_pipeline get_rows_pipeline[30];
253253
webgpu_pipeline get_rows_f32_no_vec_pipeline;
254254
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
255-
webgpu_pipeline add_pipeline[2][2][2]; // type, inplace, overlap
256-
webgpu_pipeline sub_pipeline[2][2][2]; // type, inplace, overlap
257-
webgpu_pipeline mul_pipeline[2][2][2]; // type, inplace, overlap
258-
webgpu_pipeline div_pipeline[2][2][2]; // type, inplace, overlap
255+
webgpu_pipeline add_pipeline[2][2]; // type, inplace
256+
webgpu_pipeline sub_pipeline[2][2]; // type, inplace
257+
webgpu_pipeline mul_pipeline[2][2]; // type, inplace
258+
webgpu_pipeline div_pipeline[2][2]; // type, inplace
259259
webgpu_pipeline rms_norm_pipeline[2]; // inplace
260260
webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace
261261
webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split
@@ -677,12 +677,9 @@ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor
677677
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
678678
}
679679

680-
static size_t ggml_webgpu_tensor_align_binding_size(size_t size) {
681-
return (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
682-
}
683-
684680
static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
685-
return ggml_webgpu_tensor_align_binding_size(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t));
681+
return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
682+
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
686683
}
687684

688685
// Used to determine if two tensors are the same for in-place operations
@@ -691,12 +688,6 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
691688
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
692689
}
693690

694-
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
695-
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
696-
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
697-
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
698-
}
699-
700691
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
701692
uint32_t ne = (uint32_t) ggml_nelements(dst);
702693

@@ -879,27 +870,16 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
879870
return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
880871
}
881872

882-
template <size_t a, size_t b, size_t c>
883-
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
884-
ggml_tensor * src0,
885-
ggml_tensor * src1,
886-
ggml_tensor * dst,
887-
webgpu_pipeline (&pipelines)[a][b][c]) {
888-
int inplace = ggml_webgpu_tensor_equal(src0, dst);
889-
int overlap = ggml_webgpu_tensor_overlap(src0, src1);
890-
webgpu_pipeline pipeline = pipelines[dst->type][inplace][overlap];
891-
892-
uint32_t src1_offset = ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type);
893-
if (overlap) {
894-
// when overlapped, bind a single buffer covering both src0 and src1
895-
// TODO: Do other operations need this?
896-
src1_offset = (uint32_t) ((ggml_webgpu_tensor_offset(src1) - ggml_webgpu_tensor_align_offset(ctx, src0)) /
897-
ggml_type_size(src1->type));
898-
}
873+
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
874+
ggml_tensor * src0,
875+
ggml_tensor * src1,
876+
ggml_tensor * dst,
877+
webgpu_pipeline & pipeline,
878+
bool inplace) {
899879
std::vector<uint32_t> params = {
900880
(uint32_t) ggml_nelements(dst),
901881
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
902-
src1_offset,
882+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
903883
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
904884
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
905885
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
@@ -914,36 +894,25 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
914894
(uint32_t) src1->ne[3],
915895
};
916896

917-
size_t src0_binding_size = ggml_webgpu_tensor_binding_size(ctx, src0);
918-
if (overlap) {
919-
const uint64_t base_align = ggml_webgpu_tensor_align_offset(ctx, src0);
920-
// assume end of src1 is >= end of src0
921-
const uint64_t max_end = ggml_webgpu_tensor_offset(src1) + ggml_nbytes(src1);
922-
src0_binding_size = ggml_webgpu_tensor_align_binding_size(max_end - base_align);
923-
}
924897
std::vector<wgpu::BindGroupEntry> entries = {
925898
{ .binding = 0,
926899
.buffer = ggml_webgpu_tensor_buf(src0),
927900
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
928-
.size = src0_binding_size }
901+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
902+
{ .binding = 1,
903+
.buffer = ggml_webgpu_tensor_buf(src1),
904+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
905+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
929906
};
930-
uint32_t binding_num = 1;
931-
if (!overlap) {
932-
entries.push_back({ .binding = binding_num,
933-
.buffer = ggml_webgpu_tensor_buf(src1),
934-
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
935-
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
936-
binding_num++;
937-
}
938907
if (!inplace) {
939-
entries.push_back({ .binding = binding_num,
908+
entries.push_back({ .binding = 2,
940909
.buffer = ggml_webgpu_tensor_buf(dst),
941910
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
942911
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
943912
}
944913

945-
size_t max_wg_size = ctx->max_wg_size_x;
946-
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
914+
size_t max_wg_size = ctx->max_wg_size_x;
915+
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
947916
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
948917
}
949918

@@ -1263,13 +1232,25 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
12631232
case GGML_OP_MUL_MAT:
12641233
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
12651234
case GGML_OP_ADD:
1266-
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline);
1235+
{
1236+
int inplace = ggml_webgpu_tensor_equal(src0, node);
1237+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace);
1238+
}
12671239
case GGML_OP_SUB:
1268-
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline);
1240+
{
1241+
int inplace = ggml_webgpu_tensor_equal(src0, node);
1242+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace);
1243+
}
12691244
case GGML_OP_MUL:
1270-
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline);
1245+
{
1246+
int inplace = ggml_webgpu_tensor_equal(src0, node);
1247+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace);
1248+
}
12711249
case GGML_OP_DIV:
1272-
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline);
1250+
{
1251+
int inplace = ggml_webgpu_tensor_equal(src0, node);
1252+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace);
1253+
}
12731254
case GGML_OP_RMS_NORM:
12741255
return ggml_webgpu_rms_norm(ctx, src0, node);
12751256
case GGML_OP_ROPE:
@@ -1719,82 +1700,50 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
17191700

17201701
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
17211702
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
1722-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0][0], wgsl_add_f32,
1723-
"add_f32", constants);
1724-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0][0], wgsl_add_f16,
1725-
"add_f16", constants);
1726-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1][0], wgsl_add_f32_inplace,
1703+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
1704+
constants);
1705+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
1706+
constants);
1707+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace,
17271708
"add_f32_inplace", constants);
1728-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1][0], wgsl_add_f16_inplace,
1709+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace,
17291710
"add_f16_inplace", constants);
1730-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0][1], wgsl_add_f32_overlap,
1731-
"add_f32_overlap", constants);
1732-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1][1],
1733-
wgsl_add_f32_inplace_overlap, "add_f32_inplace_overlap", constants);
1734-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0][1], wgsl_add_f16_overlap,
1735-
"add_f16_overlap", constants);
1736-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1][1],
1737-
wgsl_add_f16_inplace_overlap, "add_f16_inplace_overlap", constants);
17381711
}
17391712

17401713
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
17411714
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
1742-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0][0], wgsl_sub_f32,
1743-
"sub_f32", constants);
1744-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0][0], wgsl_sub_f16,
1745-
"sub_f16", constants);
1746-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1][0], wgsl_sub_f32_inplace,
1715+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
1716+
constants);
1717+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
1718+
constants);
1719+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace,
17471720
"sub_f32_inplace", constants);
1748-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1][0], wgsl_sub_f16_inplace,
1721+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace,
17491722
"sub_f16_inplace", constants);
1750-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0][1], wgsl_sub_f32_overlap,
1751-
"sub_f32_overlap", constants);
1752-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1][1],
1753-
wgsl_sub_f32_inplace_overlap, "sub_f32_inplace_overlap", constants);
1754-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0][1], wgsl_sub_f16_overlap,
1755-
"sub_f16_overlap", constants);
1756-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1][1],
1757-
wgsl_sub_f16_inplace_overlap, "sub_f16_inplace_overlap", constants);
17581723
}
17591724

17601725
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
17611726
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
1762-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0][0], wgsl_mul_f32,
1763-
"mul_f32", constants);
1764-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0][0], wgsl_mul_f16,
1765-
"mul_f16", constants);
1766-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1][0], wgsl_mul_f32_inplace,
1727+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
1728+
constants);
1729+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
1730+
constants);
1731+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace,
17671732
"mul_f32_inplace", constants);
1768-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1][0], wgsl_mul_f16_inplace,
1733+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace,
17691734
"mul_f16_inplace", constants);
1770-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0][1], wgsl_mul_f32_overlap,
1771-
"mul_f32_overlap", constants);
1772-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1][1],
1773-
wgsl_mul_f32_inplace_overlap, "mul_f32_inplace_overlap", constants);
1774-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0][1], wgsl_mul_f16_overlap,
1775-
"mul_f16_overlap", constants);
1776-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1][1],
1777-
wgsl_mul_f16_inplace_overlap, "mul_f16_inplace_overlap", constants);
17781735
}
17791736

17801737
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
17811738
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
1782-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0][0], wgsl_div_f32,
1783-
"div_f32", constants);
1784-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0][0], wgsl_div_f16,
1785-
"div_f16", constants);
1786-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1][0], wgsl_div_f32_inplace,
1739+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
1740+
constants);
1741+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
1742+
constants);
1743+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace,
17871744
"div_f32_inplace", constants);
1788-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1][0], wgsl_div_f16_inplace,
1745+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace,
17891746
"div_f16_inplace", constants);
1790-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0][1], wgsl_div_f32_overlap,
1791-
"div_f32_overlap", constants);
1792-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1][1],
1793-
wgsl_div_f32_inplace_overlap, "div_f32_inplace_overlap", constants);
1794-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0][1], wgsl_div_f16_overlap,
1795-
"div_f16_overlap", constants);
1796-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1][1],
1797-
wgsl_div_f16_inplace_overlap, "div_f16_inplace_overlap", constants);
17981747
}
17991748

18001749
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
@@ -2203,9 +2152,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
22032152
// TODO: Don't enable for WASM builds, they won't have an effect anyways
22042153
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
22052154
// only for native performance?
2206-
const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init",
2207-
"disable_polyfills_on_integer_div_and_mod" };
2208-
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2155+
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2156+
"disable_polyfills_on_integer_div_and_mod" };
2157+
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
22092158
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
22102159
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
22112160
deviceTogglesDesc.enabledToggleCount = 4;

0 commit comments

Comments
 (0)