|
5 | 5 | from typing import Dict, List, Optional |
6 | 6 | from dotenv import load_dotenv |
7 | 7 | import tempfile |
| 8 | +import logging |
8 | 9 |
|
9 | 10 | # Load environment variables from .env file |
10 | 11 | load_dotenv() |
|
21 | 22 |
|
22 | 23 | app = FastAPI(title="Rust Project Generator API") |
23 | 24 |
|
| 25 | +class AppConfig: |
| 26 | + def __init__(self): |
| 27 | + self.api_key = os.getenv("LLM_API_KEY", "") |
| 28 | + self.skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true" |
| 29 | + self.embed_size = int(os.getenv("LLM_EMBED_SIZE", "1536")) |
| 30 | + # Add other config values |
| 31 | + |
| 32 | +config = AppConfig() |
| 33 | + |
24 | 34 | # Get API key from environment variable (make optional) |
25 | | -api_key = os.getenv("LLM_API_KEY", "") |
| 35 | +api_key = config.api_key |
26 | 36 | # Only validate if not using local setup |
27 | 37 | if not api_key and not (os.getenv("LLM_API_BASE", "").startswith("http://localhost") or |
28 | 38 | os.getenv("LLM_API_BASE", "").startswith("http://host.docker.internal")): |
29 | 39 | raise ValueError("LLM_API_KEY environment variable not set") |
30 | 40 |
|
31 | | -# Get embedding size from environment variable |
32 | | -llm_embed_size = int(os.getenv("LLM_EMBED_SIZE", "1536")) # Default to 1536 for compatibility |
33 | | - |
34 | 41 | # Initialize components |
35 | 42 | llm_client = LlamaEdgeClient(api_key=api_key) |
36 | 43 | prompt_gen = PromptGenerator() |
|
39 | 46 |
|
40 | 47 | # Initialize vector store |
41 | 48 | try: |
42 | | - vector_store = QdrantStore(embedding_size=llm_embed_size) |
43 | | - if os.getenv("SKIP_VECTOR_SEARCH", "").lower() != "true": |
| 49 | + vector_store = QdrantStore(embedding_size=config.embed_size) |
| 50 | + if config.skip_vector_search != "true": |
44 | 51 | vector_store.create_collection("project_examples") |
45 | 52 | vector_store.create_collection("error_examples") |
46 | 53 |
|
@@ -160,7 +167,21 @@ async def compile_rust(request: dict): |
160 | 167 |
|
161 | 168 | @app.post("/compile-and-fix") |
162 | 169 | async def compile_and_fix_rust(request: dict): |
163 | | - """Endpoint to compile and fix Rust code""" |
| 170 | + """ |
| 171 | + Compile Rust code and automatically fix compilation errors. |
| 172 | + |
| 173 | + Args: |
| 174 | + request (dict): Dictionary containing: |
| 175 | + - code (str): Multi-file Rust code with filename markers |
| 176 | + - description (str): Project description |
| 177 | + - max_attempts (int, optional): Maximum fix attempts (default: 10) |
| 178 | + |
| 179 | + Returns: |
| 180 | + JSONResponse: Result of compilation with fixed code if successful |
| 181 | + |
| 182 | + Raises: |
| 183 | + HTTPException: If required fields are missing or processing fails |
| 184 | + """ |
164 | 185 | if "code" not in request or "description" not in request: |
165 | 186 | raise HTTPException(status_code=400, detail="Missing required fields") |
166 | 187 |
|
@@ -225,7 +246,7 @@ async def compile_and_fix_rust(request: dict): |
225 | 246 |
|
226 | 247 | # Find similar errors in vector DB |
227 | 248 | similar_errors = [] |
228 | | - if vector_store is not None and os.getenv("SKIP_VECTOR_SEARCH", "").lower() != "true": |
| 249 | + if vector_store is not None and config.skip_vector_search != "true": |
229 | 250 | try: |
230 | 251 | # Find similar errors in vector DB |
231 | 252 | error_embedding = llm_client.get_embeddings([error_context["full_error"]])[0] |
@@ -309,7 +330,7 @@ async def handle_project_generation( |
309 | 330 |
|
310 | 331 | try: |
311 | 332 | # Skip vector search if environment variable is set |
312 | | - skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true" |
| 333 | + skip_vector_search = config.skip_vector_search |
313 | 334 |
|
314 | 335 | example_text = "" |
315 | 336 | if not skip_vector_search: |
@@ -397,7 +418,7 @@ async def handle_project_generation( |
397 | 418 | error_context = compiler.extract_error_context(output) |
398 | 419 |
|
399 | 420 | # Skip vector search if environment variable is set |
400 | | - skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true" |
| 421 | + skip_vector_search = config.skip_vector_search |
401 | 422 | similar_errors = [] |
402 | 423 |
|
403 | 424 | if not skip_vector_search: |
@@ -468,9 +489,11 @@ async def handle_project_generation( |
468 | 489 | }) |
469 | 490 | save_status(project_dir, status) |
470 | 491 |
|
471 | | -def save_status(project_dir: str, status: Dict): |
472 | | - """Save project status to file""" |
473 | | - with open(f"{project_dir}/status.json", 'w') as f: |
| 492 | +def save_status(project_dir, status): |
| 493 | + """Save project status to file with proper resource management""" |
| 494 | + status_path = f"{project_dir}/status.json" |
| 495 | + os.makedirs(os.path.dirname(status_path), exist_ok=True) |
| 496 | + with open(status_path, 'w') as f: |
474 | 497 | json.dump(status, f) |
475 | 498 |
|
476 | 499 | @app.get("/project/{project_id}/files/{file_path:path}") |
@@ -536,7 +559,7 @@ async def generate_project_sync(request: ProjectRequest): |
536 | 559 | similar_errors = [] |
537 | 560 |
|
538 | 561 | # Skip vector search if environment variable is set |
539 | | - skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true" |
| 562 | + skip_vector_search = config.skip_vector_search |
540 | 563 |
|
541 | 564 | if not skip_vector_search: |
542 | 565 | try: |
@@ -621,7 +644,7 @@ async def generate_project_sync(request: ProjectRequest): |
621 | 644 | error_context = compiler.extract_error_context(output) |
622 | 645 |
|
623 | 646 | # Skip vector search if environment variable is set |
624 | | - skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true" |
| 647 | + skip_vector_search = config.skip_vector_search |
625 | 648 |
|
626 | 649 | if not skip_vector_search: |
627 | 650 | try: |
@@ -700,3 +723,47 @@ async def generate_project_sync(request: ProjectRequest): |
700 | 723 |
|
701 | 724 | except Exception as e: |
702 | 725 | raise HTTPException(status_code=500, detail=f"Error generating project: {str(e)}") |
| 726 | + |
| 727 | +def find_similar_projects(description, vector_store, llm_client): |
| 728 | + """Find similar projects in vector store""" |
| 729 | + skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true" |
| 730 | + example_text = "" |
| 731 | + |
| 732 | + if not skip_vector_search: |
| 733 | + try: |
| 734 | + query_embedding = llm_client.get_embeddings([description])[0] |
| 735 | + similar_projects = vector_store.search("project_examples", query_embedding, limit=1) |
| 736 | + if similar_projects: |
| 737 | + example_text = f"\nHere's a similar project you can use as reference:\n{similar_projects[0]['example']}" |
| 738 | + except Exception as e: |
| 739 | + logger.warning(f"Vector search error (non-critical): {e}") |
| 740 | + |
| 741 | + return example_text |
| 742 | + |
| 743 | +logger = logging.getLogger(__name__) |
| 744 | + |
| 745 | +def extract_and_find_similar_errors(error_output, vector_store, llm_client): |
| 746 | + """Extract error context and find similar errors""" |
| 747 | + error_context = compiler.extract_error_context(error_output) |
| 748 | + similar_errors = [] |
| 749 | + |
| 750 | + if vector_store and not config.skip_vector_search: |
| 751 | + try: |
| 752 | + error_embedding = llm_client.get_embeddings([error_context["full_error"]])[0] |
| 753 | + similar_errors = vector_store.search("error_examples", error_embedding, limit=3) |
| 754 | + except Exception as e: |
| 755 | + logger.warning(f"Vector search error: {e}") |
| 756 | + |
| 757 | + return error_context, similar_errors |
| 758 | + |
| 759 | +# try: |
| 760 | +# # ...specific operation |
| 761 | +# except FileNotFoundError as e: |
| 762 | +# logger.error(f"File not found: {e}") |
| 763 | +# # Handle specifically |
| 764 | +# except PermissionError as e: |
| 765 | +# logger.error(f"Permission denied: {e}") |
| 766 | +# # Handle specifically |
| 767 | +# except Exception as e: |
| 768 | +# logger.exception(f"Unexpected error: {e}") |
| 769 | +# # Generic fallback |
0 commit comments