Skip to content

Commit d3d707e

Browse files
committed
feat(async): Provide a new atop_n_routes() function for async contexts
1 parent 4342f2a commit d3d707e

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ router = SemanticRouter(config)
4949
# Start making queries (explore options in the SemanticRouter class)
5050
matches = router.top_n_routes(query="Hi!")
5151

52+
# Or, for async contexts:
53+
matches = await router.atop_n_routes(query="Hi!")
54+
5255
# Print the top matches (route, score, depth, leaf)
5356
[print(match) for match in matches]
5457
```

asero/main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
"""Main module, demonstration purposes, for asero semantic router."""
5+
import asyncio
56

67
from asero.config import get_config
78
from asero.router import SemanticRouter
89

910

10-
def main():
11+
async def main():
1112
"""Demonstrate the SemanticRouter functionality."""
1213
config = get_config()
1314
router = SemanticRouter(config) # Defaults to router_example.yaml
@@ -17,8 +18,8 @@ def main():
1718
while True:
1819
try:
1920
print(f"Type a query to see top-{top} semantic routes (ctrl-C to exit):")
20-
q = input("You: ").strip()
21-
matches = router.top_n_routes(q, top_n=top)
21+
q = (await asyncio.to_thread(input, "You: ")).strip()
22+
matches = await router.atop_n_routes(q, top_n=top)
2223
print("")
2324
print(f"Query: {q}")
2425
print("Top nodes:")
@@ -33,4 +34,4 @@ def main():
3334

3435

3536
if __name__ == "__main__":
36-
main()
37+
asyncio.run(main())

asero/router.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
"""Main asero semantic router classes."""
5-
5+
import asyncio
66
import logging
77
import re
88

@@ -448,7 +448,7 @@ def top_n_routes(
448448
only_leaves: bool = True,
449449
allowed_paths: list[str] | None = None
450450
) -> list[tuple[str, float, int, bool]]:
451-
"""For a given query, return the top-N most similar semantic routes in the hierarchy.
451+
"""Get (synchronously) for a given query, the top-N most similar semantic routes in the hierarchy.
452452
453453
Args:
454454
query (str): User query string.
@@ -469,6 +469,28 @@ def top_n_routes(
469469
allowed_paths=allowed_paths,
470470
)
471471

472+
async def atop_n_routes(
473+
self,
474+
query: str,
475+
top_n: int = 3,
476+
only_leaves: bool = True,
477+
allowed_paths: list[str] | None = None
478+
) -> list[tuple[str, float, int, bool]]:
479+
"""Get (asynchronously) for a given query, the top-N most similar semantic routes in the hierarchy.
480+
481+
Args:
482+
query (str): User query string.
483+
top_n (int): Number of top routes to return.
484+
only_leaves (bool): If True, only return leaf nodes.
485+
allowed_paths (list[str]): List of allowed paths to filter results.
486+
487+
Returns:
488+
list[tuple[str, float, int, bool]]: List of tuples:
489+
(route_path, similarity_score, depth, is_leaf)
490+
491+
"""
492+
return await asyncio.to_thread(self.top_n_routes, query, top_n, only_leaves, allowed_paths)
493+
472494
def add_utterance(
473495
self,
474496
path: list[str],

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies = [
2828
"colorlog ~= 6.9.0",
2929
"numpy ~= 2.3.3",
3030
"openai ~= 1.106.1",
31+
"python-dotenv ~= 1.1.1",
3132
]
3233

3334
[project.optional-dependencies]

0 commit comments

Comments
 (0)