diff --git a/.github/workflows/tests-integration.yaml b/.github/workflows/tests-integration.yaml index 66a317b9..7fa53529 100644 --- a/.github/workflows/tests-integration.yaml +++ b/.github/workflows/tests-integration.yaml @@ -27,7 +27,7 @@ jobs: - name: Set expected providers for local integration tests id: set_local_providers - run: echo "expected_providers=ollama,llamacpp,llamafile" >> $GITHUB_OUTPUT + run: echo "expected_providers=ollama,llamacpp,llamafile,vllm" >> $GITHUB_OUTPUT determine-jobs-to-run: needs: expected-providers @@ -144,6 +144,14 @@ jobs: restore-keys: | ${{ runner.os }}-llamafile- + - uses: actions/cache@v4 + if: github.event.inputs.filter == '' || contains(github.event.inputs.filter, 'vllm') + with: + path: ~/.cache/huggingface + key: ${{ runner.os }}-vllm-models-${{ hashFiles('tests/conftest.py') }} + restore-keys: | + ${{ runner.os }}-vllm-models- + - name: Setup Ollama if: github.event.inputs.filter == '' || contains(github.event.inputs.filter, 'ollama') uses: ai-action/setup-ollama@v1 @@ -192,6 +200,19 @@ jobs: run: | timeout 60 bash -c 'until curl -s http://localhost:8090/health >/dev/null; do sleep 1; done' + - name: Install and Run vLLM (CPU) + if: github.event.inputs.filter == '' || contains(github.event.inputs.filter, 'vllm') + run: | + uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly/cpu --prerelease=allow --index-strategy unsafe-best-match --torch-backend cpu + export VLLM_CPU_KVCACHE_SPACE=4 + vllm serve Qwen/Qwen2.5-0.5B-Instruct --dtype bfloat16 --max-model-len 2048 --port 8080 > vllm.log 2>&1 & + echo $! > vllm.pid + + - name: Wait for vLLM to be ready + if: github.event.inputs.filter == '' || contains(github.event.inputs.filter, 'vllm') + run: | + timeout 600 bash -c 'until curl -s http://localhost:8080/v1/models >/dev/null; do echo "Waiting for vLLM..."; sleep 5; done' || (echo "=== vLLM logs ===" && cat vllm.log && exit 1) + - name: Run Local Provider Integration tests env: INCLUDE_LOCAL_PROVIDERS: "true" @@ -211,6 +232,14 @@ jobs: rm llamafile.pid fi + - name: Cleanup vLLM process + if: always() + run: | + if [ -f vllm.pid ]; then + kill $(cat vllm.pid) || true + rm vllm.pid + fi + - name: Upload coverage reports to Codecov if: always() uses: codecov/codecov-action@v5 diff --git a/pyproject.toml b/pyproject.toml index d1faaa7b..8f5e3fca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ [project.optional-dependencies] all = [ - "any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,cohere,cerebras,fireworks,groq,bedrock,azure,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,platform,llamafile,llamacpp,sagemaker,gateway,zai,minimax]" + "any-llm-sdk[mistral,anthropic,huggingface,gemini,vertexai,cohere,cerebras,fireworks,groq,bedrock,azure,azureopenai,watsonx,together,sambanova,ollama,moonshot,nebius,xai,databricks,deepseek,inception,openai,openrouter,portkey,lmstudio,llama,voyage,perplexity,platform,llamafile,llamacpp,sagemaker,gateway,zai,minimax,vllm]" ] platform = [ @@ -109,6 +109,7 @@ openrouter = [] portkey = [] sambanova = [] minimax = [] +vllm = [] gateway = [ "fastapi>=0.115.0", "uvicorn[standard]>=0.30.0", diff --git a/src/any_llm/constants.py b/src/any_llm/constants.py index 27b9e9ff..a79ef8f6 100644 --- a/src/any_llm/constants.py +++ b/src/any_llm/constants.py @@ -45,6 +45,7 @@ class LLMProvider(StrEnum): SAGEMAKER = "sagemaker" TOGETHER = "together" VERTEXAI = "vertexai" + VLLM = "vllm" VOYAGE = "voyage" WATSONX = "watsonx" XAI = "xai" diff --git a/src/any_llm/providers/vllm/__init__.py b/src/any_llm/providers/vllm/__init__.py new file mode 100644 index 00000000..e52a1b59 --- /dev/null +++ b/src/any_llm/providers/vllm/__init__.py @@ -0,0 +1,3 @@ +from .vllm import VllmProvider + +__all__ = ["VllmProvider"] diff --git a/src/any_llm/providers/vllm/vllm.py b/src/any_llm/providers/vllm/vllm.py new file mode 100644 index 00000000..b6c256bb --- /dev/null +++ b/src/any_llm/providers/vllm/vllm.py @@ -0,0 +1,18 @@ +from any_llm.providers.openai.base import BaseOpenAIProvider + + +class VllmProvider(BaseOpenAIProvider): + API_BASE = "http://localhost:8000/v1" + ENV_API_KEY_NAME = "VLLM_API_KEY" + PROVIDER_NAME = "vllm" + PROVIDER_DOCUMENTATION_URL = "https://docs.vllm.ai/" + + SUPPORTS_EMBEDDING = True + SUPPORTS_COMPLETION_REASONING = True + SUPPORTS_COMPLETION_STREAMING = True + SUPPORTS_COMPLETION_PDF = False + + def _verify_and_set_api_key(self, api_key: str | None = None) -> str | None: + # vLLM server by default doesn't require an API key + # but can be configured to use one via --api-key flag + return api_key or "" diff --git a/tests/conftest.py b/tests/conftest.py index ec4c5175..79ff2fba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ def provider_reasoning_model_map() -> dict[LLMProvider, str]: LLMProvider.OPENROUTER: "google/gemini-2.5-flash-lite", LLMProvider.LLAMAFILE: "N/A", LLMProvider.LLAMACPP: "N/A", + LLMProvider.VLLM: "N/A", LLMProvider.LMSTUDIO: "openai/gpt-oss-20b", # You must have LM Studio running and the server enabled LLMProvider.AZUREOPENAI: "azure/", LLMProvider.CEREBRAS: "gpt-oss-120b", @@ -60,6 +61,7 @@ def provider_model_map() -> dict[LLMProvider, str]: LLMProvider.OLLAMA: "llama3.2:1b", LLMProvider.LLAMAFILE: "N/A", LLMProvider.LMSTUDIO: "google/gemma-3n-e4b", # You must have LM Studio running and the server enabled + LLMProvider.VLLM: "Qwen/Qwen2.5-0.5B-Instruct", LLMProvider.COHERE: "command-a-03-2025", LLMProvider.CEREBRAS: "llama-3.3-70b", LLMProvider.HUGGINGFACE: "huggingface/tgi", # This is the syntax used in `litellm` when using HF Inference Endpoints (https://docs.litellm.ai/docs/providers/huggingface#dedicated-inference-endpoints) @@ -131,6 +133,7 @@ def provider_client_config() -> dict[LLMProvider, dict[str, Any]]: LLMProvider.OPENAI: {"timeout": 100}, LLMProvider.HUGGINGFACE: {"api_base": "https://oze7k8n86bjfzgjk.us-east-1.aws.endpoints.huggingface.cloud/v1"}, LLMProvider.LLAMACPP: {"api_base": "http://127.0.0.1:8090/v1"}, + LLMProvider.VLLM: {"api_base": "http://127.0.0.1:8080/v1"}, LLMProvider.MISTRAL: {"timeout_ms": 100000}, LLMProvider.NEBIUS: {"api_base": "https://api.studio.nebius.com/v1/"}, LLMProvider.OPENAI: {"timeout": 10}, diff --git a/tests/constants.py b/tests/constants.py index b768575a..9740acfa 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -8,6 +8,7 @@ LLMProvider.LMSTUDIO, LLMProvider.LLAMAFILE, LLMProvider.GATEWAY, + LLMProvider.VLLM, ] EXPECTED_PROVIDERS = os.environ.get("EXPECTED_PROVIDERS", "").split(",") diff --git a/tests/unit/providers/test_vllm_provider.py b/tests/unit/providers/test_vllm_provider.py new file mode 100644 index 00000000..863a7231 --- /dev/null +++ b/tests/unit/providers/test_vllm_provider.py @@ -0,0 +1,13 @@ +from any_llm.providers.vllm.vllm import VllmProvider + + +def test_provider_without_api_key() -> None: + provider = VllmProvider() + assert provider.PROVIDER_NAME == "vllm" + assert provider.API_BASE == "http://localhost:8000/v1" + assert provider.ENV_API_KEY_NAME == "VLLM_API_KEY" + + +def test_provider_with_api_key() -> None: + provider = VllmProvider(api_key="test-api-key") + assert provider.PROVIDER_NAME == "vllm" diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py index ee93635e..f6c86807 100644 --- a/tests/unit/test_provider.py +++ b/tests/unit/test_provider.py @@ -147,6 +147,7 @@ def test_providers_raise_MissingApiKeyError(provider: LLMProvider) -> None: LLMProvider.OLLAMA, LLMProvider.SAGEMAKER, LLMProvider.VERTEXAI, + LLMProvider.VLLM, LLMProvider.GATEWAY, ): pytest.skip("This provider handles `api_key` differently.")