Skip to content

Commit 00ab75e

Browse files
fix(benchmarks): correct sdpa_backend inconsistency and attn_implementation for continuous batching (#42339)
This commit fixes two bugs in BenchmarkConfig reported in issue #42211: 1. **sdpa_backend inconsistency (line 105)**: The warning message states "sdpa_backend must be None" but the code was setting it to "math". Changed to None to match the warning message. This allows PyTorch to auto-select the appropriate SDPA backend rather than forcing one globally, which is correct for continuous batching with custom attention masks. 2. **Invalid attn_implementation (line 243)**: Changed from "paged|sdpa" to "sdpa". Using "paged|sdpa" directly bypassed the validation logic at lines 91-105 since it only checks for exactly "sdpa". The "paged|" prefix is automatically added by init_continuous_batching() in continuous_api.py, so the config should use plain "sdpa" for consistency with other configs. Both bugs were introduced in commit 069684e (PR #41916). Fixes #42211
1 parent 3410ba9 commit 00ab75e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

benchmark_v2/framework/benchmark_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def check_validity(self, skip_validity_check: bool = False) -> None:
102102
logger.warning(
103103
"when continuous batching is enabled, sdpa_backend must be None because of the attention mask, setting it to None"
104104
)
105-
self.sdpa_backend = "math"
105+
self.sdpa_backend = None
106106

107107
@property
108108
def hash(self) -> str:
@@ -240,5 +240,5 @@ def get_config_by_level(level: int) -> list[BenchmarkConfig]:
240240
configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_mode="default"))
241241
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", kernelize=True))
242242
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", kernelize=True))
243-
configs.append(BenchmarkConfig(attn_implementation="paged|sdpa", continuous_batching=True))
243+
configs.append(BenchmarkConfig(attn_implementation="sdpa", continuous_batching=True))
244244
return configs

0 commit comments

Comments
 (0)