Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ build/
.mypy_cache/
.venv/
venv/
generated/
generated/
.vscode/settings.json
45 changes: 27 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,41 @@ This project provides a flexible Python framework for interacting with VLLM (or
```
.
├── src/
│ └── tframex/ # Core library package
│ ├── __init__.py # Makes 'tframex' importable
│ └── tframex/ # Core library package
│ ├── __init__.py
│ ├── agents/
│ │ ├── __init__.py # Exposes agent classes (e.g., BasicAgent)
│ │ ├── agent_logic.py # BaseAgent and shared logic
│ │ └── agents.py # Concrete agent implementations
│ │ ├── __init__.py
│ │ ├── agent_logic.py # BaseAgent logic
│ │ └── agents.py # BasicAgent, ContextAgent
│ ├── model/
│ │ ├── __init__.py # Exposes model classes (e.g., VLLMModel)
│ │ └── model_logic.py # BaseModel, VLLMModel implementation
│ │ ├── __init__.py
│ │ ├── base.py # Shared model interfaces
│ │ ├── model_wrapper.py # Unified wrapper for model routing
│ │ ├── openai_model.py # OpenAI-specific model logic
│ │ └── vllm_model.py # vLLM-specific model logic
│ └── systems/
│ ├── __init__.py # Exposes system classes (e.g., ChainOfAgents, MultiCallSystem)
│ ├── chain_of_agents.py # Sequential summarization system
│ └── multi_call_system.py # Parallel sampling/generation system
│ ├── __init__.py
│ ├── chain_of_agents.py # Chunked summarization system
│ └── multi_call_system.py # Parallel generation system
├── examples/ # Example usage scripts (separate from the library)
├── examples/
│ ├── website_builder/
│ │ └── html.py
│ ├── context.txt # Sample input file
│ ├── example.py # Main example script
│ └── longtext.txt # Sample input file
│ ├── context.txt
│ ├── example.py
│ └── longtext.txt
├── .env copy # Example environment file template
├── tests/
│ ├── integration/
│ └── unit/
│ └── model/
├── .env copy
├── .gitignore
├── README.md # This file
├── requirements.txt # Core library dependencies
└── pyproject.toml # Build system and package configuration
├── README.md
├── requirements.txt
└── pyproject.toml

```

* **`tframex/`**: The main directory containing the library source code.
Expand Down
44 changes: 28 additions & 16 deletions examples/standard/example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# example.py (run with python -m examples.standard.example --example <*> where * is example number)
# Import Model, Agents, and Systems
from tframex.model import VLLMModel # NEW
from tframex.model import ModelWrapper
from tframex.agents import BasicAgent, ContextAgent # NEW
from tframex.systems import ChainOfAgents, MultiCallSystem # NEW
import asyncio
Expand All @@ -9,11 +9,10 @@
import time
from dotenv import load_dotenv
import argparse

# Load .env into environment
load_dotenv()



logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

Expand All @@ -32,8 +31,10 @@
API_URL = os.getenv("API_URL")
API_KEY = os.getenv("API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME")
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 32000))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", 4096))
TEMPERATURE = float(os.getenv("TEMPERATURE", 0.7))
OPENAI_MODEL_NAME = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# --- File Paths ---
CONTEXT_FILE = os.path.join("examples", "standard", "context.txt")
Expand Down Expand Up @@ -62,23 +63,34 @@ async def main():

# 0. Create the Model Instance
logger.info("Creating VLLM Model instance...")
vllm_model = VLLMModel(
model_name=MODEL_NAME,
api_url=API_URL,
api_key=API_KEY,

# VLLM Model Wrapper
# vllm_model = ModelWrapper(
# provider="vllm",
# model_name=MODEL_NAME,
# api_key=API_KEY,
# api_url=API_URL,
# default_max_tokens=MAX_TOKENS,
# default_temperature=TEMPERATURE
# )

openai_model = ModelWrapper(
provider="openai",
model_name=OPENAI_MODEL_NAME,
api_key=OPENAI_API_KEY,
default_max_tokens=MAX_TOKENS,
default_temperature=TEMPERATURE
)

# --- Example 1: Basic Agent ---
if args.example in [None, 1]:
logger.info("\n--- Example 1: Basic Agent ---")
basic_agent = BasicAgent(agent_id="basic_001", model=vllm_model)
basic_agent = BasicAgent(agent_id="basic_001", model=openai_model)
basic_prompt = "Explain the difference between synchronous and asynchronous programming using a simple analogy."
basic_output_file = "ex1_basic_agent_output.txt"
print(f"Running BasicAgent with prompt: '{basic_prompt}'")

basic_response = await basic_agent.run(basic_prompt, max_tokens=32000) # Override default tokens
basic_response = await basic_agent.run(basic_prompt, max_tokens=4096) # Override default tokens

print(f"BasicAgent Response:\n{basic_response[:200]}...") # Print preview
save_output(basic_output_file, f"Prompt:\n{basic_prompt}\n\nResponse:\n{basic_response}")
Expand All @@ -97,7 +109,7 @@ async def main():
context_content = "The user is interested in Python programming best practices."
save_output(CONTEXT_FILE, context_content, directory=".") # Create dummy file

context_agent = ContextAgent(agent_id="context_001", model=vllm_model, context=context_content)
context_agent = ContextAgent(agent_id="context_001", model=openai_model, context=context_content)
context_prompt = "What are 3 key recommendations for writing clean code?"
context_output_file = "ex2_context_agent_output.txt"
print(f"Running ContextAgent with prompt: '{context_prompt}'")
Expand All @@ -124,13 +136,13 @@ async def main():
"Asynchronous programming with asyncio provides concurrency for I/O-bound tasks without needing multiple threads.")
save_output(LONG_TEXT_FILE, long_text_content, directory=".") # Create dummy file

chain_system = ChainOfAgents(system_id="chain_summarizer_01", model=vllm_model, chunk_size=200, chunk_overlap=50)
chain_system = ChainOfAgents(system_id="chain_summarizer_01", model=openai_model, chunk_size=200, chunk_overlap=50)
chain_prompt = "Based on the provided text, explain the implications of Python's dynamic typing and the GIL."
chain_output_file = "ex3_chain_system_output.txt"
print(f"Running ChainOfAgents system with prompt: '{chain_prompt}'")

# Reduce max_tokens for intermediate summaries if needed via kwargs
chain_response = await chain_system.run(initial_prompt=chain_prompt, long_text=long_text_content, max_tokens=32000) # kwargs passed down
chain_response = await chain_system.run(initial_prompt=chain_prompt, long_text=long_text_content, max_tokens=4096) # kwargs passed down

print(f"ChainOfAgents Response:\n{chain_response[:200]}...")
save_output(chain_output_file, f"Initial Prompt:\n{chain_prompt}\n\nLong Text Input (preview):\n{long_text_content[:300]}...\n\nFinal Response:\n{chain_response}")
Expand All @@ -139,7 +151,7 @@ async def main():
# --- Example 4: Multi Call System ---
if args.example in [None, 4]:
logger.info("\n--- Example 4: Multi Call System ---")
multi_call_system = MultiCallSystem(system_id="multi_haiku_01", model=vllm_model)
multi_call_system = MultiCallSystem(system_id="multi_haiku_01", model=openai_model)
multi_call_prompt = "Make the best looking website for a html css js tailwind coffee shop landing page."
num_calls = 15 # Use a smaller number for testing, change to 120 if needed
# num_calls = 120
Expand All @@ -151,7 +163,7 @@ async def main():
num_calls=num_calls,
output_dir=multi_call_output_dir,
base_filename="website",
max_tokens=35000 # Keep haikus short
max_tokens=4096 # Keep haikus short
)

print(f"MultiCallSystem finished. Results saved in '{multi_call_output_dir}'.")
Expand All @@ -162,7 +174,7 @@ async def main():

# --- Cleanup ---
logger.info("\n--- Closing Model Client ---")
await vllm_model.close_client()
await openai_model.close_client()

end_time = time.time()
logger.info(f"--- All examples finished in {end_time - start_time:.2f} seconds ---")
Expand Down
2 changes: 1 addition & 1 deletion src/tframex/agents/agent_logic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# agent_logic.py
import logging
from abc import ABC, abstractmethod
from tframex.model.model_logic import BaseModel # NEW
from tframex.model.base import BaseModel # NEW
from typing import Any, List, Dict

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/tframex/agents/agents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# agents.py
import logging
from tframex.agents.agent_logic import BaseAgent # NEW
from tframex.model.model_logic import BaseModel # NEWBaseModel
from tframex.model.base import BaseModel # NEWBaseModel

logger = logging.getLogger(__name__)

Expand Down
6 changes: 2 additions & 4 deletions src/tframex/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# TAF/tframex/model/__init__.py

# Import the classes you want to expose directly from the 'model' package
from .model_logic import BaseModel, VLLMModel
from .model_wrapper import ModelWrapper

# Optional: Define __all__ to control 'from tframex.model import *' behaviour
__all__ = ['BaseModel', 'VLLMModel']
__all__ = ["ModelWrapper"]
34 changes: 34 additions & 0 deletions src/tframex/model/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# base.py

import logging
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Dict, List

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class BaseModel(ABC):
"""Abstract base class for language models."""
def __init__(self, model_id: str):
self.model_id = model_id
logger.info(f"Initializing base model structure for ID: {model_id}")

@abstractmethod
async def call_stream(self, messages: List[Dict[str, str]], **kwargs) -> AsyncGenerator[str, None]:
"""
Calls the language model (now expecting chat format) and streams response chunks.
Must be implemented by subclasses.

Args:
messages (List[Dict[str, str]]): A list of message dictionaries,
e.g., [{"role": "user", "content": "Hello"}].
Yields:
str: Chunks of the generated text content.
"""
raise NotImplementedError
yield "" # Required for async generator typing

@abstractmethod
async def close_client(self):
"""Closes any underlying network clients."""
raise NotImplementedError
21 changes: 21 additions & 0 deletions src/tframex/model/model_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .base import BaseModel
from .vllm_model import VLLMModel
from .openai_model import OpenAIModel

class ModelWrapper(BaseModel):
def __init__(self, provider: str, **kwargs):
if provider == "vllm":
self.model = VLLMModel(**kwargs)
elif provider == "openai":
self.model = OpenAIModel(**kwargs)
else:
raise ValueError(f"Unsupported provider: {provider}")

super().__init__(model_id=self.model.model_id)

async def call_stream(self, messages, **kwargs):
async for chunk in self.model.call_stream(messages, **kwargs):
yield chunk

async def close_client(self):
await self.model.close_client()
76 changes: 76 additions & 0 deletions src/tframex/model/openai_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging
from typing import AsyncGenerator, Dict, List

from openai import AsyncOpenAI
from .base import BaseModel

logger = logging.getLogger(__name__)


class OpenAIModel(BaseModel):
"""
Represents a connection to OpenAI's official hosted models using the OpenAI SDK.
Designed to be used with any OpenAI-compatible chat model (e.g., gpt-3.5-turbo, gpt-4).
"""
def __init__(self,
model_name: str,
api_key: str,
default_max_tokens: int = 1024,
default_temperature: float = 0.7):
"""
Initializes the OpenAIModel with credentials and default generation parameters.

Args:
model_name (str): The model ID to use (e.g., "gpt-4").
api_key (str): Your OpenAI API key.
default_max_tokens (int): Default max tokens to generate if not overridden.
default_temperature (float): Default temperature for sampling.
"""
super().__init__(model_id=f"openai_{model_name}") # Call BaseModel constructor and set model_id
self.client = AsyncOpenAI(api_key=api_key) # Initialize OpenAI async client with API key
self.model_name = model_name # Store the model name (e.g., "gpt-4")
self.default_max_tokens = default_max_tokens # Set default max_tokens if not overridden
self.default_temperature = default_temperature # Set default temperature if not overridden
logger.info(f"OpenAIModel '{self.model_id}' initialized.") # Log successful model setup


async def call_stream(self, messages: List[Dict[str, str]], **kwargs) -> AsyncGenerator[str, None]:
"""
Calls OpenAI's chat completions API with the given message history and streams the response.

Args:
messages (List[Dict[str, str]]): The conversation history/prompt.
**kwargs: Optional overrides like 'max_tokens', 'temperature'.

Yields:
str: Chunks of the generated text content.
"""
try:
# Send the request to OpenAI with streaming enabled
response = await self.client.chat.completions.create(
model=self.model_name, # Model name (e.g., "gpt-4")
messages=messages, # Chat history in OpenAI format
max_tokens=kwargs.get("max_tokens", self.default_max_tokens), # Use passed or default token limit
temperature=kwargs.get("temperature", self.default_temperature), # Use passed or default temperature
stream=True, # Enable streaming response
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]} # Pass other kwargs
)

# Iterate through streamed response chunks
async for chunk in response:
if chunk.choices: # Ensure choices exist in the chunk
delta = chunk.choices[0].delta # Get the delta object (partial update)
if delta.content is not None: # Yield content if present
yield delta.content

except Exception as e:
# Log and yield error string if anything goes wrong
logger.error(f"[{self.model_id}] OpenAI call failed: {e}", exc_info=True)
yield f"ERROR: {str(e)}"


async def close_client(self):
"""
Placeholder for SDK compatibility with BaseModel. No-op for OpenAI client.
"""
logger.info(f"[{self.model_id}] OpenAI client closed (noop).")
Loading