Skip to content

Commit 871e18a

Browse files
authored
Merge branch 'main' into fix_aiter_mha_min_seqlen_q
2 parents fdb8793 + 9fc81ec commit 871e18a

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/platforms/tpu.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,25 @@
99

1010
from vllm.inputs import ProcessorInputs, PromptType
1111
from vllm.logger import init_logger
12-
from vllm.sampling_params import SamplingParams, SamplingType
1312

1413
from .interface import Platform, PlatformEnum
1514

1615
if TYPE_CHECKING:
16+
from typing import TypeAlias
17+
1718
from vllm.attention.backends.registry import AttentionBackendEnum
1819
from vllm.config import VllmConfig
1920
from vllm.config.cache import BlockSize
2021
from vllm.pooling_params import PoolingParams
22+
from vllm.sampling_params import SamplingParams
23+
24+
ParamsType: TypeAlias = SamplingParams | PoolingParams
2125
else:
2226
BlockSize = None
2327
VllmConfig = None
2428
PoolingParams = None
2529
AttentionBackendEnum = None
30+
ParamsType = None
2631

2732
logger = init_logger(__name__)
2833

@@ -203,10 +208,12 @@ def get_device_communicator_cls(cls) -> str:
203208
def validate_request(
204209
cls,
205210
prompt: PromptType,
206-
params: SamplingParams | PoolingParams,
211+
params: ParamsType,
207212
processed_inputs: ProcessorInputs,
208213
) -> None:
209214
"""Raises if this request is unsupported on this platform"""
215+
from vllm.sampling_params import SamplingParams, SamplingType
216+
210217
if (
211218
isinstance(params, SamplingParams)
212219
and params.sampling_type == SamplingType.RANDOM_SEED

0 commit comments

Comments
 (0)