Skip to content

Commit bd66253

Browse files
committed
fix chunked local attn
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent 3044195 commit bd66253

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

vllm/attention/layers/chunked_local_attention.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import functools
4-
from typing import ClassVar
54

65
import torch
76

@@ -12,11 +11,16 @@
1211
from vllm.model_executor.layers.quantization import QuantizationConfig
1312
from vllm.v1.attention.backends.utils import (
1413
AttentionCGSupport,
14+
AttentionMetadataBuilder,
1515
CommonAttentionMetadata,
1616
make_local_attention_virtual_batches,
1717
subclass_attention_backend,
1818
)
19-
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec
19+
from vllm.v1.kv_cache_interface import (
20+
AttentionSpec,
21+
ChunkedLocalAttentionSpec,
22+
KVCacheSpec,
23+
)
2024

2125
from ..layer import Attention
2226

@@ -30,9 +34,18 @@ def create_chunked_local_attention_backend(
3034
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
3135

3236
underlying_builder = underlying_attn_backend.get_builder_cls()
37+
assert issubclass(underlying_builder, AttentionMetadataBuilder)
3338

3439
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
35-
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
40+
@classmethod
41+
def get_cudagraph_support(
42+
cls: type["AttentionMetadataBuilder"],
43+
vllm_config: VllmConfig,
44+
kv_cache_spec: AttentionSpec,
45+
) -> AttentionCGSupport:
46+
# Explicit override in case the underlying builder specialized this getter.
47+
# @override omitted only because of mypy limitation due to type variable.
48+
return AttentionCGSupport.NEVER
3649

3750
def build(
3851
self,

0 commit comments

Comments
 (0)