Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 66 additions & 46 deletions rag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Uses a model from HuggingFace with optional 4-bit quantization

import os
import re
import argparse
import torch
from transformers import pipeline
Expand All @@ -14,34 +15,50 @@


PROMPT_TEMPLATE = """
You are a highly intelligent Python coding assistant built for kids using the Sugar Learning Platform.
1. Focus on coding-related problems, errors, and explanations.
2. Use the knowledge from the provided Pygame, GTK, and Sugar Toolkit documentation.
3. Provide complete, clear and concise answers.
4. Your answer must be easy to understand for kids.
5. Always include Sugar-specific guidance when relevant to the question.
You are a smart and helpful assistant designed to answer coding questions using context given.

1. The context above contains the relevant information needed to answer this specific question.
2. Use the information from this context to formulate your answer.
3. Prioritize and include any relevant details from the context.
4. Always answer in clear, complete and helpful way.

Context: {context}
Question: {question}

Answer:
"""

CHILD_FRIENDLY_PROMPT = """
Your task is to answer children's questions using simple language.
You will be given an answer, you will have to paraphrase it.
Explain any difficult words in a way a 5-12-years-old can understand.
You are a helpful assistant who rewrites answers so that children aged 3 to 10 can understand them.

Original answer: {original_answer}
Your task:
- ONLY rewrite the Original answer using simple words and short sentences.
- Do NOT explain what the answer means.
- Do NOT explain what you're doing.
- Do NOT say anything extra.
- Do NOT repeat ideas.
- Do NOT add tips or encouragement.

Child-friendly answer:
"""
Just give the rewritten answer. Nothing else.

Original answer:
{original_answer}

Child-friendly answer:
"""

def format_docs(docs):
"""Return all document content separated by two newlines."""
return "\n\n".join(doc.page_content for doc in docs)


def trim_incomplete_sentence(text):
matches = list(re.finditer(r'\.\s', text))
if matches:
last_complete = matches[-1].end()
return text[:last_complete].strip()
else:
return text.strip()

def combine_messages(x):
"""
If 'x' has a method to_messages, combine message content with newline.
Expand Down Expand Up @@ -92,23 +109,26 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct",
"text-generation",
model=model_obj,
tokenizer=tokenizer,
max_length=1024,
truncation=True,
max_new_tokens=1024,
temperature=0.3,
truncation=True
)

tokenizer2 = AutoTokenizer.from_pretrained(model)
self.simplify_model = pipeline(
"text-generation",
model=model_obj,
tokenizer=tokenizer2,
max_length=1024,
truncation=True,
max_new_tokens=1024,
temperature=0.3,
truncation=True
)
else:
self.model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=1024,
temperature=0.3,
truncation=True,
torch_dtype=torch.float16,
device=0 if torch.cuda.is_available() else -1,
Expand All @@ -117,7 +137,8 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct",
self.simplify_model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=1024,
temperature=0.3,
truncation=True,
torch_dtype=torch.float16,
device=0 if torch.cuda.is_available() else -1,
Expand All @@ -133,15 +154,17 @@ def set_model(self, model):
self.model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=1024,
temperature=0.3,
truncation=True,
torch_dtype=torch.float16
)

self.simplify_model = pipeline(
"text-generation",
model=model,
max_length=1024,
max_new_tokens=1024,
temperature=0.3,
truncation=True,
torch_dtype=torch.float16
)
Expand All @@ -166,53 +189,47 @@ def setup_vectorstore(self, file_paths):
retriever = vector_store.as_retriever()
return retriever

def get_relevant_document(self, query, threshold=0.5):
def get_relevant_document(self, query):
results = self.retriever.invoke(query)

print(f"[DEBUG] Retrieved results: {results}")

if results:
top_result = results[0]
score = top_result.metadata.get("score", 0.0)
if score >= threshold:
return top_result, score
return results[:2], 1.0
return None, 0.0

def run(self, question):
"""
Build the QA chain and process the output from model generation.
Apply double prompting to make answers child-friendly.
"""
# Build the chain components:
chain_input = {
"context": self.retriever | format_docs,
"question": RunnablePassthrough()
}
# The chain applies: prompt -> combine messages -> model ->
# extract answer from output.
doc_result, _ = self.get_relevant_document(question)

context_text = format_docs(doc_result) if doc_result else ""

first_chain = (
chain_input
{"context": lambda x: context_text, "question": lambda x: x}
| self.prompt
| combine_messages
| self.model # Use the first model
| self.model
| extract_answer_from_output
)
doc_result, _ = self.get_relevant_document(question)
if doc_result:
first_response = first_chain.invoke({
"query": question,
"context": doc_result.page_content
})
else:
first_response = first_chain.invoke(question)

first_response = first_chain.invoke(question)

# The chain applies: prompt -> combine messages -> model ->
# extract answer from output.

second_chain = (
{"original_answer": lambda x: x}
| self.child_prompt
| combine_messages
| self.simplify_model
| extract_answer_from_output
)

final_response = second_chain.invoke(first_response)
return final_response

return trim_incomplete_sentence(final_response)


def main():
Expand Down Expand Up @@ -257,7 +274,10 @@ def main():
print("Response:", response)
except Exception as e:
print(f"An error occurred: {e}")
import traceback
traceback.print_exc()


if __name__ == "__main__":
main()