|
1 | 1 | import os |
2 | 2 | import json |
3 | | -from typing import Dict, List, Optional, Union |
4 | | -from openai import OpenAI |
| 3 | +import requests |
| 4 | +from typing import List, Dict, Any, Optional |
| 5 | +from openai import OpenAI # Add this import |
5 | 6 |
|
6 | 7 | class LlamaEdgeClient: |
7 | 8 | """Client for interacting with LlamaEdge OpenAI-compatible API""" |
8 | 9 |
|
9 | | - def __init__(self, api_key=None): |
| 10 | + def __init__(self, api_key=None, api_base=None, model=None, embed_model=None): |
| 11 | + """Initialize LlamaEdgeClient with API credentials |
| 12 | + |
| 13 | + Args: |
| 14 | + api_key: API key for LLM service |
| 15 | + api_base: Base URL for API (overrides LLM_API_BASE env var) |
| 16 | + model: Model name (overrides LLM_MODEL env var) |
| 17 | + embed_model: Embedding model name (overrides LLM_EMBED_MODEL env var) |
| 18 | + """ |
10 | 19 | self.api_key = api_key or os.getenv("LLM_API_KEY") |
11 | 20 | if not self.api_key: |
12 | 21 | raise ValueError("API key is required") |
13 | 22 |
|
14 | | - # Use environment variables with defaults |
15 | | - self.base_url = os.getenv("LLM_API_BASE", "http://localhost:8080/v1") |
16 | | - self.llm_model = os.getenv("LLM_MODEL", "Qwen2.5-Coder-3B-Instruct") |
17 | | - self.llm_embed_model = os.getenv("LLM_EMBED_MODEL", "gte-Qwen2-1.5B-instruct") # Fixed variable name |
| 23 | + # Use provided parameters with fallback to environment variables |
| 24 | + self.base_url = api_base or os.getenv("LLM_API_BASE", "http://localhost:8080/v1") |
| 25 | + self.llm_model = model or os.getenv("LLM_MODEL", "Qwen2.5-Coder-3B-Instruct") |
| 26 | + self.llm_embed_model = embed_model or os.getenv("LLM_EMBED_MODEL", "gte-Qwen2-1.5B-instruct") # Fixed variable name |
18 | 27 |
|
19 | 28 | # Initialize OpenAI client with custom base URL |
20 | 29 | self.client = OpenAI( |
|
0 commit comments