|
| 1 | +""" |
| 2 | +Modify the code from the mem0 project |
| 3 | +""" |
| 4 | + |
| 5 | +import argparse |
| 6 | +import concurrent.futures |
| 7 | +import json |
| 8 | +import os |
| 9 | +import threading |
| 10 | +import time |
| 11 | + |
| 12 | +from collections import defaultdict |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +import tiktoken |
| 16 | + |
| 17 | +from dotenv import load_dotenv |
| 18 | +from jinja2 import Template |
| 19 | +from openai import OpenAI |
| 20 | +from tqdm import tqdm |
| 21 | + |
| 22 | + |
| 23 | +load_dotenv() |
| 24 | + |
| 25 | +PROMPT = """ |
| 26 | +# Question: |
| 27 | +{{QUESTION}} |
| 28 | +
|
| 29 | +# Context: |
| 30 | +{{CONTEXT}} |
| 31 | +
|
| 32 | +# Short answer: |
| 33 | +""" |
| 34 | + |
| 35 | +TECHNIQUES = ["mem0", "rag"] |
| 36 | + |
| 37 | + |
| 38 | +class RAGManager: |
| 39 | + def __init__(self, data_path="data/locomo/locomo10_rag.json", chunk_size=500, k=2): |
| 40 | + self.model = os.getenv("MODEL") |
| 41 | + self.client = OpenAI() |
| 42 | + self.data_path = data_path |
| 43 | + self.chunk_size = chunk_size |
| 44 | + self.k = k |
| 45 | + |
| 46 | + def generate_response(self, question, context): |
| 47 | + template = Template(PROMPT) |
| 48 | + prompt = template.render(CONTEXT=context, QUESTION=question) |
| 49 | + |
| 50 | + max_retries = 3 |
| 51 | + retries = 0 |
| 52 | + |
| 53 | + while retries <= max_retries: |
| 54 | + try: |
| 55 | + t1 = time.time() |
| 56 | + response = self.client.chat.completions.create( |
| 57 | + model=self.model, |
| 58 | + messages=[ |
| 59 | + { |
| 60 | + "role": "system", |
| 61 | + "content": "You are a helpful assistant that can answer " |
| 62 | + "questions based on the provided context." |
| 63 | + "If the question involves timing, use the conversation date for reference." |
| 64 | + "Provide the shortest possible answer." |
| 65 | + "Use words directly from the conversation when possible." |
| 66 | + "Avoid using subjects in your answer.", |
| 67 | + }, |
| 68 | + {"role": "user", "content": prompt}, |
| 69 | + ], |
| 70 | + temperature=0, |
| 71 | + ) |
| 72 | + t2 = time.time() |
| 73 | + if response and response.choices: |
| 74 | + content = response.choices[0].message.content |
| 75 | + if content is not None: |
| 76 | + return content.strip(), t2 - t1 |
| 77 | + else: |
| 78 | + return "No content returned", t2 - t1 |
| 79 | + print("❎ No content returned!") |
| 80 | + else: |
| 81 | + return "Empty response", t2 - t1 |
| 82 | + except Exception as e: |
| 83 | + retries += 1 |
| 84 | + if retries > max_retries: |
| 85 | + raise e |
| 86 | + time.sleep(1) # Wait before retrying |
| 87 | + |
| 88 | + def clean_chat_history(self, chat_history): |
| 89 | + cleaned_chat_history = "" |
| 90 | + for c in chat_history: |
| 91 | + cleaned_chat_history += f"{c['timestamp']} | {c['speaker']}: {c['text']}\n" |
| 92 | + |
| 93 | + return cleaned_chat_history |
| 94 | + |
| 95 | + def calculate_embedding(self, document): |
| 96 | + response = self.client.embeddings.create(model=os.getenv("EMBEDDING_MODEL"), input=document) |
| 97 | + return response.data[0].embedding |
| 98 | + |
| 99 | + def calculate_similarity(self, embedding1, embedding2): |
| 100 | + return np.dot(embedding1, embedding2) / ( |
| 101 | + np.linalg.norm(embedding1) * np.linalg.norm(embedding2) |
| 102 | + ) |
| 103 | + |
| 104 | + def search(self, query, chunks, embeddings, k=1): |
| 105 | + """ |
| 106 | + Search for the top-k most similar chunks to the query. |
| 107 | +
|
| 108 | + Args: |
| 109 | + query: The query string |
| 110 | + chunks: List of text chunks |
| 111 | + embeddings: List of embeddings for each chunk |
| 112 | + k: Number of top chunks to return (default: 1) |
| 113 | +
|
| 114 | + Returns: |
| 115 | + combined_chunks: The combined text of the top-k chunks |
| 116 | + search_time: Time taken for the search |
| 117 | + """ |
| 118 | + t1 = time.time() |
| 119 | + query_embedding = self.calculate_embedding(query) |
| 120 | + similarities = [ |
| 121 | + self.calculate_similarity(query_embedding, embedding) for embedding in embeddings |
| 122 | + ] |
| 123 | + |
| 124 | + # Get indices of top-k most similar chunks |
| 125 | + top_indices = [np.argmax(similarities)] if k == 1 else np.argsort(similarities)[-k:][::-1] |
| 126 | + # Combine the top-k chunks |
| 127 | + combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices]) |
| 128 | + |
| 129 | + t2 = time.time() |
| 130 | + return combined_chunks, t2 - t1 |
| 131 | + |
| 132 | + def create_chunks(self, chat_history, chunk_size=500): |
| 133 | + """ |
| 134 | + Create chunks using tiktoken for more accurate token counting |
| 135 | + """ |
| 136 | + # Get the encoding for the model |
| 137 | + encoding = tiktoken.encoding_for_model(os.getenv("EMBEDDING_MODEL")) |
| 138 | + |
| 139 | + documents = self.clean_chat_history(chat_history) |
| 140 | + |
| 141 | + if chunk_size == -1: |
| 142 | + return [documents], [] |
| 143 | + |
| 144 | + chunks = [] |
| 145 | + |
| 146 | + # Encode the document |
| 147 | + tokens = encoding.encode(documents) |
| 148 | + |
| 149 | + # Split into chunks based on token count |
| 150 | + for i in range(0, len(tokens), chunk_size): |
| 151 | + chunk_tokens = tokens[i : i + chunk_size] |
| 152 | + chunk = encoding.decode(chunk_tokens) |
| 153 | + chunks.append(chunk) |
| 154 | + |
| 155 | + embeddings = [] |
| 156 | + for chunk in chunks: |
| 157 | + embedding = self.calculate_embedding(chunk) |
| 158 | + embeddings.append(embedding) |
| 159 | + |
| 160 | + return chunks, embeddings |
| 161 | + |
| 162 | + def process_all_conversations(self, output_file_path): |
| 163 | + with open(self.data_path) as f: |
| 164 | + data = json.load(f) |
| 165 | + |
| 166 | + final_results = defaultdict(list) |
| 167 | + for key, value in tqdm(data.items(), desc="Processing conversations"): |
| 168 | + chat_history = value["conversation"] |
| 169 | + questions = value["question"] |
| 170 | + |
| 171 | + chunks, embeddings = self.create_chunks(chat_history, self.chunk_size) |
| 172 | + |
| 173 | + for item in tqdm(questions, desc="Answering questions", leave=False): |
| 174 | + question = item["question"] |
| 175 | + answer = item.get("answer", "") |
| 176 | + category = item["category"] |
| 177 | + |
| 178 | + if self.chunk_size == -1: |
| 179 | + context = chunks[0] |
| 180 | + search_time = 0 |
| 181 | + else: |
| 182 | + context, search_time = self.search(question, chunks, embeddings, k=self.k) |
| 183 | + response, response_time = self.generate_response(question, context) |
| 184 | + |
| 185 | + final_results[key].append( |
| 186 | + { |
| 187 | + "question": question, |
| 188 | + "answer": answer, |
| 189 | + "category": category, |
| 190 | + "context": context, |
| 191 | + "response": response, |
| 192 | + "search_time": search_time, |
| 193 | + "response_time": response_time, |
| 194 | + } |
| 195 | + ) |
| 196 | + with open(output_file_path, "w+") as f: |
| 197 | + json.dump(final_results, f, indent=4) |
| 198 | + |
| 199 | + # Save results |
| 200 | + with open(output_file_path, "w+") as f: |
| 201 | + json.dump(final_results, f, indent=4) |
| 202 | + print("The original rag file have been generated!") |
| 203 | + |
| 204 | + |
| 205 | +class Experiment: |
| 206 | + def __init__(self, technique_type, chunk_size): |
| 207 | + self.technique_type = technique_type |
| 208 | + self.chunk_size = chunk_size |
| 209 | + |
| 210 | + def run(self): |
| 211 | + print( |
| 212 | + f"Running experiment with technique: {self.technique_type}, chunk size: {self.chunk_size}" |
| 213 | + ) |
| 214 | + |
| 215 | + |
| 216 | +def process_item(item_data): |
| 217 | + k, v = item_data |
| 218 | + local_results = defaultdict(list) |
| 219 | + |
| 220 | + for item in tqdm(v): |
| 221 | + gt_answer = str(item["answer"]) |
| 222 | + pred_answer = str(item["response"]) |
| 223 | + category = str(item["category"]) |
| 224 | + question = str(item["question"]) |
| 225 | + search_time = str(item["search_time"]) |
| 226 | + response_time = str(item["response_time"]) |
| 227 | + search_context = str(item["context"]) |
| 228 | + |
| 229 | + # Skip category 5 |
| 230 | + if category == "5": |
| 231 | + continue |
| 232 | + |
| 233 | + local_results[k].append( |
| 234 | + { |
| 235 | + "question": question, |
| 236 | + "golden_answer": gt_answer, |
| 237 | + "answer": pred_answer, |
| 238 | + "category": int(category), |
| 239 | + "response_duration_ms": float(response_time) * 1000, |
| 240 | + "search_duration_ms": float(search_time) * 1000, |
| 241 | + "search_context": search_context, |
| 242 | + # "llm_score_std":np.std(llm_score) |
| 243 | + } |
| 244 | + ) |
| 245 | + |
| 246 | + return local_results |
| 247 | + |
| 248 | + |
| 249 | +def rename_json_keys(file_path): |
| 250 | + with open(file_path, encoding="utf-8") as f: |
| 251 | + data = json.load(f) |
| 252 | + |
| 253 | + new_data = {} |
| 254 | + for old_key in data: |
| 255 | + new_key = f"locomo_exp_user_{old_key}" |
| 256 | + new_data[new_key] = data[old_key] |
| 257 | + |
| 258 | + with open(file_path, "w", encoding="utf-8") as f: |
| 259 | + json.dump(new_data, f, indent=2, ensure_ascii=False) |
| 260 | + |
| 261 | + |
| 262 | +def generate_response_file(file_path): |
| 263 | + parser = argparse.ArgumentParser(description="Evaluate RAG results") |
| 264 | + |
| 265 | + parser.add_argument( |
| 266 | + "--output_folder", |
| 267 | + type=str, |
| 268 | + default="default_locomo_responses.json", |
| 269 | + help="Path to save the evaluation results", |
| 270 | + ) |
| 271 | + parser.add_argument( |
| 272 | + "--max_workers", type=int, default=10, help="Maximum number of worker threads" |
| 273 | + ) |
| 274 | + parser.add_argument("--chunk_size", type=int, default=2000, help="Chunk size for processing") |
| 275 | + parser.add_argument("--num_chunks", type=int, default=2, help="Number of chunks to process") |
| 276 | + |
| 277 | + args = parser.parse_args() |
| 278 | + with open(file_path) as f: |
| 279 | + data = json.load(f) |
| 280 | + |
| 281 | + results = defaultdict(list) |
| 282 | + results_lock = threading.Lock() |
| 283 | + |
| 284 | + # Use ThreadPoolExecutor with specified workers |
| 285 | + with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor: |
| 286 | + futures = [executor.submit(process_item, item_data) for item_data in data.items()] |
| 287 | + |
| 288 | + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): |
| 289 | + local_results = future.result() |
| 290 | + with results_lock: |
| 291 | + for k, items in local_results.items(): |
| 292 | + results[k].extend(items) |
| 293 | + |
| 294 | + # Save results to JSON file |
| 295 | + with open(file_path, "w") as f: |
| 296 | + json.dump(results, f, indent=4) |
| 297 | + |
| 298 | + rename_json_keys(file_path) |
| 299 | + print(f"Results saved to {file_path}") |
| 300 | + |
| 301 | + |
| 302 | +def main(): |
| 303 | + parser = argparse.ArgumentParser(description="Run memory experiments") |
| 304 | + parser.add_argument( |
| 305 | + "--technique_type", choices=TECHNIQUES, default="rag", help="Memory technique to use" |
| 306 | + ) |
| 307 | + parser.add_argument("--chunk_size", type=int, default=2000, help="Chunk size for processing") |
| 308 | + parser.add_argument( |
| 309 | + "--output_folder", |
| 310 | + type=str, |
| 311 | + default="results/locomo/mem0-default/", |
| 312 | + help="Output path for results", |
| 313 | + ) |
| 314 | + parser.add_argument("--top_k", type=int, default=30, help="Number of top memories to retrieve") |
| 315 | + parser.add_argument("--num_chunks", type=int, default=2, help="Number of chunks to process") |
| 316 | + parser.add_argument("--frame", type=str, default="mem0") |
| 317 | + parser.add_argument("--version", type=str, default="default") |
| 318 | + |
| 319 | + args = parser.parse_args() |
| 320 | + |
| 321 | + response_path = f"{args.frame}_locomo_responses.json" |
| 322 | + |
| 323 | + if args.technique_type == "rag": |
| 324 | + output_file_path = os.path.join(args.output_folder, response_path) |
| 325 | + rag_manager = RAGManager( |
| 326 | + data_path="data/locomo/locomo10_rag.json", chunk_size=args.chunk_size, k=args.num_chunks |
| 327 | + ) |
| 328 | + rag_manager.process_all_conversations(output_file_path) |
| 329 | + """Generate response files""" |
| 330 | + generate_response_file(output_file_path) |
| 331 | + |
| 332 | + |
| 333 | +if __name__ == "__main__": |
| 334 | + start = time.time() |
| 335 | + main() |
| 336 | + end = time.time() |
| 337 | + print(f"Execution time is:{end - start}") |
0 commit comments