@@ -355,6 +355,25 @@ struct ggml_backend_webgpu_buffer_context {
355355
356356/* WebGPU object initializations */
357357
358+ // Process a WGSL shader string, replacing tokens of the form {{KEY}} with
359+ // the corresponding values provided in `repls`.
360+ static std::string ggml_webgpu_process_shader_repls (const char * src,
361+ const std::vector<std::pair<std::string, std::string>> & repls) {
362+ if (!src) {
363+ return std::string ();
364+ }
365+ std::string s = src;
366+ for (const auto & kv : repls) {
367+ std::string token = " {{" + kv.first + " }}" ;
368+ size_t pos = 0 ;
369+ while ((pos = s.find (token, pos)) != std::string::npos) {
370+ s.replace (pos, token.length (), kv.second );
371+ pos += kv.second .length ();
372+ }
373+ }
374+ return s;
375+ }
376+
358377static void ggml_webgpu_create_pipeline (wgpu::Device & device,
359378 webgpu_pipeline & pipeline,
360379 const char * shader_code,
@@ -1749,40 +1768,45 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
17491768 wgsl_mul_mat_iq4_xs_f32, " mul_mat_iq4_xs_f32" );
17501769
17511770 if (webgpu_ctx->supports_subgroup_matrix ) {
1752- std::vector<wgpu::ConstantEntry> mul_mat_sg_mat_constants (7 );
1753- mul_mat_sg_mat_constants[0 ].key = " TILE_K" ;
1754- mul_mat_sg_mat_constants[0 ].value = WEBGPU_MUL_MAT_TILE_K;
1755- mul_mat_sg_mat_constants[1 ].key = " SUBGROUP_M" ;
1756- mul_mat_sg_mat_constants[1 ].value = WEBGPU_MUL_MAT_SUBGROUP_M;
1757- mul_mat_sg_mat_constants[2 ].key = " SUBGROUP_N" ;
1758- mul_mat_sg_mat_constants[2 ].value = WEBGPU_MUL_MAT_SUBGROUP_N;
1759- mul_mat_sg_mat_constants[3 ].key = " SUBGROUP_MATRIX_M_SIZE" ;
1760- mul_mat_sg_mat_constants[3 ].value = static_cast <double >(webgpu_ctx->subgroup_matrix_config .M );
1761- mul_mat_sg_mat_constants[4 ].key = " SUBGROUP_MATRIX_N_SIZE" ;
1762- mul_mat_sg_mat_constants[4 ].value = static_cast <double >(webgpu_ctx->subgroup_matrix_config .N );
1763- mul_mat_sg_mat_constants[5 ].key = " SUBGROUP_SIZE" ;
1764- mul_mat_sg_mat_constants[5 ].value = static_cast <double >(webgpu_ctx->subgroup_size );
1765- mul_mat_sg_mat_constants[6 ].key = " SUBGROUP_MATRIX_K_SIZE" ;
1766- mul_mat_sg_mat_constants[6 ].value = static_cast <double >(webgpu_ctx->subgroup_matrix_config .K );
1771+ std::vector<std::pair<std::string, std::string>> sg_matrix_repls;
1772+ sg_matrix_repls.emplace_back (" WEBGPU_SUBGROUP_SIZE" , std::to_string (webgpu_ctx->subgroup_size ));
1773+ sg_matrix_repls.emplace_back (" WEBGPU_TILE_K" , std::to_string (WEBGPU_MUL_MAT_TILE_K));
1774+ sg_matrix_repls.emplace_back (" WEBGPU_SUBGROUP_M" , std::to_string (WEBGPU_MUL_MAT_SUBGROUP_M));
1775+ sg_matrix_repls.emplace_back (" WEBGPU_SUBGROUP_N" , std::to_string (WEBGPU_MUL_MAT_SUBGROUP_N));
1776+ sg_matrix_repls.emplace_back (" WEBGPU_SUBGROUP_MATRIX_M" , std::to_string (WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M));
1777+ sg_matrix_repls.emplace_back (" WEBGPU_SUBGROUP_MATRIX_N" , std::to_string (WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N));
1778+ sg_matrix_repls.emplace_back (" WEBGPU_SG_MAT_M_SIZE" , std::to_string (webgpu_ctx->subgroup_matrix_config .M ));
1779+ sg_matrix_repls.emplace_back (" WEBGPU_SG_MAT_N_SIZE" , std::to_string (webgpu_ctx->subgroup_matrix_config .N ));
1780+ sg_matrix_repls.emplace_back (" WEBGPU_SG_MAT_K_SIZE" , std::to_string (webgpu_ctx->subgroup_matrix_config .K ));
1781+
1782+ std::string proc_mul_mat_subgroup_matrix_f32_f32 =
1783+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
1784+ std::string proc_mul_mat_subgroup_matrix_f32_f32_vec =
1785+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
1786+ std::string proc_mul_mat_subgroup_matrix_f16_f32 =
1787+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
1788+ std::string proc_mul_mat_subgroup_matrix_f16_f32_vec =
1789+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
1790+ std::string proc_mul_mat_subgroup_matrix_f16_f16 =
1791+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
1792+ std::string proc_mul_mat_subgroup_matrix_f16_f16_vec =
1793+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
17671794
1768- webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][0 ] =
1769- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_subgroup_matrix_f32_f32,
1770- " mul_mat_subgroup_matrix_f32_f32" , mul_mat_sg_mat_constants);
1795+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][0 ] = ggml_webgpu_create_pipeline2 (
1796+ webgpu_ctx->device , proc_mul_mat_subgroup_matrix_f32_f32.c_str (), " mul_mat_subgroup_matrix_f32_f32" );
17711797 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][1 ] =
1772- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_subgroup_matrix_f32_f32_vec,
1773- " mul_mat_subgroup_matrix_f32_f32_vec" , mul_mat_sg_mat_constants);
1774- webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][0 ] =
1775- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_subgroup_matrix_f16_f32,
1776- " mul_mat_subgroup_matrix_f16_f32" , mul_mat_sg_mat_constants);
1798+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str (),
1799+ " mul_mat_subgroup_matrix_f32_f32_vec" );
1800+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][0 ] = ggml_webgpu_create_pipeline2 (
1801+ webgpu_ctx->device , proc_mul_mat_subgroup_matrix_f16_f32.c_str (), " mul_mat_subgroup_matrix_f16_f32" );
17771802 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][1 ] =
1778- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_subgroup_matrix_f16_f32_vec,
1779- " mul_mat_subgroup_matrix_f16_f32_vec" , mul_mat_sg_mat_constants);
1780- webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][0 ] =
1781- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_subgroup_matrix_f16_f16,
1782- " mul_mat_subgroup_matrix_f16_f16" , mul_mat_sg_mat_constants);
1803+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str (),
1804+ " mul_mat_subgroup_matrix_f16_f32_vec" );
1805+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][0 ] = ggml_webgpu_create_pipeline2 (
1806+ webgpu_ctx->device , proc_mul_mat_subgroup_matrix_f16_f16.c_str (), " mul_mat_subgroup_matrix_f16_f16" );
17831807 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][1 ] =
1784- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_subgroup_matrix_f16_f16_vec ,
1785- " mul_mat_subgroup_matrix_f16_f16_vec" , mul_mat_sg_mat_constants );
1808+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_subgroup_matrix_f16_f16_vec. c_str () ,
1809+ " mul_mat_subgroup_matrix_f16_f16_vec" );
17861810 } else {
17871811 std::vector<wgpu::ConstantEntry> mul_mat_reg_tile_constants (3 );
17881812 mul_mat_reg_tile_constants[0 ].key = " TILE_K" ;
@@ -1792,20 +1816,42 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
17921816 mul_mat_reg_tile_constants[2 ].key = " WORKGROUP_SIZE_N" ;
17931817 mul_mat_reg_tile_constants[2 ].value = WEBGPU_MUL_MAT_WG_SIZE_N;
17941818
1795- webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][0 ] = ggml_webgpu_create_pipeline2 (
1796- webgpu_ctx->device , wgsl_mul_mat_reg_tile_f32_f32, " mul_mat_reg_tile_f32_f32" , mul_mat_reg_tile_constants);
1819+ std::vector<std::pair<std::string, std::string>> reg_repls;
1820+ reg_repls.emplace_back (" WEBGPU_TILE_M" , std::to_string (WEBGPU_MUL_MAT_TILE_M));
1821+ reg_repls.emplace_back (" WEBGPU_TILE_N" , std::to_string (WEBGPU_MUL_MAT_TILE_N));
1822+
1823+ // Process each reg-tile shader with tile replacements.
1824+ // Keep the processed strings in-scope so .c_str() remains valid.
1825+ std::string proc_mul_mat_reg_tile_f32_f32 =
1826+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
1827+ std::string proc_mul_mat_reg_tile_f32_f32_vec =
1828+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
1829+ std::string proc_mul_mat_reg_tile_f16_f32 =
1830+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
1831+ std::string proc_mul_mat_reg_tile_f16_f32_vec =
1832+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
1833+ std::string proc_mul_mat_reg_tile_f16_f16 =
1834+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
1835+ std::string proc_mul_mat_reg_tile_f16_f16_vec =
1836+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
1837+
1838+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][0 ] =
1839+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_f32_f32.c_str (),
1840+ " mul_mat_reg_tile_f32_f32" , mul_mat_reg_tile_constants);
17971841 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][1 ] =
1798- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_reg_tile_f32_f32_vec ,
1842+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_f32_f32_vec. c_str () ,
17991843 " mul_mat_reg_tile_f32_f32_vec" , mul_mat_reg_tile_constants);
1800- webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][0 ] = ggml_webgpu_create_pipeline2 (
1801- webgpu_ctx->device , wgsl_mul_mat_reg_tile_f16_f32, " mul_mat_reg_tile_f16_f32" , mul_mat_reg_tile_constants);
1844+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][0 ] =
1845+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_f16_f32.c_str (),
1846+ " mul_mat_reg_tile_f16_f32" , mul_mat_reg_tile_constants);
18021847 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][1 ] =
1803- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_reg_tile_f16_f32_vec ,
1848+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_f16_f32_vec. c_str () ,
18041849 " mul_mat_reg_tile_f16_f32_vec" , mul_mat_reg_tile_constants);
1805- webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][0 ] = ggml_webgpu_create_pipeline2 (
1806- webgpu_ctx->device , wgsl_mul_mat_reg_tile_f16_f16, " mul_mat_reg_tile_f16_f16" , mul_mat_reg_tile_constants);
1850+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][0 ] =
1851+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_f16_f16.c_str (),
1852+ " mul_mat_reg_tile_f16_f16" , mul_mat_reg_tile_constants);
18071853 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][1 ] =
1808- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_reg_tile_f16_f16_vec ,
1854+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_f16_f16_vec. c_str () ,
18091855 " mul_mat_reg_tile_f16_f16_vec" , mul_mat_reg_tile_constants);
18101856 }
18111857
@@ -2354,18 +2400,30 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
23542400 std::cout << " Result Type: " << static_cast <int >(config.resultComponentType ) << " \n " ;
23552401 }
23562402
2357- ctx->subgroup_matrix_config = *subgroup_matrix_configs.configs ;
23582403 wgpu::SupportedFeatures features;
23592404 ctx->adapter .GetFeatures (&features);
23602405 // we require f16 support
23612406 GGML_ASSERT (ctx->adapter .HasFeature (wgpu::FeatureName::ShaderF16));
23622407
2408+ // Only support square f16 matrices of size 8 or 16 for now
2409+ bool valid_subgroup_matrix_config = false ;
2410+ for (size_t i = 0 ; i < subgroup_matrix_configs.configCount ; i++) {
2411+ const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs [i];
2412+ if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16 ) &&
2413+ config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2414+ config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2415+ ctx->subgroup_matrix_config = config;
2416+ valid_subgroup_matrix_config = true ;
2417+ break ;
2418+ }
2419+ }
23632420 // For subgroup matrix code to be workable, we really need a consistent subgroup size.
23642421 // Unfortunately, WebGPU allows info.subgroup{Min/Max}Size to be different, and even on devices
23652422 // where it is consistent, e.g., Apple M-series GPUs, the min/max sizes report different values.
23662423 // Therefore, hardcoding the subgroup size to 32 for now for development.
2367- ctx->subgroup_size = 32 ;
2368- ctx->supports_subgroup_matrix = ctx->adapter .HasFeature (wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2424+ ctx->subgroup_size = 32 ;
2425+ ctx->supports_subgroup_matrix =
2426+ valid_subgroup_matrix_config && ctx->adapter .HasFeature (wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
23692427
23702428 // Initialize device
23712429 std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
0 commit comments