diff --git a/src/app/api/api_v1/endpoints/chat.py b/src/app/api/api_v1/endpoints/chat.py index 5ab20b3..83ffde3 100644 --- a/src/app/api/api_v1/endpoints/chat.py +++ b/src/app/api/api_v1/endpoints/chat.py @@ -3,7 +3,7 @@ import backoff import psycopg -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status from fastapi.responses import StreamingResponse from langchain_core.messages import ToolMessage from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver @@ -22,6 +22,7 @@ LanguageNotSupportedError, bad_request, ) +from src.app.services.search import SearchService, get_search_service from src.app.utils.logger import logger as utils_logger logger = utils_logger(__name__) @@ -345,8 +346,10 @@ async def q_and_a_stream( factor=2, ) async def agent_response( + background_tasks: BackgroundTasks, body: models.AgentContext = Depends(get_agent_params), chatfactory=Depends(get_chat_service), + sp: SearchService = Depends(get_search_service), ) -> Optional[Dict]: try: if body.query is None: @@ -367,12 +370,15 @@ async def agent_response( thread_id=body.thread_id, corpora=body.corpora, sdg_filter=body.sdg_filter, + sp=sp, + background_tasks=background_tasks, ) else: res = await chatfactory.agent_message( query=body.query, corpora=body.corpora, sdg_filter=body.sdg_filter, + sp=sp, ) if isinstance(res["messages"][-2], ToolMessage): diff --git a/src/app/services/abst_chat.py b/src/app/services/abst_chat.py index 6036d31..88a23eb 100644 --- a/src/app/services/abst_chat.py +++ b/src/app/services/abst_chat.py @@ -20,7 +20,7 @@ from abc import ABC from typing import AsyncIterable, Dict, List, Optional -from fastapi import Depends, Request +from fastapi import BackgroundTasks, Depends, Request from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel # type: ignore from langchain_core.runnables import RunnableConfig # type: ignore from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver # type: ignore @@ -42,6 +42,7 @@ ) # from src.app.services.llm_proxy import LLMProxy +from src.app.services.search import SearchService from src.app.utils.decorators import log_time_and_error from src.app.utils.logger import log_environmental_impacts from src.app.utils.logger import logger as utils_logger @@ -404,6 +405,8 @@ async def agent_message( thread_id: Optional[uuid.UUID] = None, corpora: Optional[tuple[str, ...]] = None, sdg_filter: Optional[List[int]] = None, + sp: SearchService | None = None, + background_tasks: BackgroundTasks | None = None, ): """ Sends a chat message handled by an agent. @@ -440,6 +443,8 @@ async def agent_message( "thread_id": thread_id, "corpora": corpora, "sdg_filter": sdg_filter, + "sp": sp, + "background_tasks": background_tasks, } ) @@ -450,7 +455,11 @@ async def agent_message( }, ] - res = await agent_executor.ainvoke(input={"messages": messages}, config=config) + res = await agent_executor.ainvoke( + input={"messages": messages}, + config=config, + background_tasks=background_tasks, + ) return res diff --git a/src/app/services/agent.py b/src/app/services/agent.py index f464d49..51d78da 100644 --- a/src/app/services/agent.py +++ b/src/app/services/agent.py @@ -1,5 +1,6 @@ from typing import List, Tuple +from fastapi import BackgroundTasks from langchain_core.messages.utils import trim_messages from langchain_core.runnables import RunnableConfig from langchain_core.tools import tool @@ -15,16 +16,23 @@ async def _get_resources_about_sustainability( - rag_query: str, config: RunnableConfig + rag_query: str, + config: RunnableConfig, ) -> Tuple[str, List[Document]]: """Core logic for getting relevant resources about sustainability from WeLearn database.""" - sp = SearchService() + qp = EnhancedSearchQuery( query=rag_query, sdg_filter=config["configurable"].get("sdg_filter"), corpora=config["configurable"].get("corpora"), ) - docs = await sp.search_handler(qp) + sp: SearchService = config["configurable"].get("sp") + background_tasks: BackgroundTasks = config["configurable"].get("background_tasks") + if not sp: + logger.warning("No SearchService found.") + return "No relevant documents found.", [] + + docs = await sp.search_handler(background_tasks=background_tasks, qp=qp) if not docs: logger.warning("No documents found for the query.") return "No relevant documents found.", [] @@ -36,7 +44,8 @@ async def _get_resources_about_sustainability( @tool(response_format="content_and_artifact") @log_time_and_error async def get_resources_about_sustainability( - rag_query: str, config: RunnableConfig + rag_query: str, + config: RunnableConfig, ) -> Tuple[str, List[Document]]: """Get relevant resources about sustainability from WeLearn database. diff --git a/src/app/services/search.py b/src/app/services/search.py index 5d6de76..d52fb56 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -4,6 +4,7 @@ from typing import Tuple, cast import numpy as np +import torch from fastapi import BackgroundTasks, Depends, Request from fastapi.concurrency import run_in_threadpool from numpy import ndarray @@ -12,7 +13,6 @@ from qdrant_client.http import exceptions as qdrant_exceptions from qdrant_client.http import models as http_models from sklearn.metrics.pairwise import cosine_similarity -import torch from transformers import AutoModel, AutoTokenizer from src.app.models.collections import Collection @@ -187,7 +187,9 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: @log_time_and_error_sync def _compute_embeddings(self, model, tokenizer, inputs: list[str]) -> np.ndarray: with torch.no_grad(): - tokenized_inputs = tokenizer(inputs, padding=True, truncation=True, return_tensors='pt') + tokenized_inputs = tokenizer( + inputs, padding=True, truncation=True, return_tensors="pt" + ) model_output = model(**tokenized_inputs) embeddings = model_output[0][:, 0] embeddings = torch.nn.functional.normalize(embeddings, dim=1).numpy() @@ -206,7 +208,9 @@ async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: inputs = self._split_input_seq_len(seq_len, search_input) try: - embeddings = await run_in_threadpool(self._compute_embeddings, model, tokenizer, inputs) + embeddings = await run_in_threadpool( + self._compute_embeddings, model, tokenizer, inputs + ) embeddings = np.mean(embeddings, axis=0) except Exception as ex: logger.error("api_error=EMBED_ERROR model=%s", curr_model) @@ -228,7 +232,9 @@ async def simple_search_handler(self, qp: EnhancedSearchQuery): model_instance = model["instance"] tokenizer = model["tokenizer"] embedding_input = qp.query if isinstance(qp.query, list) else [qp.query] - embedding = await run_in_threadpool(self._compute_embeddings, model_instance, tokenizer, embedding_input) + embedding = await run_in_threadpool( + self._compute_embeddings, model_instance, tokenizer, embedding_input + ) result = await self.search( collection_info="collection_welearn_mul_granite-embedding-107m-multilingual", embedding=embedding, diff --git a/src/app/tests/services/test_agent.py b/src/app/tests/services/test_agent.py index 0fd7bf0..78ce29d 100644 --- a/src/app/tests/services/test_agent.py +++ b/src/app/tests/services/test_agent.py @@ -1,6 +1,8 @@ from unittest import IsolatedAsyncioTestCase, mock from unittest.mock import AsyncMock +from fastapi import BackgroundTasks + from src.app.services.agent import ( _get_resources_about_sustainability, trim_conversation_history, @@ -12,17 +14,22 @@ async def test_get_resources_about_sustainability_found(self): # Mock SearchService and its search_handler with mock.patch("src.app.services.agent.SearchService") as MockSearchService: mock_service = MockSearchService.return_value - mock_payload = { + mock_doc = mock.Mock() + mock_doc.payload = { "document_title": "Test Title", "slice_content": "Test Content", "document_url": "http://test.url", } - mock_doc = mock.Mock() - mock_doc.payload = mock_payload mock_service.search_handler = AsyncMock(return_value=[mock_doc] * 8) + config = { + "configurable": { + "sdg_filter": None, + "corpora": None, + "sp": mock_service, + "background_tasks": BackgroundTasks(), + } + } - # config with dummy values - config = {"configurable": {"sdg_filter": None, "corpora": None}} content, docs = await _get_resources_about_sustainability( "test query", config ) @@ -46,3 +53,74 @@ def test_trim_conversation_history(self): result = trim_conversation_history(state) self.assertIn("llm_input_messages", result) self.assertIsInstance(result["llm_input_messages"], list) + + async def test_get_resources_about_sustainability_no_search_service(self): + # config without 'sp' (SearchService) + config = { + "configurable": { + "sdg_filter": None, + "corpora": None, + "sp": None, + "background_tasks": None, + } + } + content, docs = await _get_resources_about_sustainability("test query", config) + self.assertEqual(content, "No relevant documents found.") + self.assertEqual(docs, []) + + async def test_get_resources_about_sustainability_with_background_tasks(self): + # Mock SearchService and its search_handler + with mock.patch("src.app.services.agent.SearchService") as MockSearchService: + mock_service = MockSearchService.return_value + mock_doc = mock.Mock() + mock_doc.payload = { + "document_title": "Test Title", + "slice_content": "Test Content", + "document_url": "http://test.url", + } + mock_service.search_handler = AsyncMock(return_value=[mock_doc]) + config = { + "configurable": { + "sdg_filter": None, + "corpora": None, + "sp": mock_service, + "background_tasks": BackgroundTasks(), + } + } + content, docs = await _get_resources_about_sustainability( + "test query", config + ) + self.assertIsInstance(content, str) + self.assertEqual(len(docs), 1) + + async def test_get_resources_about_sustainability_limits_to_seven_docs(self): + # Mock SearchService and its search_handler + with mock.patch("src.app.services.agent.SearchService") as MockSearchService: + mock_service = MockSearchService.return_value + mock_doc = mock.Mock() + mock_doc.payload = { + "document_title": "Test Title", + "slice_content": "Test Content", + "document_url": "http://test.url", + } + mock_service.search_handler = AsyncMock(return_value=[mock_doc] * 10) + config = { + "configurable": { + "sdg_filter": None, + "corpora": None, + "sp": mock_service, + "background_tasks": BackgroundTasks(), + } + } + with mock.patch( + "src.app.services.agent.stringify_docs_content" + ) as mock_stringify: + mock_stringify.return_value = "stringified content" + content, docs = await _get_resources_about_sustainability( + "test query", config + ) + mock_stringify.assert_called_once() + called_docs = mock_stringify.call_args[0][0] + self.assertEqual(len(called_docs), 7) + self.assertEqual(content, "stringified content") + self.assertEqual(len(docs), 10)