@@ -451,6 +451,16 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
451451 ggml_backend_webgpu_build_and_enqueue (ctx, ctx->mul_mat_pipeline , params, entries, wg_x);
452452}
453453
454+ // sample test
455+ // ADD(type=f32, ne=[10,5,4,3], nr=[2,1,1,1], nf=1)
456+ // ne: number of elements in each dimension of tensor b
457+ // nr: number of repetitions in each dimension
458+ // tensor b is the smaller tensor, and is broadcasted with repetitions to match the size of a
459+ // broadcasted with ne * nr
460+ // 10*2, 5*1, 4*1, 3*1 = [20, 5, 4, 3] is the shape of dst and a
461+ // essentially, if nr[x] is > 1, that dimension of b is repeated
462+ // nf: number of fused operations (1 means singular addition)
463+
454464// adds src0 and src1 and puts in dst
455465static void ggml_webgpu_add (webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
456466 // each tensor in GGML is stored inside a buffer on the GPU
@@ -464,27 +474,27 @@ static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
464474 src0_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
465475
466476 size_t src1_offset = ggml_backend_webgpu_tensor_offset (src1);
467- // assumes power of 2 offset alignment
468477 size_t src1_misalignment = src1_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
469- // align to minimum offset alignment
470478 src1_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
471479
472480 size_t dst_offset = ggml_backend_webgpu_tensor_offset (dst);
473481 size_t dst_misalignment = dst_offset & (ctx->limits .minStorageBufferOffsetAlignment - 1 );
474482 dst_offset &= ~(ctx->limits .minStorageBufferOffsetAlignment - 1 );
475-
483+
476484 // set up parameters
477485 std::vector<uint32_t > params = {
478486 // number of elements-- determines how many threads to dispatch (one for each addition operation)
479487 (uint32_t ) ggml_nelements (dst),
480488
481489 // even though tensors are 4d, the actual data is stored linearly
482490 // stride = how many elements (or bytes) we must skip in memory to move from one value to another along a certain dimension
483- // i.e.
484- // nb[0] = 1 // each element is next to the previous
485- // nb[1] = nb[0] * ne[0] = 5 // to move to next row, skip 5 elements
486- // nb[2] = nb[1] * ne[1] = 20 // to next matrix, skip 20 elements
487- // nb[3] = nb[2] * ne[2] = 60 // to next batch, skip 60 elements
491+ // i.e. tensor: [5, 6, 3, 2], ggml_type_size: 4 (each number is 4 bytes)
492+ // (nb = number of bytes to skip for each element (stride))
493+ // (ne = number of elements in that dimension)
494+ // nb[0] = 4 // each element is next to the previous, so only 4 bytes in between
495+ // nb[1] = nb[0] * ne[0] = 4 * 5 = 20 // to move to next row, skip 20 bytes
496+ // nb[2] = nb[1] * ne[1] = 20 * 6 = 120 // to next matrix, skip 120 elements
497+ // nb[3] = nb[2] * ne[2] = 120 * 3 = 360 // to next batch, skip 60 elements
488498
489499 // calculate element strides for each tensor
490500 (uint32_t ) (src0->nb [0 ] / ggml_type_size (src0->type )),
@@ -502,16 +512,24 @@ static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
502512 (uint32_t ) (dst->nb [2 ] / ggml_type_size (dst->type )),
503513 (uint32_t ) (dst->nb [3 ] / ggml_type_size (dst->type )),
504514
505- // number of elements in each dimension
515+ // number of elements in each dimension of larger tensors (src0 and dst)
506516 (uint32_t ) dst->ne [0 ],
507517 (uint32_t ) dst->ne [1 ],
508518 (uint32_t ) dst->ne [2 ],
509519 (uint32_t ) dst->ne [3 ],
510520
521+ // number of elements in each dimension of smaller tensor to be broadcasted (src1)
522+ (uint32_t ) src1->ne [0 ],
523+ (uint32_t ) src1->ne [1 ],
524+ (uint32_t ) src1->ne [2 ],
525+ (uint32_t ) src1->ne [3 ],
526+
511527 // offsets in terms of elements instead of bytes
512528 (uint32_t ) (src0_misalignment / ggml_type_size (src0->type )),
513529 (uint32_t ) (src1_misalignment / ggml_type_size (src1->type )),
514530 (uint32_t ) (dst_misalignment / ggml_type_size (dst->type )),
531+
532+
515533 };
516534
517535 // bind group = groups together several GPU resources that shaders will use (e.g., buffers holding tensor data)
0 commit comments