Skip to content

Commit f5001d8

Browse files
committed
Only check for subgroup matrix configs if they are supported
1 parent f538ca3 commit f5001d8

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2374,7 +2374,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
23742374

23752375
wgpu::AdapterInfo info{};
23762376
wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
2377-
info.nextInChain = &subgroup_matrix_configs;
2377+
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2378+
info.nextInChain = &subgroup_matrix_configs;
2379+
}
23782380
ctx->adapter.GetInfo(&info);
23792381

23802382
wgpu::SupportedFeatures features;
@@ -2384,22 +2386,23 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
23842386

23852387
// Only support square f16 matrices of size 8 or 16 for now
23862388
bool valid_subgroup_matrix_config = false;
2387-
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2388-
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2389-
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2390-
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2391-
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2392-
ctx->subgroup_matrix_config = config;
2393-
valid_subgroup_matrix_config = true;
2394-
break;
2389+
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2390+
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2391+
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2392+
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2393+
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2394+
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2395+
ctx->subgroup_matrix_config = config;
2396+
valid_subgroup_matrix_config = true;
2397+
break;
2398+
}
23952399
}
23962400
}
23972401

23982402
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
23992403
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
24002404
ctx->subgroup_size = info.subgroupMaxSize;
2401-
ctx->supports_subgroup_matrix =
2402-
valid_subgroup_matrix_config && ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2405+
ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
24032406

24042407
// Initialize device
24052408
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,

0 commit comments

Comments
 (0)