diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..042b6fc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +app.log \ No newline at end of file diff --git a/app.py b/app.py index a1571d6..60f165e 100644 --- a/app.py +++ b/app.py @@ -23,8 +23,7 @@ def search_tools(): formatted_results.append({ 'name': tool[0], 'description': tool[1], - 'link': tool[2] if len(tool) > 2 else '', - 'category': tool[3] if len(tool) > 3 else '' + 'link': tool[2] if len(tool) > 2 else '' }) return jsonify({'results': formatted_results}) except Exception as e: diff --git a/backend/main.py b/backend/main.py index 3345ebc..badffdd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,69 +1,22 @@ -import csv -from pathlib import Path -from .semantic_search import search +import pandas as pd +from backend.semantic_search import SemanticSearch +# Initialize the search object +search_object = SemanticSearch() +df = search_object.df # Get the dataframe from the search object -def find_indices(primary_list, query_list): +def search_tool(query: str): """ - Find the indices of elements from query_list in primary_list. - - Args: - primary_list (list): The list to search in - query_list (list): The list of elements to search for - - Returns: - list: A list of indices where query elements are found in primary list - """ - indices = [] - for query_item in query_list: - try: - index = primary_list.index(query_item) - indices.append(index) - except ValueError: - pass - return indices - - -csv_path = Path("backend/database/tool_list_database.csv") -with csv_path.open(newline='', encoding="utf-8") as f: - tools = list(csv.reader(f)) # includes header - header, *tools = tools # split header / body if you want -print("Loaded", len(tools), "rows") - - -descriptions=[] -final_outputs_list=[] - - -for r in tools: - text=f"{r[0]} {r[1]}" - descriptions.append(text.lower()) - - - -def search_tool(query): - """ - Searches for tools based on a query and returns the matching tool data. - - Args: - query (str): The search query string. - - Returns: - list: A list of lists, where each inner list represents a matching tool's data. + Searches for tools in the database based on a query. """ - # Find matching tool descriptions based on the query - matching_descriptions = search(descriptions, query.lower()) - - # Find the indices of these matching descriptions in the main descriptions list - matching_indices = find_indices(descriptions, matching_descriptions) - - # Collect the full tool data for each matching index - matching_tools_data = [] - for index in matching_indices: - matching_tools_data.append(tools[index]) - - return matching_tools_data - - - - + # Get the descriptions from the search results + descriptions = search_object.search(query) + + # Create the search key from the dataframe + df_search_key = (df['name'] + ' ' + df['description']) + + # Retrieve the full tool information from the dataframe + results = df[df_search_key.isin(descriptions)] + + # Convert the results to a list of lists + return results.values.tolist() \ No newline at end of file diff --git a/backend/semantic_search.py b/backend/semantic_search.py index 7203acd..aeeb56b 100644 --- a/backend/semantic_search.py +++ b/backend/semantic_search.py @@ -1,45 +1,74 @@ from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_community.retrievers import BM25Retriever +import pandas as pd +class SemanticSearch: + """ + A class to perform semantic search on a list of tools. + It uses a hybrid approach with BM25 and FAISS to get the best of both worlds: + - BM25 for keyword-based search. + - FAISS for semantic similarity search. + """ + def __init__(self, data_path='backend/database/tool_list_database.csv'): + """ + Initializes the SemanticSearch object. + - Loads the tool data from the CSV file. + - Initializes the HuggingFace embeddings model. + - Creates the BM25 and FAISS retrievers. + """ + # Load the sentence transformer model for embeddings + self.embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") + + # Load the tool data from the CSV file + self.df = pd.read_csv(data_path) + self.df.fillna('', inplace=True) + print(f"Loaded {len(self.df)} rows") -embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2") + # Create a list of documents for the retrievers + self.doc_list = (self.df['name'] + ' ' + self.df['description']).tolist() + + # Initialize the BM25 retriever for keyword-based search + self.bm25_retriever = BM25Retriever.from_texts(self.doc_list) + # Initialize the FAISS vector store and retriever for semantic search + self.faiss_vectorstore = FAISS.from_texts(self.doc_list, self.embedding) + self.faiss_retriever = self.faiss_vectorstore.as_retriever(search_kwargs={"k": 20}) -def search(doc_list, query, similarity_threshold=0.5): - # Create retrievers - bm25_retriever = BM25Retriever.from_texts(doc_list) - faiss_vectorstore = FAISS.from_texts(doc_list, embedding) - faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 20}) - - # Get results with scores - faiss_results = faiss_retriever.get_relevant_documents(query) - - # Filter FAISS results by similarity threshold - filtered_results = [] - for doc in faiss_results: - # FAISS returns distance, convert to similarity - # For cosine similarity: similarity = 1 - distance - distance = doc.metadata.get('score', 0) if 'score' in doc.metadata else 0 - similarity = 1 - distance - - if similarity >= similarity_threshold: - filtered_results.append(doc) - - # Get BM25 results (they don't have similarity scores in the same way) - bm25_results = bm25_retriever.get_relevant_documents(query) - - # Combine results (you can implement your own ensemble logic here) - # For now, let's prioritize FAISS results above threshold, then BM25 - unique_results = {} - - # Add filtered FAISS results first - for doc in filtered_results: - unique_results[doc.page_content] = doc - - # Add BM25 results (you might want to limit these too) - for doc in bm25_results[:10]: # Limit BM25 results - if doc.page_content not in unique_results: + def search(self, query, similarity_threshold=0.5): + """ + Performs a hybrid search using BM25 and FAISS. + - Gets the results from the FAISS retriever. + - Filters the FAISS results based on a similarity threshold. + - Gets the results from the BM25 retriever. + - Combines the results from both retrievers, giving priority to the FAISS results. + """ + # Get the semantic search results from FAISS + faiss_results = self.faiss_retriever.invoke(query) + + # Filter the FAISS results based on the similarity threshold + filtered_results = [] + for doc in faiss_results: + # The distance score from FAISS is converted to a similarity score + distance = doc.metadata.get('score', 0) if 'score' in doc.metadata else 0 + similarity = 1 - distance + + if similarity >= similarity_threshold: + filtered_results.append(doc) + + # Get the keyword-based search results from BM25 + bm25_results = self.bm25_retriever.invoke(query) + + # Combine the results from both retrievers + unique_results = {} + + # Add the filtered FAISS results first + for doc in filtered_results: unique_results[doc.page_content] = doc - - return list(unique_results.keys()) + + # Add the BM25 results, avoiding duplicates + for doc in bm25_results[:10]: # Limit the number of BM25 results + if doc.page_content not in unique_results: + unique_results[doc.page_content] = doc + + return list(unique_results.keys()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a6ef1a9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +faiss-cpu>=1.12.0 +flask>=3.1.2 +langchain>=0.3.27 +langchain-community>=0.3.29 +langchain-huggingface>=0.3.1 +rank-bm25>=0.2.2 +sentence-transformers>=5.1.0 +pandas>=2.2.0 \ No newline at end of file diff --git a/templates/index.html b/templates/index.html index d69bae5..3f17e72 100644 --- a/templates/index.html +++ b/templates/index.html @@ -713,10 +713,11 @@
${tool.description || 'No description available'}
+${escapeHTML(tool.description || 'No description available')}
${message}
+${escapeHTML(message)}
Check network connection and try again
@@ -907,9 +896,13 @@