Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/app/api/api_v1/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import backoff
import psycopg
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from langchain_core.messages import ToolMessage
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
Expand All @@ -22,6 +22,7 @@
LanguageNotSupportedError,
bad_request,
)
from src.app.services.search import SearchService, get_search_service
from src.app.utils.logger import logger as utils_logger

logger = utils_logger(__name__)
Expand Down Expand Up @@ -345,8 +346,10 @@ async def q_and_a_stream(
factor=2,
)
async def agent_response(
background_tasks: BackgroundTasks,
body: models.AgentContext = Depends(get_agent_params),
chatfactory=Depends(get_chat_service),
sp: SearchService = Depends(get_search_service),
) -> Optional[Dict]:
try:
if body.query is None:
Expand All @@ -367,12 +370,15 @@ async def agent_response(
thread_id=body.thread_id,
corpora=body.corpora,
sdg_filter=body.sdg_filter,
sp=sp,
background_tasks=background_tasks,
)
else:
res = await chatfactory.agent_message(
query=body.query,
corpora=body.corpora,
sdg_filter=body.sdg_filter,
sp=sp,
)

if isinstance(res["messages"][-2], ToolMessage):
Expand Down
13 changes: 11 additions & 2 deletions src/app/services/abst_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from abc import ABC
from typing import AsyncIterable, Dict, List, Optional

from fastapi import Depends, Request
from fastapi import BackgroundTasks, Depends, Request
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel # type: ignore
from langchain_core.runnables import RunnableConfig # type: ignore
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver # type: ignore
Expand All @@ -42,6 +42,7 @@
)

# from src.app.services.llm_proxy import LLMProxy
from src.app.services.search import SearchService
from src.app.utils.decorators import log_time_and_error
from src.app.utils.logger import log_environmental_impacts
from src.app.utils.logger import logger as utils_logger
Expand Down Expand Up @@ -404,6 +405,8 @@ async def agent_message(
thread_id: Optional[uuid.UUID] = None,
corpora: Optional[tuple[str, ...]] = None,
sdg_filter: Optional[List[int]] = None,
sp: SearchService | None = None,
background_tasks: BackgroundTasks | None = None,
):
"""
Sends a chat message handled by an agent.
Expand Down Expand Up @@ -440,6 +443,8 @@ async def agent_message(
"thread_id": thread_id,
"corpora": corpora,
"sdg_filter": sdg_filter,
"sp": sp,
"background_tasks": background_tasks,
}
)

Expand All @@ -450,7 +455,11 @@ async def agent_message(
},
]

res = await agent_executor.ainvoke(input={"messages": messages}, config=config)
res = await agent_executor.ainvoke(
input={"messages": messages},
config=config,
background_tasks=background_tasks,
)
return res


Expand Down
17 changes: 13 additions & 4 deletions src/app/services/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Tuple

from fastapi import BackgroundTasks
from langchain_core.messages.utils import trim_messages
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
Expand All @@ -15,16 +16,23 @@


async def _get_resources_about_sustainability(
rag_query: str, config: RunnableConfig
rag_query: str,
config: RunnableConfig,
) -> Tuple[str, List[Document]]:
"""Core logic for getting relevant resources about sustainability from WeLearn database."""
sp = SearchService()

qp = EnhancedSearchQuery(
query=rag_query,
sdg_filter=config["configurable"].get("sdg_filter"),
corpora=config["configurable"].get("corpora"),
)
docs = await sp.search_handler(qp)
sp: SearchService = config["configurable"].get("sp")
background_tasks: BackgroundTasks = config["configurable"].get("background_tasks")
if not sp:
logger.warning("No SearchService found.")
return "No relevant documents found.", []

docs = await sp.search_handler(background_tasks=background_tasks, qp=qp)
if not docs:
logger.warning("No documents found for the query.")
return "No relevant documents found.", []
Expand All @@ -36,7 +44,8 @@ async def _get_resources_about_sustainability(
@tool(response_format="content_and_artifact")
@log_time_and_error
async def get_resources_about_sustainability(
rag_query: str, config: RunnableConfig
rag_query: str,
config: RunnableConfig,
) -> Tuple[str, List[Document]]:
"""Get relevant resources about sustainability from WeLearn database.

Expand Down
14 changes: 10 additions & 4 deletions src/app/services/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple, cast

import numpy as np
import torch
from fastapi import BackgroundTasks, Depends, Request
from fastapi.concurrency import run_in_threadpool
from numpy import ndarray
Expand All @@ -12,7 +13,6 @@
from qdrant_client.http import exceptions as qdrant_exceptions
from qdrant_client.http import models as http_models
from sklearn.metrics.pairwise import cosine_similarity
import torch
from transformers import AutoModel, AutoTokenizer

from src.app.models.collections import Collection
Expand Down Expand Up @@ -187,7 +187,9 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]:
@log_time_and_error_sync
def _compute_embeddings(self, model, tokenizer, inputs: list[str]) -> np.ndarray:
with torch.no_grad():
tokenized_inputs = tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
tokenized_inputs = tokenizer(
inputs, padding=True, truncation=True, return_tensors="pt"
)
model_output = model(**tokenized_inputs)
embeddings = model_output[0][:, 0]
embeddings = torch.nn.functional.normalize(embeddings, dim=1).numpy()
Expand All @@ -206,7 +208,9 @@ async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray:
inputs = self._split_input_seq_len(seq_len, search_input)

try:
embeddings = await run_in_threadpool(self._compute_embeddings, model, tokenizer, inputs)
embeddings = await run_in_threadpool(
self._compute_embeddings, model, tokenizer, inputs
)
embeddings = np.mean(embeddings, axis=0)
except Exception as ex:
logger.error("api_error=EMBED_ERROR model=%s", curr_model)
Expand All @@ -228,7 +232,9 @@ async def simple_search_handler(self, qp: EnhancedSearchQuery):
model_instance = model["instance"]
tokenizer = model["tokenizer"]
embedding_input = qp.query if isinstance(qp.query, list) else [qp.query]
embedding = await run_in_threadpool(self._compute_embeddings, model_instance, tokenizer, embedding_input)
embedding = await run_in_threadpool(
self._compute_embeddings, model_instance, tokenizer, embedding_input
)
result = await self.search(
collection_info="collection_welearn_mul_granite-embedding-107m-multilingual",
embedding=embedding,
Expand Down
88 changes: 83 additions & 5 deletions src/app/tests/services/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from unittest import IsolatedAsyncioTestCase, mock
from unittest.mock import AsyncMock

from fastapi import BackgroundTasks

from src.app.services.agent import (
_get_resources_about_sustainability,
trim_conversation_history,
Expand All @@ -12,17 +14,22 @@ async def test_get_resources_about_sustainability_found(self):
# Mock SearchService and its search_handler
with mock.patch("src.app.services.agent.SearchService") as MockSearchService:
mock_service = MockSearchService.return_value
mock_payload = {
mock_doc = mock.Mock()
mock_doc.payload = {
"document_title": "Test Title",
"slice_content": "Test Content",
"document_url": "http://test.url",
}
mock_doc = mock.Mock()
mock_doc.payload = mock_payload
mock_service.search_handler = AsyncMock(return_value=[mock_doc] * 8)
config = {
"configurable": {
"sdg_filter": None,
"corpora": None,
"sp": mock_service,
"background_tasks": BackgroundTasks(),
}
}

# config with dummy values
config = {"configurable": {"sdg_filter": None, "corpora": None}}
content, docs = await _get_resources_about_sustainability(
"test query", config
)
Expand All @@ -46,3 +53,74 @@ def test_trim_conversation_history(self):
result = trim_conversation_history(state)
self.assertIn("llm_input_messages", result)
self.assertIsInstance(result["llm_input_messages"], list)

async def test_get_resources_about_sustainability_no_search_service(self):
# config without 'sp' (SearchService)
config = {
"configurable": {
"sdg_filter": None,
"corpora": None,
"sp": None,
"background_tasks": None,
}
}
content, docs = await _get_resources_about_sustainability("test query", config)
self.assertEqual(content, "No relevant documents found.")
self.assertEqual(docs, [])

async def test_get_resources_about_sustainability_with_background_tasks(self):
# Mock SearchService and its search_handler
with mock.patch("src.app.services.agent.SearchService") as MockSearchService:
mock_service = MockSearchService.return_value
mock_doc = mock.Mock()
mock_doc.payload = {
"document_title": "Test Title",
"slice_content": "Test Content",
"document_url": "http://test.url",
}
mock_service.search_handler = AsyncMock(return_value=[mock_doc])
config = {
"configurable": {
"sdg_filter": None,
"corpora": None,
"sp": mock_service,
"background_tasks": BackgroundTasks(),
}
}
content, docs = await _get_resources_about_sustainability(
"test query", config
)
self.assertIsInstance(content, str)
self.assertEqual(len(docs), 1)

async def test_get_resources_about_sustainability_limits_to_seven_docs(self):
# Mock SearchService and its search_handler
with mock.patch("src.app.services.agent.SearchService") as MockSearchService:
mock_service = MockSearchService.return_value
mock_doc = mock.Mock()
mock_doc.payload = {
"document_title": "Test Title",
"slice_content": "Test Content",
"document_url": "http://test.url",
}
mock_service.search_handler = AsyncMock(return_value=[mock_doc] * 10)
config = {
"configurable": {
"sdg_filter": None,
"corpora": None,
"sp": mock_service,
"background_tasks": BackgroundTasks(),
}
}
with mock.patch(
"src.app.services.agent.stringify_docs_content"
) as mock_stringify:
mock_stringify.return_value = "stringified content"
content, docs = await _get_resources_about_sustainability(
"test query", config
)
mock_stringify.assert_called_once()
called_docs = mock_stringify.call_args[0][0]
self.assertEqual(len(called_docs), 7)
self.assertEqual(content, "stringified content")
self.assertEqual(len(docs), 10)