Skip to content

Add ToolsRetriever class and convert_retriever_to_tool() function #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## Next

### Added

- Added `ToolsRetriever` class to use LLM function calling to retrieve results from a Retriever.
- Added `convert_retriever_to_tool()` function convert any Retriever to a Tool.

## 1.7.0

### Added
Expand Down
7 changes: 6 additions & 1 deletion examples/customize/llms/openai_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@

from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
from neo4j_graphrag.tools.tool import (
Tool,
ObjectParameter,
StringParameter,
IntegerParameter,
)

# Load environment variables from .env file (OPENAI_API_KEY required for this example)
load_dotenv()
Expand Down
7 changes: 6 additions & 1 deletion examples/customize/llms/vertexai_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

from neo4j_graphrag.llm import VertexAILLM
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
from neo4j_graphrag.tools.tool import (
Tool,
ObjectParameter,
StringParameter,
IntegerParameter,
)

# Load environment variables from .env file
load_dotenv()
Expand Down
121 changes: 121 additions & 0 deletions examples/retrieve/tools/retriever_to_tool_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Example demonstrating how to convert a retriever to a tool.

This example shows:
1. How to convert a custom StaticRetriever to a Tool
2. How to define parameters for the tool
3. How to execute the tool
"""

import neo4j
from typing import Optional, Any, cast
from unittest.mock import MagicMock

from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RawSearchResult
from neo4j_graphrag.tools.tool import (
StringParameter,
ObjectParameter,
)
from neo4j_graphrag.tools.utils import convert_retriever_to_tool


# Create a Retriever that returns static results about Neo4j
# This would illustrate the conversion process of any Retriever (Vector, Hybrid, etc.)
class StaticRetriever(Retriever):
"""A retriever that returns static results about Neo4j."""

# Disable Neo4j version verification
VERIFY_NEO4J_VERSION = False

def __init__(self, driver: neo4j.Driver):
# Call the parent class constructor with the driver
super().__init__(driver)

def get_search_results(
self, query_text: Optional[str] = None, **kwargs: Any
) -> RawSearchResult:
"""Return static information about Neo4j regardless of the query."""
# Create formatted Neo4j information
neo4j_info = (
"# Neo4j Graph Database\n\n"
"Neo4j is a graph database management system developed by Neo4j, Inc. "
"It is an ACID-compliant transactional database with native graph storage and processing.\n\n"
"## Key Features:\n\n"
"- **Cypher Query Language**: Neo4j's intuitive query language designed specifically for working with graph data\n"
"- **Property Graphs**: Both nodes and relationships can have properties (key-value pairs)\n"
"- **ACID Compliance**: Ensures data integrity with full transaction support\n"
"- **Native Graph Storage**: Optimized storage for graph data structures\n"
"- **High Availability**: Clustering for enterprise deployments\n"
"- **Scalability**: Handles billions of nodes and relationships"
)

# Create a Neo4j record with the information
records = [neo4j.Record({"result": neo4j_info})]

# Return a RawSearchResult with the records and metadata
return RawSearchResult(records=records, metadata={"query": query_text})


def main() -> None:
# Convert a StaticRetriever to a tool with specific parameters
static_retriever = StaticRetriever(driver=cast(Any, MagicMock()))

# Define parameters for the static retriever tool
static_parameters = ObjectParameter(
description="Parameters for the Neo4j information retriever",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this description really required? It doesn't seem to have a real added-value.

properties={
"query_text": StringParameter(
description="The query about Neo4j (any query will return general Neo4j information)",
required=True,
),
},
)

# Convert the retriever to a tool with specific parameters
static_tool = convert_retriever_to_tool(
Copy link
Contributor

@stellasia stellasia May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say this should be a method from the Retriever class, and parameters should be encapsulated in this class as well, these parameters are bound to the search method and won't change from one instance to another.

So something like:

class Retriever:
   def get_parameters(self) -> ObjectParameter:
       raise NotImplementedError()  # need to be implemented in subclasses

   def convert_to_tool(self, name: str, description: Optional[str] = None) -> Tool:
       # rest of the function goes here

Note: as a future improvement, I think we could infer the parameters from the search method signature without having to redeclare it.

retriever=static_retriever,
description="Get general information about Neo4j graph database",
parameters=static_parameters,
name="Neo4jInfoTool",
)

# Print tool information
print("Example: StaticRetriever with specific parameters")
print(f"Tool Name: {static_tool.get_name()}")
print(f"Tool Description: {static_tool.get_description()}")
print(f"Tool Parameters: {static_tool.get_parameters()}")
print()

# Execute the tools (in a real application, this would be done by instructions from an LLM)
try:
# Execute the static retriever tool
print("\nExecuting the static retriever tool...")
static_result = static_tool.execute(
query="What is Neo4j?",
)
print("Static Search Results:")
for i, item in enumerate(static_result):
print(f"{i + 1}. {str(item)[:100]}...")

except Exception as e:
print(f"Error executing tool: {e}")


if __name__ == "__main__":
main()
Loading