Skip to content

Commit 33f5e1e

Browse files
authored
Merge pull request #18 from Acuspeedster/main
2 parents 42699af + 3f642bf commit 33f5e1e

File tree

3 files changed

+154
-60
lines changed

3 files changed

+154
-60
lines changed

app/load_data.py

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,60 @@
1-
# app/load_data.py
21
import json
32
import os
43
import uuid
54
from glob import glob
65
from app.llm_client import LlamaEdgeClient
76
from app.vector_store import QdrantStore
7+
import logging
88

9-
def load_project_examples():
10-
"""Load project examples into vector database"""
11-
vector_store = QdrantStore()
12-
llm_client = LlamaEdgeClient()
13-
9+
PROJECT_COLLECTION = "project_examples"
10+
ERROR_COLLECTION = "error_examples"
11+
PROJECT_DATA_PATH = "data/project_examples/*.json"
12+
ERROR_DATA_PATH = "data/error_examples/*.json"
13+
14+
logging.basicConfig(level=logging.INFO)
15+
logger = logging.getLogger(__name__)
16+
17+
def load_examples(vector_store, llm_client, collection_name, file_pattern, text_key):
18+
"""Load examples into vector database"""
1419
# Ensure collections exist
15-
vector_store.create_collection("project_examples")
20+
vector_store.create_collection(collection_name)
1621

17-
example_files = glob("data/project_examples/*.json")
22+
example_files = glob(file_pattern)
23+
24+
# Collect all embeddings and metadata first
25+
embeddings = []
26+
metadata = []
1827

1928
for file_path in example_files:
2029
with open(file_path, 'r') as f:
2130
example = json.load(f)
2231

23-
# Get embedding for query
24-
embedding = llm_client.get_embeddings([example["query"]])[0]
25-
26-
# Store in vector DB with proper UUID
27-
point_id = str(uuid.uuid4()) # Generate proper UUID
28-
29-
vector_store.upsert("project_examples",
30-
[{"id": point_id, # Use UUID instead of filename
31-
"vector": embedding,
32-
"payload": example}])
33-
34-
print(f"Loaded project example: {example['query']}")
32+
# Get embedding for query or error
33+
try:
34+
embedding = llm_client.get_embeddings([example[text_key]])[0]
35+
embeddings.append(embedding)
36+
metadata.append(example)
37+
logger.info(f"Loaded {collection_name[:-1]} example: {example[text_key][:50]}...")
38+
except Exception as e:
39+
logger.error(f"Error loading {file_path}: {e}")
40+
41+
# Insert all documents in a single batch
42+
if embeddings:
43+
vector_store.insert_documents(collection_name, embeddings, metadata)
44+
45+
def load_project_examples():
46+
"""Load project examples into vector database"""
47+
vector_store = QdrantStore()
48+
llm_client = LlamaEdgeClient()
49+
50+
load_examples(vector_store, llm_client, PROJECT_COLLECTION, PROJECT_DATA_PATH, "query")
3551

3652
def load_error_examples():
3753
"""Load compiler error examples into vector database"""
3854
vector_store = QdrantStore()
3955
llm_client = LlamaEdgeClient()
4056

41-
# Ensure collections exist
42-
vector_store.create_collection("error_examples")
43-
44-
error_files = glob("data/error_examples/*.json")
45-
46-
for file_path in error_files:
47-
with open(file_path, 'r') as f:
48-
example = json.load(f)
49-
50-
# Get embedding for error
51-
embedding = llm_client.get_embeddings([example["error"]])[0]
52-
53-
# Store in vector DB with proper UUID
54-
point_id = str(uuid.uuid4())
55-
56-
# Store in vector DB
57-
vector_store.upsert("error_examples",
58-
[{"id": point_id,
59-
"vector": embedding,
60-
"payload": example}])
61-
62-
print(f"Loaded error example: {example['error'][:50]}...")
57+
load_examples(vector_store, llm_client, ERROR_COLLECTION, ERROR_DATA_PATH, "error")
6358

6459
if __name__ == "__main__":
6560
load_project_examples()

app/main.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict, List, Optional
66
from dotenv import load_dotenv
77
import tempfile
8+
import logging
89

910
# Load environment variables from .env file
1011
load_dotenv()
@@ -21,16 +22,22 @@
2122

2223
app = FastAPI(title="Rust Project Generator API")
2324

25+
class AppConfig:
26+
def __init__(self):
27+
self.api_key = os.getenv("LLM_API_KEY", "")
28+
self.skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
29+
self.embed_size = int(os.getenv("LLM_EMBED_SIZE", "1536"))
30+
# Add other config values
31+
32+
config = AppConfig()
33+
2434
# Get API key from environment variable (make optional)
25-
api_key = os.getenv("LLM_API_KEY", "")
35+
api_key = config.api_key
2636
# Only validate if not using local setup
2737
if not api_key and not (os.getenv("LLM_API_BASE", "").startswith("http://localhost") or
2838
os.getenv("LLM_API_BASE", "").startswith("http://host.docker.internal")):
2939
raise ValueError("LLM_API_KEY environment variable not set")
3040

31-
# Get embedding size from environment variable
32-
llm_embed_size = int(os.getenv("LLM_EMBED_SIZE", "1536")) # Default to 1536 for compatibility
33-
3441
# Initialize components
3542
llm_client = LlamaEdgeClient(api_key=api_key)
3643
prompt_gen = PromptGenerator()
@@ -39,8 +46,8 @@
3946

4047
# Initialize vector store
4148
try:
42-
vector_store = QdrantStore(embedding_size=llm_embed_size)
43-
if os.getenv("SKIP_VECTOR_SEARCH", "").lower() != "true":
49+
vector_store = QdrantStore(embedding_size=config.embed_size)
50+
if config.skip_vector_search != "true":
4451
vector_store.create_collection("project_examples")
4552
vector_store.create_collection("error_examples")
4653

@@ -160,7 +167,21 @@ async def compile_rust(request: dict):
160167

161168
@app.post("/compile-and-fix")
162169
async def compile_and_fix_rust(request: dict):
163-
"""Endpoint to compile and fix Rust code"""
170+
"""
171+
Compile Rust code and automatically fix compilation errors.
172+
173+
Args:
174+
request (dict): Dictionary containing:
175+
- code (str): Multi-file Rust code with filename markers
176+
- description (str): Project description
177+
- max_attempts (int, optional): Maximum fix attempts (default: 10)
178+
179+
Returns:
180+
JSONResponse: Result of compilation with fixed code if successful
181+
182+
Raises:
183+
HTTPException: If required fields are missing or processing fails
184+
"""
164185
if "code" not in request or "description" not in request:
165186
raise HTTPException(status_code=400, detail="Missing required fields")
166187

@@ -225,7 +246,7 @@ async def compile_and_fix_rust(request: dict):
225246

226247
# Find similar errors in vector DB
227248
similar_errors = []
228-
if vector_store is not None and os.getenv("SKIP_VECTOR_SEARCH", "").lower() != "true":
249+
if vector_store is not None and config.skip_vector_search != "true":
229250
try:
230251
# Find similar errors in vector DB
231252
error_embedding = llm_client.get_embeddings([error_context["full_error"]])[0]
@@ -309,7 +330,7 @@ async def handle_project_generation(
309330

310331
try:
311332
# Skip vector search if environment variable is set
312-
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
333+
skip_vector_search = config.skip_vector_search
313334

314335
example_text = ""
315336
if not skip_vector_search:
@@ -397,7 +418,7 @@ async def handle_project_generation(
397418
error_context = compiler.extract_error_context(output)
398419

399420
# Skip vector search if environment variable is set
400-
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
421+
skip_vector_search = config.skip_vector_search
401422
similar_errors = []
402423

403424
if not skip_vector_search:
@@ -468,9 +489,11 @@ async def handle_project_generation(
468489
})
469490
save_status(project_dir, status)
470491

471-
def save_status(project_dir: str, status: Dict):
472-
"""Save project status to file"""
473-
with open(f"{project_dir}/status.json", 'w') as f:
492+
def save_status(project_dir, status):
493+
"""Save project status to file with proper resource management"""
494+
status_path = f"{project_dir}/status.json"
495+
os.makedirs(os.path.dirname(status_path), exist_ok=True)
496+
with open(status_path, 'w') as f:
474497
json.dump(status, f)
475498

476499
@app.get("/project/{project_id}/files/{file_path:path}")
@@ -536,7 +559,7 @@ async def generate_project_sync(request: ProjectRequest):
536559
similar_errors = []
537560

538561
# Skip vector search if environment variable is set
539-
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
562+
skip_vector_search = config.skip_vector_search
540563

541564
if not skip_vector_search:
542565
try:
@@ -621,7 +644,7 @@ async def generate_project_sync(request: ProjectRequest):
621644
error_context = compiler.extract_error_context(output)
622645

623646
# Skip vector search if environment variable is set
624-
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
647+
skip_vector_search = config.skip_vector_search
625648

626649
if not skip_vector_search:
627650
try:
@@ -700,3 +723,47 @@ async def generate_project_sync(request: ProjectRequest):
700723

701724
except Exception as e:
702725
raise HTTPException(status_code=500, detail=f"Error generating project: {str(e)}")
726+
727+
def find_similar_projects(description, vector_store, llm_client):
728+
"""Find similar projects in vector store"""
729+
skip_vector_search = os.getenv("SKIP_VECTOR_SEARCH", "").lower() == "true"
730+
example_text = ""
731+
732+
if not skip_vector_search:
733+
try:
734+
query_embedding = llm_client.get_embeddings([description])[0]
735+
similar_projects = vector_store.search("project_examples", query_embedding, limit=1)
736+
if similar_projects:
737+
example_text = f"\nHere's a similar project you can use as reference:\n{similar_projects[0]['example']}"
738+
except Exception as e:
739+
logger.warning(f"Vector search error (non-critical): {e}")
740+
741+
return example_text
742+
743+
logger = logging.getLogger(__name__)
744+
745+
def extract_and_find_similar_errors(error_output, vector_store, llm_client):
746+
"""Extract error context and find similar errors"""
747+
error_context = compiler.extract_error_context(error_output)
748+
similar_errors = []
749+
750+
if vector_store and not config.skip_vector_search:
751+
try:
752+
error_embedding = llm_client.get_embeddings([error_context["full_error"]])[0]
753+
similar_errors = vector_store.search("error_examples", error_embedding, limit=3)
754+
except Exception as e:
755+
logger.warning(f"Vector search error: {e}")
756+
757+
return error_context, similar_errors
758+
759+
# try:
760+
# # ...specific operation
761+
# except FileNotFoundError as e:
762+
# logger.error(f"File not found: {e}")
763+
# # Handle specifically
764+
# except PermissionError as e:
765+
# logger.error(f"Permission denied: {e}")
766+
# # Handle specifically
767+
# except Exception as e:
768+
# logger.exception(f"Unexpected error: {e}")
769+
# # Generic fallback

app/vector_store.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from qdrant_client import QdrantClient
55
from qdrant_client.http import models as qmodels
66
from qdrant_client import models # Add this import
7+
from glob import glob
8+
import json
9+
from app.llm_client import LlamaEdgeClient # Adjust the import based on your project structure
710

811
class QdrantStore:
912
"""Interface for Qdrant vector database"""
@@ -55,9 +58,9 @@ def insert_documents(self, collection_name: str, embeddings: List[List[float]],
5558
metadata: List[Dict[str, Any]]):
5659
"""Insert documents with embeddings and metadata into collection"""
5760
points = []
58-
for i, (embedding, meta) in enumerate(zip(embeddings, metadata)):
61+
for embedding, meta in zip(embeddings, metadata):
5962
points.append(models.PointStruct(
60-
id=i,
63+
id=str(uuid.uuid4()), # Using UUID instead of index
6164
vector=embedding,
6265
payload=meta
6366
))
@@ -121,4 +124,33 @@ def count(self, collection_name: str) -> int:
121124
return collection_info.vectors_count
122125
except Exception as e:
123126
print(f"Error getting count for collection {collection_name}: {e}")
124-
return 0
127+
return 0
128+
129+
def load_project_examples():
130+
"""Load project examples into vector database"""
131+
vector_store = QdrantStore()
132+
llm_client = LlamaEdgeClient()
133+
134+
# Ensure collections exist
135+
vector_store.create_collection("project_examples")
136+
137+
example_files = glob("data/project_examples/*.json")
138+
139+
embeddings = []
140+
metadata = []
141+
142+
for file_path in example_files:
143+
with open(file_path, 'r') as f:
144+
example = json.load(f)
145+
146+
# Get embedding for query
147+
embedding = llm_client.get_embeddings([example["query"]])[0]
148+
149+
embeddings.append(embedding)
150+
metadata.append(example)
151+
152+
print(f"Loaded project example: {example['query']}")
153+
154+
# Insert all documents in a single batch
155+
if embeddings:
156+
vector_store.insert_documents("project_examples", embeddings, metadata)

0 commit comments

Comments
 (0)