diff --git a/.gitignore b/.gitignore index 7378e72..4e3eb6e 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ build/ .mypy_cache/ .venv/ venv/ -generated/ \ No newline at end of file +generated/ +.vscode/settings.json diff --git a/README.md b/README.md index 1835f13..ba67610 100644 --- a/README.md +++ b/README.md @@ -29,32 +29,41 @@ This project provides a flexible Python framework for interacting with VLLM (or ``` . ├── src/ -│ └── tframex/ # Core library package -│ ├── __init__.py # Makes 'tframex' importable +│ └── tframex/ # Core library package +│ ├── __init__.py │ ├── agents/ -│ │ ├── __init__.py # Exposes agent classes (e.g., BasicAgent) -│ │ ├── agent_logic.py # BaseAgent and shared logic -│ │ └── agents.py # Concrete agent implementations +│ │ ├── __init__.py +│ │ ├── agent_logic.py # BaseAgent logic +│ │ └── agents.py # BasicAgent, ContextAgent │ ├── model/ -│ │ ├── __init__.py # Exposes model classes (e.g., VLLMModel) -│ │ └── model_logic.py # BaseModel, VLLMModel implementation +│ │ ├── __init__.py +│ │ ├── base.py # Shared model interfaces +│ │ ├── model_wrapper.py # Unified wrapper for model routing +│ │ ├── openai_model.py # OpenAI-specific model logic +│ │ └── vllm_model.py # vLLM-specific model logic │ └── systems/ -│ ├── __init__.py # Exposes system classes (e.g., ChainOfAgents, MultiCallSystem) -│ ├── chain_of_agents.py # Sequential summarization system -│ └── multi_call_system.py # Parallel sampling/generation system +│ ├── __init__.py +│ ├── chain_of_agents.py # Chunked summarization system +│ └── multi_call_system.py # Parallel generation system │ -├── examples/ # Example usage scripts (separate from the library) +├── examples/ │ ├── website_builder/ │ │ └── html.py -│ ├── context.txt # Sample input file -│ ├── example.py # Main example script -│ └── longtext.txt # Sample input file +│ ├── context.txt +│ ├── example.py +│ └── longtext.txt │ -├── .env copy # Example environment file template +├── tests/ +│ ├── integration/ +│ └── unit/ +│ └── model/ +│ +├── .env copy ├── .gitignore -├── README.md # This file -├── requirements.txt # Core library dependencies -└── pyproject.toml # Build system and package configuration +├── README.md +├── requirements.txt +└── pyproject.toml + ``` * **`tframex/`**: The main directory containing the library source code. diff --git a/examples/standard/example.py b/examples/standard/example.py index 3a9bd01..fe3f165 100644 --- a/examples/standard/example.py +++ b/examples/standard/example.py @@ -1,6 +1,6 @@ # example.py (run with python -m examples.standard.example --example <*> where * is example number) # Import Model, Agents, and Systems -from tframex.model import VLLMModel # NEW +from tframex.model import ModelWrapper from tframex.agents import BasicAgent, ContextAgent # NEW from tframex.systems import ChainOfAgents, MultiCallSystem # NEW import asyncio @@ -9,11 +9,10 @@ import time from dotenv import load_dotenv import argparse + # Load .env into environment load_dotenv() - - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) @@ -32,8 +31,10 @@ API_URL = os.getenv("API_URL") API_KEY = os.getenv("API_KEY") MODEL_NAME = os.getenv("MODEL_NAME") -MAX_TOKENS = int(os.getenv("MAX_TOKENS", 32000)) +MAX_TOKENS = int(os.getenv("MAX_TOKENS", 4096)) TEMPERATURE = float(os.getenv("TEMPERATURE", 0.7)) +OPENAI_MODEL_NAME = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # --- File Paths --- CONTEXT_FILE = os.path.join("examples", "standard", "context.txt") @@ -62,10 +63,21 @@ async def main(): # 0. Create the Model Instance logger.info("Creating VLLM Model instance...") - vllm_model = VLLMModel( - model_name=MODEL_NAME, - api_url=API_URL, - api_key=API_KEY, + + # VLLM Model Wrapper + # vllm_model = ModelWrapper( + # provider="vllm", + # model_name=MODEL_NAME, + # api_key=API_KEY, + # api_url=API_URL, + # default_max_tokens=MAX_TOKENS, + # default_temperature=TEMPERATURE + # ) + + openai_model = ModelWrapper( + provider="openai", + model_name=OPENAI_MODEL_NAME, + api_key=OPENAI_API_KEY, default_max_tokens=MAX_TOKENS, default_temperature=TEMPERATURE ) @@ -73,12 +85,12 @@ async def main(): # --- Example 1: Basic Agent --- if args.example in [None, 1]: logger.info("\n--- Example 1: Basic Agent ---") - basic_agent = BasicAgent(agent_id="basic_001", model=vllm_model) + basic_agent = BasicAgent(agent_id="basic_001", model=openai_model) basic_prompt = "Explain the difference between synchronous and asynchronous programming using a simple analogy." basic_output_file = "ex1_basic_agent_output.txt" print(f"Running BasicAgent with prompt: '{basic_prompt}'") - basic_response = await basic_agent.run(basic_prompt, max_tokens=32000) # Override default tokens + basic_response = await basic_agent.run(basic_prompt, max_tokens=4096) # Override default tokens print(f"BasicAgent Response:\n{basic_response[:200]}...") # Print preview save_output(basic_output_file, f"Prompt:\n{basic_prompt}\n\nResponse:\n{basic_response}") @@ -97,7 +109,7 @@ async def main(): context_content = "The user is interested in Python programming best practices." save_output(CONTEXT_FILE, context_content, directory=".") # Create dummy file - context_agent = ContextAgent(agent_id="context_001", model=vllm_model, context=context_content) + context_agent = ContextAgent(agent_id="context_001", model=openai_model, context=context_content) context_prompt = "What are 3 key recommendations for writing clean code?" context_output_file = "ex2_context_agent_output.txt" print(f"Running ContextAgent with prompt: '{context_prompt}'") @@ -124,13 +136,13 @@ async def main(): "Asynchronous programming with asyncio provides concurrency for I/O-bound tasks without needing multiple threads.") save_output(LONG_TEXT_FILE, long_text_content, directory=".") # Create dummy file - chain_system = ChainOfAgents(system_id="chain_summarizer_01", model=vllm_model, chunk_size=200, chunk_overlap=50) + chain_system = ChainOfAgents(system_id="chain_summarizer_01", model=openai_model, chunk_size=200, chunk_overlap=50) chain_prompt = "Based on the provided text, explain the implications of Python's dynamic typing and the GIL." chain_output_file = "ex3_chain_system_output.txt" print(f"Running ChainOfAgents system with prompt: '{chain_prompt}'") # Reduce max_tokens for intermediate summaries if needed via kwargs - chain_response = await chain_system.run(initial_prompt=chain_prompt, long_text=long_text_content, max_tokens=32000) # kwargs passed down + chain_response = await chain_system.run(initial_prompt=chain_prompt, long_text=long_text_content, max_tokens=4096) # kwargs passed down print(f"ChainOfAgents Response:\n{chain_response[:200]}...") save_output(chain_output_file, f"Initial Prompt:\n{chain_prompt}\n\nLong Text Input (preview):\n{long_text_content[:300]}...\n\nFinal Response:\n{chain_response}") @@ -139,7 +151,7 @@ async def main(): # --- Example 4: Multi Call System --- if args.example in [None, 4]: logger.info("\n--- Example 4: Multi Call System ---") - multi_call_system = MultiCallSystem(system_id="multi_haiku_01", model=vllm_model) + multi_call_system = MultiCallSystem(system_id="multi_haiku_01", model=openai_model) multi_call_prompt = "Make the best looking website for a html css js tailwind coffee shop landing page." num_calls = 15 # Use a smaller number for testing, change to 120 if needed # num_calls = 120 @@ -151,7 +163,7 @@ async def main(): num_calls=num_calls, output_dir=multi_call_output_dir, base_filename="website", - max_tokens=35000 # Keep haikus short + max_tokens=4096 # Keep haikus short ) print(f"MultiCallSystem finished. Results saved in '{multi_call_output_dir}'.") @@ -162,7 +174,7 @@ async def main(): # --- Cleanup --- logger.info("\n--- Closing Model Client ---") - await vllm_model.close_client() + await openai_model.close_client() end_time = time.time() logger.info(f"--- All examples finished in {end_time - start_time:.2f} seconds ---") diff --git a/src/tframex/agents/agent_logic.py b/src/tframex/agents/agent_logic.py index 79352ff..bf130e0 100644 --- a/src/tframex/agents/agent_logic.py +++ b/src/tframex/agents/agent_logic.py @@ -1,7 +1,7 @@ # agent_logic.py import logging from abc import ABC, abstractmethod -from tframex.model.model_logic import BaseModel # NEW +from tframex.model.base import BaseModel # NEW from typing import Any, List, Dict logger = logging.getLogger(__name__) diff --git a/src/tframex/agents/agents.py b/src/tframex/agents/agents.py index 4b926dc..9fd4533 100644 --- a/src/tframex/agents/agents.py +++ b/src/tframex/agents/agents.py @@ -1,7 +1,7 @@ # agents.py import logging from tframex.agents.agent_logic import BaseAgent # NEW -from tframex.model.model_logic import BaseModel # NEWBaseModel +from tframex.model.base import BaseModel # NEWBaseModel logger = logging.getLogger(__name__) diff --git a/src/tframex/model/__init__.py b/src/tframex/model/__init__.py index 7df8b6f..56fcdb5 100644 --- a/src/tframex/model/__init__.py +++ b/src/tframex/model/__init__.py @@ -1,7 +1,5 @@ # TAF/tframex/model/__init__.py -# Import the classes you want to expose directly from the 'model' package -from .model_logic import BaseModel, VLLMModel +from .model_wrapper import ModelWrapper -# Optional: Define __all__ to control 'from tframex.model import *' behaviour -__all__ = ['BaseModel', 'VLLMModel'] \ No newline at end of file +__all__ = ["ModelWrapper"] diff --git a/src/tframex/model/base.py b/src/tframex/model/base.py new file mode 100644 index 0000000..3361a99 --- /dev/null +++ b/src/tframex/model/base.py @@ -0,0 +1,34 @@ +# base.py + +import logging +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Dict, List + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class BaseModel(ABC): + """Abstract base class for language models.""" + def __init__(self, model_id: str): + self.model_id = model_id + logger.info(f"Initializing base model structure for ID: {model_id}") + + @abstractmethod + async def call_stream(self, messages: List[Dict[str, str]], **kwargs) -> AsyncGenerator[str, None]: + """ + Calls the language model (now expecting chat format) and streams response chunks. + Must be implemented by subclasses. + + Args: + messages (List[Dict[str, str]]): A list of message dictionaries, + e.g., [{"role": "user", "content": "Hello"}]. + Yields: + str: Chunks of the generated text content. + """ + raise NotImplementedError + yield "" # Required for async generator typing + + @abstractmethod + async def close_client(self): + """Closes any underlying network clients.""" + raise NotImplementedError \ No newline at end of file diff --git a/src/tframex/model/model_wrapper.py b/src/tframex/model/model_wrapper.py new file mode 100644 index 0000000..fec57c0 --- /dev/null +++ b/src/tframex/model/model_wrapper.py @@ -0,0 +1,21 @@ +from .base import BaseModel +from .vllm_model import VLLMModel +from .openai_model import OpenAIModel + +class ModelWrapper(BaseModel): + def __init__(self, provider: str, **kwargs): + if provider == "vllm": + self.model = VLLMModel(**kwargs) + elif provider == "openai": + self.model = OpenAIModel(**kwargs) + else: + raise ValueError(f"Unsupported provider: {provider}") + + super().__init__(model_id=self.model.model_id) + + async def call_stream(self, messages, **kwargs): + async for chunk in self.model.call_stream(messages, **kwargs): + yield chunk + + async def close_client(self): + await self.model.close_client() diff --git a/src/tframex/model/openai_model.py b/src/tframex/model/openai_model.py new file mode 100644 index 0000000..35d04ef --- /dev/null +++ b/src/tframex/model/openai_model.py @@ -0,0 +1,76 @@ +import logging +from typing import AsyncGenerator, Dict, List + +from openai import AsyncOpenAI +from .base import BaseModel + +logger = logging.getLogger(__name__) + + +class OpenAIModel(BaseModel): + """ + Represents a connection to OpenAI's official hosted models using the OpenAI SDK. + Designed to be used with any OpenAI-compatible chat model (e.g., gpt-3.5-turbo, gpt-4). + """ + def __init__(self, + model_name: str, + api_key: str, + default_max_tokens: int = 1024, + default_temperature: float = 0.7): + """ + Initializes the OpenAIModel with credentials and default generation parameters. + + Args: + model_name (str): The model ID to use (e.g., "gpt-4"). + api_key (str): Your OpenAI API key. + default_max_tokens (int): Default max tokens to generate if not overridden. + default_temperature (float): Default temperature for sampling. + """ + super().__init__(model_id=f"openai_{model_name}") # Call BaseModel constructor and set model_id + self.client = AsyncOpenAI(api_key=api_key) # Initialize OpenAI async client with API key + self.model_name = model_name # Store the model name (e.g., "gpt-4") + self.default_max_tokens = default_max_tokens # Set default max_tokens if not overridden + self.default_temperature = default_temperature # Set default temperature if not overridden + logger.info(f"OpenAIModel '{self.model_id}' initialized.") # Log successful model setup + + + async def call_stream(self, messages: List[Dict[str, str]], **kwargs) -> AsyncGenerator[str, None]: + """ + Calls OpenAI's chat completions API with the given message history and streams the response. + + Args: + messages (List[Dict[str, str]]): The conversation history/prompt. + **kwargs: Optional overrides like 'max_tokens', 'temperature'. + + Yields: + str: Chunks of the generated text content. + """ + try: + # Send the request to OpenAI with streaming enabled + response = await self.client.chat.completions.create( + model=self.model_name, # Model name (e.g., "gpt-4") + messages=messages, # Chat history in OpenAI format + max_tokens=kwargs.get("max_tokens", self.default_max_tokens), # Use passed or default token limit + temperature=kwargs.get("temperature", self.default_temperature), # Use passed or default temperature + stream=True, # Enable streaming response + **{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]} # Pass other kwargs + ) + + # Iterate through streamed response chunks + async for chunk in response: + if chunk.choices: # Ensure choices exist in the chunk + delta = chunk.choices[0].delta # Get the delta object (partial update) + if delta.content is not None: # Yield content if present + yield delta.content + + except Exception as e: + # Log and yield error string if anything goes wrong + logger.error(f"[{self.model_id}] OpenAI call failed: {e}", exc_info=True) + yield f"ERROR: {str(e)}" + + + async def close_client(self): + """ + Placeholder for SDK compatibility with BaseModel. No-op for OpenAI client. + """ + logger.info(f"[{self.model_id}] OpenAI client closed (noop).") diff --git a/src/tframex/model/model_logic.py b/src/tframex/model/vllm_model.py similarity index 87% rename from src/tframex/model/model_logic.py rename to src/tframex/model/vllm_model.py index 252052d..a2d5735 100644 --- a/src/tframex/model/model_logic.py +++ b/src/tframex/model/vllm_model.py @@ -1,41 +1,16 @@ -# model_logic.py +# vllm_model.py + import httpx import json import asyncio import logging -from abc import ABC, abstractmethod from typing import AsyncGenerator, Dict, List +from .base import BaseModel logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) -class BaseModel(ABC): - """Abstract base class for language models.""" - def __init__(self, model_id: str): - self.model_id = model_id - logger.info(f"Initializing base model structure for ID: {model_id}") - - @abstractmethod - async def call_stream(self, messages: List[Dict[str, str]], **kwargs) -> AsyncGenerator[str, None]: - """ - Calls the language model (now expecting chat format) and streams response chunks. - Must be implemented by subclasses. - - Args: - messages (List[Dict[str, str]]): A list of message dictionaries, - e.g., [{"role": "user", "content": "Hello"}]. - Yields: - str: Chunks of the generated text content. - """ - raise NotImplementedError - yield "" # Required for async generator typing - - @abstractmethod - async def close_client(self): - """Closes any underlying network clients.""" - raise NotImplementedError - class VLLMModel(BaseModel): """ Represents a connection to a VLLM OpenAI-compatible endpoint. diff --git a/src/tframex/systems/chain_of_agents.py b/src/tframex/systems/chain_of_agents.py index 144e89f..6fd30fb 100644 --- a/src/tframex/systems/chain_of_agents.py +++ b/src/tframex/systems/chain_of_agents.py @@ -1,6 +1,6 @@ # chain_of_agents.py import logging -from tframex.model import BaseModel +from tframex.model.base import BaseModel from tframex.agents import BasicAgent # Using BasicAgent for summarization/final answer from typing import List diff --git a/src/tframex/systems/multi_call_system.py b/src/tframex/systems/multi_call_system.py index 0af0303..5981cb8 100644 --- a/src/tframex/systems/multi_call_system.py +++ b/src/tframex/systems/multi_call_system.py @@ -2,7 +2,7 @@ import asyncio import logging import os -from tframex.model import BaseModel +from tframex.model.base import BaseModel from typing import List, Dict logger = logging.getLogger(__name__) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_openai_integration.py b/tests/integration/test_openai_integration.py new file mode 100644 index 0000000..823efc9 --- /dev/null +++ b/tests/integration/test_openai_integration.py @@ -0,0 +1,32 @@ +import pytest +import os +from tframex.model import ModelWrapper +from dotenv import load_dotenv + +# Load variables from .env file +load_dotenv() + +@pytest.mark.asyncio +async def test_openai_model_real_response(capfd): + api_key = os.getenv("OPENAI_API_KEY") + model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") + + assert api_key, "OPENAI_API_KEY must be set in .env to run this test." + + model = ModelWrapper( + provider="openai", + model_name=model_name, + api_key=api_key + ) + + messages = [{"role": "user", "content": "What is the capital of France?"}] + response = [chunk async for chunk in model.call_stream(messages)] + + full_response = "".join(response) + print("RESPONSE:", full_response, flush=True) + + await model.close_client() + + # Capture and display output + out, _ = capfd.readouterr() + assert "Paris" in full_response, f"Expected 'Paris' in response. Got:\n{out}" diff --git a/tests/integration/test_vllm_integration.py b/tests/integration/test_vllm_integration.py new file mode 100644 index 0000000..d271943 --- /dev/null +++ b/tests/integration/test_vllm_integration.py @@ -0,0 +1,29 @@ +import pytest +import os +from tframex.model import ModelWrapper +from dotenv import load_dotenv + +load_dotenv() + +@pytest.mark.asyncio +async def test_vllm_model_real_response(): + api_key = os.getenv("API_KEY") + api_url = os.getenv("API_URL") + model_name = os.getenv("MODEL_NAME") + + assert api_key and api_url and model_name, "API_URL, API_KEY, and MODEL_NAME must be set." + + model = ModelWrapper( + provider="vllm", + model_name=model_name, + api_url=api_url, + api_key=api_key + ) + + messages = [{"role": "user", "content": "Say hello in Spanish"}] + response = [chunk async for chunk in model.call_stream(messages)] + + full_response = "".join(response).lower() + print("API_URL =", os.getenv("API_URL")) + assert "hola" in full_response + await model.close_client() diff --git a/tests/unit/model/test_model_wrapper.py b/tests/unit/model/test_model_wrapper.py new file mode 100644 index 0000000..a0feb21 --- /dev/null +++ b/tests/unit/model/test_model_wrapper.py @@ -0,0 +1,85 @@ +import pytest +from tframex.model import ModelWrapper + +# --- Dummy Setup --- + +class DummyResponse: + def __init__(self, lines=None, status_code=200): + self._lines = lines or [] + self.status_code = status_code + self.headers = {} + self.request = None + + async def __aenter__(self): return self + async def __aexit__(self, *args): pass + async def aread(self): return b"Internal error" + def aiter_lines(self): + async def gen(): + for line in self._lines: + yield line + return gen() + + +class DummyClient: + def __init__(self, response): self._response = response + def stream(self, *args, **kwargs): return self._response + + +# --- Tests --- + +@pytest.mark.asyncio +async def test_model_wrapper_stream_success(): + # Simulated streaming chunks from a VLLM-like model + lines = [ + 'data:{"choices":[{"delta":{"content":"streamed"}}]}', + 'data:{"choices":[{"delta":{"content":" output"}}]}', + 'data:[DONE]' + ] + dummy = DummyResponse(lines=lines) + + wrapper = ModelWrapper( + provider="vllm", + model_name="test", + api_url="http://localhost", + api_key="dummy" + ) + wrapper.model._client = DummyClient(dummy) + + chunks = [chunk async for chunk in wrapper.call_stream([{"role": "user", "content": "Say something"}])] + assert "".join(chunks).strip() == "streamed output" + + +@pytest.mark.asyncio +async def test_model_wrapper_stream_failure(): + # Simulate an API error (500) + dummy = DummyResponse(status_code=500) + + wrapper = ModelWrapper( + provider="vllm", + model_name="test", + api_url="http://localhost", + api_key="dummy" + ) + wrapper.model._client = DummyClient(dummy) + + chunks = [chunk async for chunk in wrapper.call_stream([{"role": "user", "content": "test"}])] + assert any("ERROR" in chunk for chunk in chunks) + + +@pytest.mark.asyncio +async def test_model_wrapper_close_client(): + called = {"closed": False} + + class DummyCloseClient: + async def aclose(self): called["closed"] = True + + wrapper = ModelWrapper( + provider="vllm", + model_name="test", + api_url="http://localhost", + api_key="dummy" + ) + wrapper.model._client = DummyCloseClient() + + await wrapper.close_client() + assert called["closed"] is True diff --git a/tests/unit/model/test_openai_model.py b/tests/unit/model/test_openai_model.py new file mode 100644 index 0000000..ec042b6 --- /dev/null +++ b/tests/unit/model/test_openai_model.py @@ -0,0 +1,45 @@ +from tframex.model import ModelWrapper +import pytest + +# --- Dummy OpenAI-style client + stream --- + +class DummyOpenAIChunk: + def __init__(self, content): + self.choices = [type("Choice", (), {"delta": type("Delta", (), {"content": content})()})] + +class DummyOpenAIStream: + def __init__(self, chunks): self._chunks = chunks + def __aiter__(self): + async def gen(): + for c in self._chunks: + yield DummyOpenAIChunk(c) + return gen() + + +class DummyOpenAIClient: + def __init__(self, stream): + self.chat = type("Chat", (), {})() + self.chat.completions = type("Completions", (), {})() + + async def create(*_args, **_kwargs): + return stream + + self.chat.completions.create = create + + +@pytest.mark.asyncio +async def test_model_wrapper_openai_stream(): + # Prepare dummy OpenAI-like streamed chunks + dummy_stream = DummyOpenAIStream(["hello", " world"]) + + wrapper = ModelWrapper( + provider="openai", + model_name="gpt-3.5-turbo", + api_key="dummy-key" + ) + + # Inject dummy OpenAI client + wrapper.model.client = DummyOpenAIClient(dummy_stream) + + chunks = [chunk async for chunk in wrapper.call_stream([{"role": "user", "content": "hi"}])] + assert "".join(chunks) == "hello world" diff --git a/tests/unit/model/test_vllm_model.py b/tests/unit/model/test_vllm_model.py new file mode 100644 index 0000000..b7dec42 --- /dev/null +++ b/tests/unit/model/test_vllm_model.py @@ -0,0 +1,87 @@ +import pytest +import json +from tframex.model import ModelWrapper + +# --- Helpers --- + +class DummyResponse: + def __init__(self, status_code, lines=None, headers=None, error_content=b""): + self.status_code = status_code + self._lines = lines or [] + self.headers = headers or {} + self._error_content = error_content + self.request = None + + async def __aenter__(self): return self + async def __aexit__(self, *args): pass + async def aread(self): return self._error_content + def aiter_lines(self): + async def gen(): + for line in self._lines: + yield line + return gen() + + +class DummyClient: + def __init__(self, response): self._response = response + def stream(self, *args, **kwargs): return self._response + + +# --- Tests --- + +@pytest.mark.asyncio +async def test_call_stream_success_with_wrapper(): + lines = [ + 'data:' + json.dumps({"choices": [{"delta": {"content": "hello"}}]}), + 'data:[DONE]' + ] + dummy_response = DummyResponse(status_code=200, lines=lines) + + # Wrap a VLLMModel via the wrapper + wrapper = ModelWrapper( + provider="vllm", + model_name="test-model", + api_url="API_URL", + api_key="API_KEY" + ) + wrapper.model._client = DummyClient(dummy_response) + + chunks = [chunk async for chunk in wrapper.call_stream([{"role": "user", "content": "ping"}])] + assert chunks == ["hello"] + + +@pytest.mark.asyncio +async def test_call_stream_api_error_with_wrapper(): + dummy_response = DummyResponse(status_code=500, error_content=b"something went wrong") + + wrapper = ModelWrapper( + provider="vllm", + model_name="test-model", + api_url="API_URL", + api_key="API_KEY" + ) + wrapper.model._client = DummyClient(dummy_response) + + chunks = [chunk async for chunk in wrapper.call_stream([{"role": "user", "content": "ping"}])] + assert len(chunks) == 1 + assert "ERROR" in chunks[0] + assert "500" in chunks[0] + assert "something went wrong" in chunks[0] + + +@pytest.mark.asyncio +async def test_close_client_with_wrapper(): + wrapper = ModelWrapper( + provider="vllm", + model_name="test-model", + api_url="API_URL", + api_key="API_KEY" + ) + + closed = {"done": False} + class DummyCloseClient: + async def aclose(self): closed["done"] = True + + wrapper.model._client = DummyCloseClient() + await wrapper.close_client() + assert closed["done"] is True