diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..042b6fc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +app.log \ No newline at end of file diff --git a/app.py b/app.py index a1571d6..60f165e 100644 --- a/app.py +++ b/app.py @@ -23,8 +23,7 @@ def search_tools(): formatted_results.append({ 'name': tool[0], 'description': tool[1], - 'link': tool[2] if len(tool) > 2 else '', - 'category': tool[3] if len(tool) > 3 else '' + 'link': tool[2] if len(tool) > 2 else '' }) return jsonify({'results': formatted_results}) except Exception as e: diff --git a/backend/main.py b/backend/main.py index 3345ebc..badffdd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,69 +1,22 @@ -import csv -from pathlib import Path -from .semantic_search import search +import pandas as pd +from backend.semantic_search import SemanticSearch +# Initialize the search object +search_object = SemanticSearch() +df = search_object.df # Get the dataframe from the search object -def find_indices(primary_list, query_list): +def search_tool(query: str): """ - Find the indices of elements from query_list in primary_list. - - Args: - primary_list (list): The list to search in - query_list (list): The list of elements to search for - - Returns: - list: A list of indices where query elements are found in primary list - """ - indices = [] - for query_item in query_list: - try: - index = primary_list.index(query_item) - indices.append(index) - except ValueError: - pass - return indices - - -csv_path = Path("backend/database/tool_list_database.csv") -with csv_path.open(newline='', encoding="utf-8") as f: - tools = list(csv.reader(f)) # includes header - header, *tools = tools # split header / body if you want -print("Loaded", len(tools), "rows") - - -descriptions=[] -final_outputs_list=[] - - -for r in tools: - text=f"{r[0]} {r[1]}" - descriptions.append(text.lower()) - - - -def search_tool(query): - """ - Searches for tools based on a query and returns the matching tool data. - - Args: - query (str): The search query string. - - Returns: - list: A list of lists, where each inner list represents a matching tool's data. + Searches for tools in the database based on a query. """ - # Find matching tool descriptions based on the query - matching_descriptions = search(descriptions, query.lower()) - - # Find the indices of these matching descriptions in the main descriptions list - matching_indices = find_indices(descriptions, matching_descriptions) - - # Collect the full tool data for each matching index - matching_tools_data = [] - for index in matching_indices: - matching_tools_data.append(tools[index]) - - return matching_tools_data - - - - + # Get the descriptions from the search results + descriptions = search_object.search(query) + + # Create the search key from the dataframe + df_search_key = (df['name'] + ' ' + df['description']) + + # Retrieve the full tool information from the dataframe + results = df[df_search_key.isin(descriptions)] + + # Convert the results to a list of lists + return results.values.tolist() \ No newline at end of file diff --git a/backend/semantic_search.py b/backend/semantic_search.py index 7203acd..aeeb56b 100644 --- a/backend/semantic_search.py +++ b/backend/semantic_search.py @@ -1,45 +1,74 @@ from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_community.retrievers import BM25Retriever +import pandas as pd +class SemanticSearch: + """ + A class to perform semantic search on a list of tools. + It uses a hybrid approach with BM25 and FAISS to get the best of both worlds: + - BM25 for keyword-based search. + - FAISS for semantic similarity search. + """ + def __init__(self, data_path='backend/database/tool_list_database.csv'): + """ + Initializes the SemanticSearch object. + - Loads the tool data from the CSV file. + - Initializes the HuggingFace embeddings model. + - Creates the BM25 and FAISS retrievers. + """ + # Load the sentence transformer model for embeddings + self.embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") + + # Load the tool data from the CSV file + self.df = pd.read_csv(data_path) + self.df.fillna('', inplace=True) + print(f"Loaded {len(self.df)} rows") -embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2") + # Create a list of documents for the retrievers + self.doc_list = (self.df['name'] + ' ' + self.df['description']).tolist() + + # Initialize the BM25 retriever for keyword-based search + self.bm25_retriever = BM25Retriever.from_texts(self.doc_list) + # Initialize the FAISS vector store and retriever for semantic search + self.faiss_vectorstore = FAISS.from_texts(self.doc_list, self.embedding) + self.faiss_retriever = self.faiss_vectorstore.as_retriever(search_kwargs={"k": 20}) -def search(doc_list, query, similarity_threshold=0.5): - # Create retrievers - bm25_retriever = BM25Retriever.from_texts(doc_list) - faiss_vectorstore = FAISS.from_texts(doc_list, embedding) - faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 20}) - - # Get results with scores - faiss_results = faiss_retriever.get_relevant_documents(query) - - # Filter FAISS results by similarity threshold - filtered_results = [] - for doc in faiss_results: - # FAISS returns distance, convert to similarity - # For cosine similarity: similarity = 1 - distance - distance = doc.metadata.get('score', 0) if 'score' in doc.metadata else 0 - similarity = 1 - distance - - if similarity >= similarity_threshold: - filtered_results.append(doc) - - # Get BM25 results (they don't have similarity scores in the same way) - bm25_results = bm25_retriever.get_relevant_documents(query) - - # Combine results (you can implement your own ensemble logic here) - # For now, let's prioritize FAISS results above threshold, then BM25 - unique_results = {} - - # Add filtered FAISS results first - for doc in filtered_results: - unique_results[doc.page_content] = doc - - # Add BM25 results (you might want to limit these too) - for doc in bm25_results[:10]: # Limit BM25 results - if doc.page_content not in unique_results: + def search(self, query, similarity_threshold=0.5): + """ + Performs a hybrid search using BM25 and FAISS. + - Gets the results from the FAISS retriever. + - Filters the FAISS results based on a similarity threshold. + - Gets the results from the BM25 retriever. + - Combines the results from both retrievers, giving priority to the FAISS results. + """ + # Get the semantic search results from FAISS + faiss_results = self.faiss_retriever.invoke(query) + + # Filter the FAISS results based on the similarity threshold + filtered_results = [] + for doc in faiss_results: + # The distance score from FAISS is converted to a similarity score + distance = doc.metadata.get('score', 0) if 'score' in doc.metadata else 0 + similarity = 1 - distance + + if similarity >= similarity_threshold: + filtered_results.append(doc) + + # Get the keyword-based search results from BM25 + bm25_results = self.bm25_retriever.invoke(query) + + # Combine the results from both retrievers + unique_results = {} + + # Add the filtered FAISS results first + for doc in filtered_results: unique_results[doc.page_content] = doc - - return list(unique_results.keys()) + + # Add the BM25 results, avoiding duplicates + for doc in bm25_results[:10]: # Limit the number of BM25 results + if doc.page_content not in unique_results: + unique_results[doc.page_content] = doc + + return list(unique_results.keys()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a6ef1a9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +faiss-cpu>=1.12.0 +flask>=3.1.2 +langchain>=0.3.27 +langchain-community>=0.3.29 +langchain-huggingface>=0.3.1 +rank-bm25>=0.2.2 +sentence-transformers>=5.1.0 +pandas>=2.2.0 \ No newline at end of file diff --git a/templates/index.html b/templates/index.html index d69bae5..3f17e72 100644 --- a/templates/index.html +++ b/templates/index.html @@ -713,10 +713,11 @@

Scanning Arsenal...

const toolsGrid = document.getElementById('toolsGrid'); const resultsCount = document.getElementById('resultsCount'); const loadingIndicator = document.getElementById('loadingIndicator'); - const sortOptions = document.getElementById('sortOptions'); const exampleTags = document.querySelectorAll('.example-tag'); - // Update time in status bar + /** + * Updates the current time in the status bar every second. + */ function updateTime() { const now = new Date(); document.getElementById('currentTime').textContent = now.toLocaleTimeString(); @@ -740,14 +741,9 @@

Scanning Arsenal...

}); }); - sortOptions.addEventListener('change', () => { - const query = searchInput.value.trim(); - if (query && toolsGrid.children.length > 0 && !toolsGrid.querySelector('.no-results')) { - performSearch(); // Re-search with current query to re-sort - } - }); - - // Search Function + /** + * Performs a search when the user clicks the search button or presses Enter. + */ function performSearch() { const query = searchInput.value.trim(); @@ -756,11 +752,11 @@

Scanning Arsenal...

return; } - // Show loading indicator + // Show loading indicator and hide results loadingIndicator.style.display = 'block'; resultsContainer.style.display = 'none'; - // Call Flask backend + // Call the Flask backend to perform the search fetch('/search', { method: 'POST', headers: { @@ -774,11 +770,11 @@

Scanning Arsenal...

throw new Error(data.error); } - // Handle null or empty results + // Display the search results const results = data.results || []; displayResults(results); - // Hide loading, show results + // Hide loading indicator and show results loadingIndicator.style.display = 'none'; resultsContainer.style.display = 'block'; }) @@ -786,7 +782,7 @@

Scanning Arsenal...

console.error('Error:', error); loadingIndicator.style.display = 'none'; - // Ensure resultsContainer is shown and properly display error + // Show an error message in the results container resultsContainer.style.display = 'block'; toolsGrid.innerHTML = `
@@ -806,25 +802,30 @@

Connection Error

}); } - // Display Results + /** + * Escapes HTML special characters to prevent XSS attacks. + * @param {string} str The string to escape. + * @returns {string} The escaped string. + */ + function escapeHTML(str) { + const div = document.createElement('div'); + div.appendChild(document.createTextNode(str)); + return div.innerHTML; + } + + /** + * Displays the search results in the tools grid. + * @param {Array} tools The list of tools to display. + */ function displayResults(tools) { // Ensure tools is an array if (!Array.isArray(tools)) { tools = []; } - // Sort tools based on selected option (removed 'name' option) - const sortBy = sortOptions ? sortOptions.value : 'relevance'; let sortedTools = [...tools]; - switch(sortBy) { - case 'category': - sortedTools.sort((a, b) => (a.category || 'Uncategorized').localeCompare(b.category || 'Uncategorized')); - break; - // 'relevance' keeps original order (from search) - } - - // Update results count + // Update the results count if (sortedTools.length === 0) { resultsCount.innerHTML = ` No tools found in arsenal`; } else { @@ -834,7 +835,7 @@

Connection Error

// Clear previous results toolsGrid.innerHTML = ''; - // Display message if no results + // Display a message if there are no results if (sortedTools.length === 0) { toolsGrid.innerHTML = `
@@ -849,30 +850,27 @@

No Tools Found in Arsenal

return; } - // Populate tools grid + // Create and append a tool card for each tool sortedTools.forEach((tool, index) => { const toolCard = document.createElement('div'); toolCard.className = 'tool-card'; - // Determine category class for styling - const categoryClass = getCategoryClass(tool.category); - + // Create the HTML for the tool card, escaping all user-provided data to prevent XSS toolCard.innerHTML = `
-

${tool.name || 'Unknown Tool'}

- ${tool.category || 'Uncategorized'} +

${escapeHTML(tool.name || 'Unknown Tool')}

-

${tool.description || 'No description available'}

+

${escapeHTML(tool.description || 'No description available')}

`; - // Add staggered animation + // Add a staggered animation to the tool cards toolCard.style.animationDelay = `${index * 0.1}s`; toolCard.style.animation = 'fadeInUp 0.6s ease forwards'; toolCard.style.opacity = '0'; @@ -881,25 +879,16 @@

${tool.name || 'Unknown Tool'}

}); } - // Get category class for styling - function getCategoryClass(category) { - if (!category) return ''; - const cat = category.toLowerCase(); - if (cat.includes('network')) return 'network'; - if (cat.includes('password')) return 'password'; - if (cat.includes('vulnerability') || cat.includes('vuln')) return 'vulnerability'; - if (cat.includes('forensic')) return 'forensics'; - if (cat.includes('web')) return 'web'; - return ''; - } - - // Show error with cybersecurity theme + /** + * Shows an error message in the tools grid. + * @param {string} message The error message to display. + */ function showError(message) { toolsGrid.innerHTML = `

System Error

-

${message}

+

${escapeHTML(message)}

Check network connection and try again

@@ -907,9 +896,13 @@

System Error

`; } - // Show alert function + /** + * Shows a temporary alert message in the top-right corner of the screen. + * @param {string} message The message to display. + * @param {string} type The type of alert ('info' or 'warning'). + */ function showAlert(message, type = 'info') { - // Create alert element + // Create the alert element const alert = document.createElement('div'); alert.style.cssText = ` position: fixed; @@ -925,18 +918,18 @@

System Error

border: 2px solid ${type === 'warning' ? '#ffaa00' : '#00ff88'}; box-shadow: 0 10px 30px rgba(0, 0, 0, 0.3); `; - alert.innerHTML = ` ${message}`; + alert.innerHTML = ` ${escapeHTML(message)}`; document.body.appendChild(alert); - // Remove after 3 seconds + // Remove the alert after 3 seconds setTimeout(() => { alert.style.animation = 'slideOutRight 0.3s ease'; setTimeout(() => alert.remove(), 300); }, 3000); } - // Add slide animations for alerts + // Add the slide-in/out animations for the alerts const style = document.createElement('style'); style.textContent = ` @keyframes slideInRight { @@ -950,7 +943,7 @@

System Error

`; document.head.appendChild(style); - // Initialize + // Initialize the application window.addEventListener('DOMContentLoaded', () => { // Ensure all elements exist before setting up if (resultsContainer) {