Skip to content

Commit 9edfcc9

Browse files
committed
Working subgroup matrix code for (semi)generic sizes
1 parent 0f6e38d commit 9edfcc9

File tree

3 files changed

+148
-91
lines changed

3 files changed

+148
-91
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 100 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,25 @@ struct ggml_backend_webgpu_buffer_context {
355355

356356
/* WebGPU object initializations */
357357

358+
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
359+
// the corresponding values provided in `repls`.
360+
static std::string ggml_webgpu_process_shader_repls(const char * src,
361+
const std::vector<std::pair<std::string, std::string>> & repls) {
362+
if (!src) {
363+
return std::string();
364+
}
365+
std::string s = src;
366+
for (const auto & kv : repls) {
367+
std::string token = "{{" + kv.first + "}}";
368+
size_t pos = 0;
369+
while ((pos = s.find(token, pos)) != std::string::npos) {
370+
s.replace(pos, token.length(), kv.second);
371+
pos += kv.second.length();
372+
}
373+
}
374+
return s;
375+
}
376+
358377
static void ggml_webgpu_create_pipeline(wgpu::Device & device,
359378
webgpu_pipeline & pipeline,
360379
const char * shader_code,
@@ -1749,40 +1768,45 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
17491768
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
17501769

17511770
if (webgpu_ctx->supports_subgroup_matrix) {
1752-
std::vector<wgpu::ConstantEntry> mul_mat_sg_mat_constants(7);
1753-
mul_mat_sg_mat_constants[0].key = "TILE_K";
1754-
mul_mat_sg_mat_constants[0].value = WEBGPU_MUL_MAT_TILE_K;
1755-
mul_mat_sg_mat_constants[1].key = "SUBGROUP_M";
1756-
mul_mat_sg_mat_constants[1].value = WEBGPU_MUL_MAT_SUBGROUP_M;
1757-
mul_mat_sg_mat_constants[2].key = "SUBGROUP_N";
1758-
mul_mat_sg_mat_constants[2].value = WEBGPU_MUL_MAT_SUBGROUP_N;
1759-
mul_mat_sg_mat_constants[3].key = "SUBGROUP_MATRIX_M_SIZE";
1760-
mul_mat_sg_mat_constants[3].value = static_cast<double>(webgpu_ctx->subgroup_matrix_config.M);
1761-
mul_mat_sg_mat_constants[4].key = "SUBGROUP_MATRIX_N_SIZE";
1762-
mul_mat_sg_mat_constants[4].value = static_cast<double>(webgpu_ctx->subgroup_matrix_config.N);
1763-
mul_mat_sg_mat_constants[5].key = "SUBGROUP_SIZE";
1764-
mul_mat_sg_mat_constants[5].value = static_cast<double>(webgpu_ctx->subgroup_size);
1765-
mul_mat_sg_mat_constants[6].key = "SUBGROUP_MATRIX_K_SIZE";
1766-
mul_mat_sg_mat_constants[6].value = static_cast<double>(webgpu_ctx->subgroup_matrix_config.K);
1771+
std::vector<std::pair<std::string, std::string>> sg_matrix_repls;
1772+
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size));
1773+
sg_matrix_repls.emplace_back("WEBGPU_TILE_K", std::to_string(WEBGPU_MUL_MAT_TILE_K));
1774+
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M));
1775+
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N));
1776+
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M));
1777+
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N));
1778+
sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_M_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.M));
1779+
sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_N_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.N));
1780+
sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_K_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.K));
1781+
1782+
std::string proc_mul_mat_subgroup_matrix_f32_f32 =
1783+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
1784+
std::string proc_mul_mat_subgroup_matrix_f32_f32_vec =
1785+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
1786+
std::string proc_mul_mat_subgroup_matrix_f16_f32 =
1787+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
1788+
std::string proc_mul_mat_subgroup_matrix_f16_f32_vec =
1789+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
1790+
std::string proc_mul_mat_subgroup_matrix_f16_f16 =
1791+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
1792+
std::string proc_mul_mat_subgroup_matrix_f16_f16_vec =
1793+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
17671794

1768-
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
1769-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32,
1770-
"mul_mat_subgroup_matrix_f32_f32", mul_mat_sg_mat_constants);
1795+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
1796+
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32");
17711797
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
1772-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f32_f32_vec,
1773-
"mul_mat_subgroup_matrix_f32_f32_vec", mul_mat_sg_mat_constants);
1774-
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] =
1775-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32,
1776-
"mul_mat_subgroup_matrix_f16_f32", mul_mat_sg_mat_constants);
1798+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(),
1799+
"mul_mat_subgroup_matrix_f32_f32_vec");
1800+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
1801+
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32");
17771802
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
1778-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f32_vec,
1779-
"mul_mat_subgroup_matrix_f16_f32_vec", mul_mat_sg_mat_constants);
1780-
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] =
1781-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16,
1782-
"mul_mat_subgroup_matrix_f16_f16", mul_mat_sg_mat_constants);
1803+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(),
1804+
"mul_mat_subgroup_matrix_f16_f32_vec");
1805+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
1806+
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16");
17831807
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
1784-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_subgroup_matrix_f16_f16_vec,
1785-
"mul_mat_subgroup_matrix_f16_f16_vec", mul_mat_sg_mat_constants);
1808+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(),
1809+
"mul_mat_subgroup_matrix_f16_f16_vec");
17861810
} else {
17871811
std::vector<wgpu::ConstantEntry> mul_mat_reg_tile_constants(3);
17881812
mul_mat_reg_tile_constants[0].key = "TILE_K";
@@ -1792,20 +1816,42 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
17921816
mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N";
17931817
mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N;
17941818

1795-
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
1796-
webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32, "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants);
1819+
std::vector<std::pair<std::string, std::string>> reg_repls;
1820+
reg_repls.emplace_back("WEBGPU_TILE_M", std::to_string(WEBGPU_MUL_MAT_TILE_M));
1821+
reg_repls.emplace_back("WEBGPU_TILE_N", std::to_string(WEBGPU_MUL_MAT_TILE_N));
1822+
1823+
// Process each reg-tile shader with tile replacements.
1824+
// Keep the processed strings in-scope so .c_str() remains valid.
1825+
std::string proc_mul_mat_reg_tile_f32_f32 =
1826+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
1827+
std::string proc_mul_mat_reg_tile_f32_f32_vec =
1828+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
1829+
std::string proc_mul_mat_reg_tile_f16_f32 =
1830+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
1831+
std::string proc_mul_mat_reg_tile_f16_f32_vec =
1832+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
1833+
std::string proc_mul_mat_reg_tile_f16_f16 =
1834+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
1835+
std::string proc_mul_mat_reg_tile_f16_f16_vec =
1836+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
1837+
1838+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
1839+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(),
1840+
"mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants);
17971841
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
1798-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f32_f32_vec,
1842+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(),
17991843
"mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants);
1800-
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
1801-
webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32, "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants);
1844+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] =
1845+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(),
1846+
"mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants);
18021847
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
1803-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f32_vec,
1848+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(),
18041849
"mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants);
1805-
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
1806-
webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16, "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants);
1850+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] =
1851+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(),
1852+
"mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants);
18071853
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
1808-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_reg_tile_f16_f16_vec,
1854+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(),
18091855
"mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants);
18101856
}
18111857

@@ -2354,18 +2400,30 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
23542400
std::cout << " Result Type: " << static_cast<int>(config.resultComponentType) << "\n";
23552401
}
23562402

2357-
ctx->subgroup_matrix_config = *subgroup_matrix_configs.configs;
23582403
wgpu::SupportedFeatures features;
23592404
ctx->adapter.GetFeatures(&features);
23602405
// we require f16 support
23612406
GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
23622407

2408+
// Only support square f16 matrices of size 8 or 16 for now
2409+
bool valid_subgroup_matrix_config = false;
2410+
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2411+
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2412+
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2413+
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2414+
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2415+
ctx->subgroup_matrix_config = config;
2416+
valid_subgroup_matrix_config = true;
2417+
break;
2418+
}
2419+
}
23632420
// For subgroup matrix code to be workable, we really need a consistent subgroup size.
23642421
// Unfortunately, WebGPU allows info.subgroup{Min/Max}Size to be different, and even on devices
23652422
// where it is consistent, e.g., Apple M-series GPUs, the min/max sizes report different values.
23662423
// Therefore, hardcoding the subgroup size to 32 for now for development.
2367-
ctx->subgroup_size = 32;
2368-
ctx->supports_subgroup_matrix = ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2424+
ctx->subgroup_size = 32;
2425+
ctx->supports_subgroup_matrix =
2426+
valid_subgroup_matrix_config && ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
23692427

23702428
// Initialize device
23712429
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ fn get_local_m(thread_id: u32) -> u32 {
187187

188188
// Warning: cannot be overrides, must match values in ggml-webgpu.cpp
189189
// TILE_M must be multiple of 4 for vec4 loads
190-
const TILE_M = 4u;
191-
const TILE_N = 4u;
190+
const TILE_M = {{WEBGPU_TILE_M}}u;
191+
const TILE_N = {{WEBGPU_TILE_N}}u;
192192

193193
override WORKGROUP_SIZE_M: u32;
194194
override WORKGROUP_SIZE_N: u32;

0 commit comments

Comments
 (0)