Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions agent_assembly_line/llm_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@
"embeddings": "text-embedding-ada-002"
}
},
"anthropic": {
"claude-3-5-sonnet": {
"llm": "claude-3-5-sonnet-20241022",
"embeddings": "text-embedding-ada-002" # Use OpenAI for embeddings
},
"claude-3-5-haiku": {
"llm": "claude-3-5-haiku-20241022",
"embeddings": "text-embedding-ada-002"
},
"claude-3-opus": {
"llm": "claude-3-opus-20240229",
"embeddings": "text-embedding-ada-002"
}
},
# Add more mappings as needed
}

Expand Down Expand Up @@ -60,6 +74,54 @@ def create_llm_and_embeddings(config: Config):
embeddings = OpenAIEmbeddings(api_key=api_key, model=embeddings)
return llm, embeddings

elif llm_type == "anthropic":
from langchain_anthropic import ChatAnthropic
from langchain_openai.embeddings import OpenAIEmbeddings

api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("ANTHROPIC_API_KEY not found in environment variables.")

# Map short names to full model IDs
model_mapping = {
"claude-3-5-sonnet": "claude-3-5-sonnet-20241022",
"claude-3-5-haiku": "claude-3-5-haiku-20241022",
"claude-3-opus": "claude-3-opus-20240229",
"claude-3-sonnet": "claude-3-sonnet-20240229",
"claude-3-haiku": "claude-3-haiku-20240307",
}

full_model_name = model_mapping.get(model_name, model_name)

llm = ChatAnthropic(
api_key=api_key,
model=full_model_name,
timeout=config.timeout,
max_tokens=config.max_tokens if hasattr(config, 'max_tokens') else 4096
)

# Anthropic doesn't provide embeddings, so use OpenAI or custom
openai_key = os.getenv("OPENAI_API_KEY")
if config.custom_embeddings:
embeddings_model = config.custom_embeddings
else:
embeddings_model = _llm_embeddings_mapping.get("anthropic", {}).get(model_name, {}).get("embeddings", "text-embedding-ada-002")

if openai_key:
embeddings = OpenAIEmbeddings(api_key=openai_key, model=embeddings_model)
else:
# Fallback to Ollama embeddings if available
try:
from langchain_ollama.embeddings import OllamaEmbeddings
embeddings = OllamaEmbeddings(model="nomic-embed-text")
except ImportError:
raise ValueError(
"Anthropic does not provide embeddings. "
"Please set OPENAI_API_KEY for OpenAI embeddings or install langchain_ollama."
)

return llm, embeddings

elif llm_type == "runpod":
from langchain_runpod.llms import RunpodLLM
from langchain_runpod.embeddings import RunpodEmbeddings
Expand Down Expand Up @@ -87,4 +149,7 @@ def extract_response(response, config: Config):
return response["choices"][0]["text"]
if config.model_name == "gpt-4o":
return response["choices"][0]["text"]
# Handle Anthropic response format
if hasattr(response, 'content'):
return response.content
return response
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"langchain-chroma==0.2.2",
"langchain-ollama==0.3.0",
"langchain-openai==0.3.11",
"langchain-anthropic==0.3.10",
"pandas==2.2.3",
"numpy==1.26.4",
"pytest==8.3.5",
Expand Down