Skip to content

feat(ChatVllm): Add vLLM support #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added `ChatOpenRouter()` for chatting via [Open Router](https://openrouter.ai/). (#148)
* Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144)
* Added `ChatPortkey()` for chatting via [Portkey AI](https://portkey.ai/). (#143)
* Added `ChatVllm()` for chatting via [vLLM](https://docs.vllm.ai/en/latest/). (#24)

### Bug fixes

Expand Down
2 changes: 2 additions & 0 deletions chatlas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ._provider_perplexity import ChatPerplexity
from ._provider_portkey import ChatPortkey
from ._provider_snowflake import ChatSnowflake
from ._provider_vllm import ChatVllm
from ._tokens import token_usage
from ._tools import Tool, ToolRejectError
from ._turn import Turn
Expand Down Expand Up @@ -46,6 +47,7 @@
"ChatPortkey",
"ChatSnowflake",
"ChatVertex",
"ChatVllm",
"Chat",
"content_image_file",
"content_image_plot",
Expand Down
6 changes: 3 additions & 3 deletions chatlas/_provider_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _chat_perform_args(
kwargs: Optional["SubmitInputArgs"] = None,
) -> "SubmitInputArgs":
tool_schemas = [
self._anthropic_tool_schema(tool.schema) for tool in tools.values()
self._tool_schema_json(tool.schema) for tool in tools.values()
]

# If data extraction is requested, add a "mock" tool with parameters inferred from the data model
Expand All @@ -306,7 +306,7 @@ def _structured_tool_call(**kwargs: Any):
},
}

tool_schemas.append(self._anthropic_tool_schema(data_model_tool.schema))
tool_schemas.append(self._tool_schema_json(data_model_tool.schema))

if stream:
stream = False
Expand Down Expand Up @@ -542,7 +542,7 @@ def _as_content_block(content: Content) -> "ContentBlockParam":
raise ValueError(f"Unknown content type: {type(content)}")

@staticmethod
def _anthropic_tool_schema(schema: "ChatCompletionToolParam") -> "ToolParam":
def _tool_schema_json(schema: "ChatCompletionToolParam") -> "ToolParam":
fn = schema["function"]
name = fn["name"]

Expand Down
9 changes: 8 additions & 1 deletion chatlas/_provider_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from openai.types.chat.chat_completion_assistant_message_param import (
ContentArrayOfContentPart,
Expand Down Expand Up @@ -276,7 +277,7 @@ def _chat_perform_args(
data_model: Optional[type[BaseModel]] = None,
kwargs: Optional["SubmitInputArgs"] = None,
) -> "SubmitInputArgs":
tool_schemas = [tool.schema for tool in tools.values()]
tool_schemas = [self._tool_schema_json(tool.schema) for tool in tools.values()]

kwargs_full: "SubmitInputArgs" = {
"stream": stream,
Expand Down Expand Up @@ -514,6 +515,12 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:

return res

@staticmethod
def _tool_schema_json(
schema: "ChatCompletionToolParam",
) -> "ChatCompletionToolParam":
return schema

def _as_turn(
self, completion: "ChatCompletion", has_data_model: bool
) -> Turn[ChatCompletion]:
Expand Down
3 changes: 1 addition & 2 deletions chatlas/_provider_portkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ def ChatPortkey(
Chat
A chat object that retains the state of the conversation.

Notes
Note
-----
This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`) with
the defaults tweaked for PortkeyAI.

"""
if model is None:
model = log_model_default("gpt-4.1")
Expand Down
121 changes: 121 additions & 0 deletions chatlas/_provider_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
from typing import TYPE_CHECKING, Optional

import requests

from ._chat import Chat
from ._provider_openai import OpenAIProvider

if TYPE_CHECKING:
from openai.types.chat import ChatCompletionToolParam

from .types.openai import ChatClientArgs


def ChatVllm(
*,
base_url: str,
system_prompt: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
seed: Optional[int] = None,
kwargs: Optional["ChatClientArgs"] = None,
) -> Chat:
"""
Chat with a model hosted by vLLM

[vLLM](https://docs.vllm.ai/en/latest/) is an open source library that
provides an efficient and convenient LLMs model server. You can use
`ChatVllm()` to connect to endpoints powered by vLLM.

Prerequisites
-------------

::: {.callout-note}
## vLLM runtime

`ChatVllm` requires a vLLM server to be running somewhere (either on your
machine or a remote server). If you want to run a vLLM server locally, see
the [vLLM documentation](https://docs.vllm.ai/en/stable/getting_started/quickstart.html).
:::


Parameters
----------
base_url
Base URL of the vLLM server (e.g., "http://localhost:8000/v1").
system_prompt
A system prompt to set the behavior of the assistant.
model
Model identifier to use.
seed
Random seed for reproducibility.
api_key
API key for authentication. If not provided, the `VLLM_API_KEY` environment
variable will be used.
kwargs
Additional arguments to pass to the LLM client.

Return
------
Chat
A chat object that retains the state of the conversation.

Note
-----
This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`) with
the defaults tweaked for PortkeyAI.
"""

if api_key is None:
api_key = get_vllm_key()

if model is None:
models = get_vllm_models(base_url, api_key)
available_models = ", ".join(models)
raise ValueError(f"Must specify model. Available models: {available_models}")

return Chat(
provider=VLLMProvider(
base_url=base_url,
model=model,
seed=seed,
api_key=api_key,
kwargs=kwargs,
),
system_prompt=system_prompt,
)


class VLLMProvider(OpenAIProvider):
# Just like OpenAI but no strict
@staticmethod
def _tool_schema_json(
schema: "ChatCompletionToolParam",
) -> "ChatCompletionToolParam":
schema["function"]["strict"] = False
return schema


def get_vllm_key() -> str:
key = os.getenv("VLLM_API_KEY", os.getenv("VLLM_KEY"))
if not key:
raise ValueError("VLLM_API_KEY environment variable not set")
return key


def get_vllm_models(base_url: str, api_key: Optional[str] = None) -> list[str]:
if api_key is None:
api_key = get_vllm_key()

headers = {"Authorization": f"Bearer {api_key}"}
response = requests.get(f"{base_url}/v1/models", headers=headers)
response.raise_for_status()
data = response.json()

return [model["id"] for model in data["data"]]


# def chat_vllm_test(**kwargs) -> Chat:
# """Create a test chat instance with default parameters."""
# return ChatVllm(base_url="https://llm.nrp-nautilus.io/", model="llama3", **kwargs)
1 change: 1 addition & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ quartodoc:
- ChatPortkey
- ChatSnowflake
- ChatVertex
- ChatVllm
- title: The chat object
desc: Methods and attributes available on a chat instance
contents:
Expand Down
106 changes: 106 additions & 0 deletions tests/test_provider_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os

import pytest

do_test = os.getenv("TEST_VLLM", "true")
if do_test.lower() == "false":
pytest.skip("Skipping vLLM tests", allow_module_level=True)

from chatlas import ChatVllm

from .conftest import (
assert_tools_async,
assert_tools_simple,
assert_turns_existing,
assert_turns_system,
)


def test_vllm_simple_request():
# This test assumes you have a vLLM server running locally
# Skip if TEST_VLLM_BASE_URL is not set
base_url = os.getenv("TEST_VLLM_BASE_URL")
if base_url is None:
pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests")

model = os.getenv("TEST_VLLM_MODEL", "llama3")

chat = ChatVllm(
base_url=base_url,
model=model,
system_prompt="Be as terse as possible; no punctuation",
)
chat.chat("What is 1 + 1?")
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens is not None
assert len(turn.tokens) == 3
assert turn.tokens[0] >= 10 # More lenient assertion for vLLM
assert turn.finish_reason == "stop"


@pytest.mark.asyncio
async def test_vllm_simple_streaming_request():
base_url = os.getenv("TEST_VLLM_BASE_URL")
if base_url is None:
pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests")

model = os.getenv("TEST_VLLM_MODEL", "llama3")

chat = ChatVllm(
base_url=base_url,
model=model,
system_prompt="Be as terse as possible; no punctuation",
)
res = []
async for x in await chat.stream_async("What is 1 + 1?"):
res.append(x)
assert "2" in "".join(res)
turn = chat.get_last_turn()
assert turn is not None
assert turn.finish_reason == "stop"


def test_vllm_respects_turns_interface():
base_url = os.getenv("TEST_VLLM_BASE_URL")
if base_url is None:
pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests")

model = os.getenv("TEST_VLLM_MODEL", "llama3")

def chat_fun(**kwargs):
return ChatVllm(base_url=base_url, model=model, **kwargs)

assert_turns_system(chat_fun)
assert_turns_existing(chat_fun)


def test_vllm_tool_variations():
base_url = os.getenv("TEST_VLLM_BASE_URL")
if base_url is None:
pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests")

model = os.getenv("TEST_VLLM_MODEL", "llama3")

def chat_fun(**kwargs):
return ChatVllm(base_url=base_url, model=model, **kwargs)

assert_tools_simple(chat_fun)


@pytest.mark.asyncio
async def test_vllm_tool_variations_async():
base_url = os.getenv("TEST_VLLM_BASE_URL")
if base_url is None:
pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests")

model = os.getenv("TEST_VLLM_MODEL", "llama3")

def chat_fun(**kwargs):
return ChatVllm(base_url=base_url, model=model, **kwargs)

await assert_tools_async(chat_fun)


# Note: vLLM support for data extraction and images depends on the specific model
# and configuration, so we skip those tests for now
Loading