diff --git a/rag_agent.py b/rag_agent.py index 1656f08..5725460 100644 --- a/rag_agent.py +++ b/rag_agent.py @@ -2,6 +2,7 @@ # Uses a model from HuggingFace with optional 4-bit quantization import os +import re import argparse import torch from transformers import pipeline @@ -14,34 +15,50 @@ PROMPT_TEMPLATE = """ -You are a highly intelligent Python coding assistant built for kids using the Sugar Learning Platform. -1. Focus on coding-related problems, errors, and explanations. -2. Use the knowledge from the provided Pygame, GTK, and Sugar Toolkit documentation. -3. Provide complete, clear and concise answers. -4. Your answer must be easy to understand for kids. -5. Always include Sugar-specific guidance when relevant to the question. +You are a smart and helpful assistant designed to answer coding questions using context given. +1. The context above contains the relevant information needed to answer this specific question. +2. Use the information from this context to formulate your answer. +3. Prioritize and include any relevant details from the context. +4. Always answer in clear, complete and helpful way. + +Context: {context} Question: {question} + Answer: """ CHILD_FRIENDLY_PROMPT = """ -Your task is to answer children's questions using simple language. -You will be given an answer, you will have to paraphrase it. -Explain any difficult words in a way a 5-12-years-old can understand. +You are a helpful assistant who rewrites answers so that children aged 3 to 10 can understand them. -Original answer: {original_answer} +Your task: +- ONLY rewrite the Original answer using simple words and short sentences. +- Do NOT explain what the answer means. +- Do NOT explain what you're doing. +- Do NOT say anything extra. +- Do NOT repeat ideas. +- Do NOT add tips or encouragement. -Child-friendly answer: -""" +Just give the rewritten answer. Nothing else. +Original answer: +{original_answer} +Child-friendly answer: +""" def format_docs(docs): """Return all document content separated by two newlines.""" return "\n\n".join(doc.page_content for doc in docs) - +def trim_incomplete_sentence(text): + matches = list(re.finditer(r'\.\s', text)) + if matches: + last_complete = matches[-1].end() + return text[:last_complete].strip() + else: + return text.strip() + def combine_messages(x): """ If 'x' has a method to_messages, combine message content with newline. @@ -92,8 +109,9 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct", "text-generation", model=model_obj, tokenizer=tokenizer, - max_length=1024, - truncation=True, + max_new_tokens=1024, + temperature=0.3, + truncation=True ) tokenizer2 = AutoTokenizer.from_pretrained(model) @@ -101,14 +119,16 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct", "text-generation", model=model_obj, tokenizer=tokenizer2, - max_length=1024, - truncation=True, + max_new_tokens=1024, + temperature=0.3, + truncation=True ) else: self.model = pipeline( "text-generation", model=model, - max_length=1024, + max_new_tokens=1024, + temperature=0.3, truncation=True, torch_dtype=torch.float16, device=0 if torch.cuda.is_available() else -1, @@ -117,7 +137,8 @@ def __init__(self, model="Qwen/Qwen2-1.5B-Instruct", self.simplify_model = pipeline( "text-generation", model=model, - max_length=1024, + max_new_tokens=1024, + temperature=0.3, truncation=True, torch_dtype=torch.float16, device=0 if torch.cuda.is_available() else -1, @@ -133,7 +154,8 @@ def set_model(self, model): self.model = pipeline( "text-generation", model=model, - max_length=1024, + max_new_tokens=1024, + temperature=0.3, truncation=True, torch_dtype=torch.float16 ) @@ -141,7 +163,8 @@ def set_model(self, model): self.simplify_model = pipeline( "text-generation", model=model, - max_length=1024, + max_new_tokens=1024, + temperature=0.3, truncation=True, torch_dtype=torch.float16 ) @@ -166,13 +189,13 @@ def setup_vectorstore(self, file_paths): retriever = vector_store.as_retriever() return retriever - def get_relevant_document(self, query, threshold=0.5): + def get_relevant_document(self, query): results = self.retriever.invoke(query) + + print(f"[DEBUG] Retrieved results: {results}") + if results: - top_result = results[0] - score = top_result.metadata.get("score", 0.0) - if score >= threshold: - return top_result, score + return results[:2], 1.0 return None, 0.0 def run(self, question): @@ -180,29 +203,22 @@ def run(self, question): Build the QA chain and process the output from model generation. Apply double prompting to make answers child-friendly. """ - # Build the chain components: - chain_input = { - "context": self.retriever | format_docs, - "question": RunnablePassthrough() - } - # The chain applies: prompt -> combine messages -> model -> - # extract answer from output. + doc_result, _ = self.get_relevant_document(question) + + context_text = format_docs(doc_result) if doc_result else "" + first_chain = ( - chain_input + {"context": lambda x: context_text, "question": lambda x: x} | self.prompt | combine_messages - | self.model # Use the first model + | self.model | extract_answer_from_output ) - doc_result, _ = self.get_relevant_document(question) - if doc_result: - first_response = first_chain.invoke({ - "query": question, - "context": doc_result.page_content - }) - else: - first_response = first_chain.invoke(question) - + first_response = first_chain.invoke(question) + + # The chain applies: prompt -> combine messages -> model -> + # extract answer from output. + second_chain = ( {"original_answer": lambda x: x} | self.child_prompt @@ -210,9 +226,10 @@ def run(self, question): | self.simplify_model | extract_answer_from_output ) - + final_response = second_chain.invoke(first_response) - return final_response + + return trim_incomplete_sentence(final_response) def main(): @@ -257,7 +274,10 @@ def main(): print("Response:", response) except Exception as e: print(f"An error occurred: {e}") + import traceback + traceback.print_exc() if __name__ == "__main__": main() + \ No newline at end of file