diff --git a/migrations/006_scraped_pages_index.sql b/migrations/006_scraped_pages_index.sql new file mode 100644 index 0000000..6c22875 --- /dev/null +++ b/migrations/006_scraped_pages_index.sql @@ -0,0 +1,74 @@ +-- Scraped pages index: per-page metadata and vector embeddings for semantic search +-- Requires pgvector extension (Supabase has it built-in) + +create extension if not exists vector; + +-- Per-page metadata and content from scraped websites +create table scraped_pages ( + id uuid primary key default gen_random_uuid(), + reference_doc_id uuid not null references reference_documents(id) on delete cascade, + url text not null, + normalized_url text not null, + title text, + raw_content text not null, + word_count int not null, + scraped_at timestamptz not null, + created_at timestamptz default now() +); + +-- Chunks with vector embeddings for semantic search +create table page_chunks ( + id uuid primary key default gen_random_uuid(), + scraped_page_id uuid not null references scraped_pages(id) on delete cascade, + chunk_index int not null, + content text not null, + embedding vector(1536) not null, + word_count int not null, + created_at timestamptz default now() +); + +-- Indexes +create index idx_scraped_pages_reference_doc_id on scraped_pages(reference_doc_id); +create unique index idx_scraped_pages_normalized_url_reference_doc + on scraped_pages(normalized_url, reference_doc_id); + +-- IVFFlat index for approximate nearest-neighbor search (cosine distance) +-- lists = 100 is a reasonable default for small-to-medium datasets +create index idx_page_chunks_embedding_cosine on page_chunks + using ivfflat (embedding vector_cosine_ops) + with (lists = 100); + +create index idx_page_chunks_scraped_page_id on page_chunks(scraped_page_id); + +-- RPC for semantic search: returns chunks with source URL, ordered by cosine distance. +-- query_embedding_text is a string like '[0.1, 0.2, ...]' so Supabase/PostgREST can pass it. +create or replace function search_page_chunks( + query_embedding_text text, + ref_doc_id uuid, + match_limit int default 5 +) +returns table ( + id uuid, + scraped_page_id uuid, + chunk_index int, + content text, + word_count int, + page_url text, + distance float +) +language sql stable +as $$ + select + pc.id, + pc.scraped_page_id, + pc.chunk_index, + pc.content, + pc.word_count, + sp.url as page_url, + (pc.embedding <=> query_embedding_text::vector(1536)) as distance + from page_chunks pc + join scraped_pages sp on sp.id = pc.scraped_page_id + where sp.reference_doc_id = search_page_chunks.ref_doc_id + order by pc.embedding <=> query_embedding_text::vector(1536) + limit match_limit; +$$; diff --git a/src/api/webhook.py b/src/api/webhook.py index 98dc95e..325656e 100644 --- a/src/api/webhook.py +++ b/src/api/webhook.py @@ -121,6 +121,7 @@ async def process_message(page_id: str, sender_id: str, message_text: str): context = AgentContext( bot_config_id=bot_config.id, + reference_doc_id=bot_config.reference_doc_id, reference_doc=ref_doc["content"], tone=bot_config.tone, recent_messages=recent_messages, diff --git a/src/cli/setup_cli.py b/src/cli/setup_cli.py index fa9ec24..a372662 100644 --- a/src/cli/setup_cli.py +++ b/src/cli/setup_cli.py @@ -22,13 +22,17 @@ import questionary import typer -from src.services.scraper import scrape_website +from src.services.scraper import chunk_text, scrape_website from src.services.reference_doc import build_reference_document +from src.services.embedding_service import generate_embeddings from src.db.repository import ( create_bot_configuration, create_reference_document, + create_page_chunks, + create_scraped_page, create_test_session, get_reference_document_by_source_url, + get_scraped_pages_by_reference_doc, save_test_message, ) from src.models.agent_models import AgentContext @@ -282,6 +286,7 @@ def _run_test_repl( """ context = AgentContext( bot_config_id="cli-test", + reference_doc_id=reference_doc_id, reference_doc=ref_doc_content, tone=tone, recent_messages=[], @@ -379,12 +384,55 @@ def setup(): ref_doc_content = existing_doc["content"] typer.echo(f"✓ Found existing reference document for {normalized_url}") typer.echo(" Skipping scrape and document generation.") + # If no page index exists yet, scrape and index pages only (do not modify reference doc) + existing_pages = get_scraped_pages_by_reference_doc(reference_doc_id) + if not existing_pages: + typer.echo(" No page index found. Scraping pages for search index only...") + try: + scrape_result = _run_async_with_cleanup(scrape_website(normalized_url)) + typer.echo(f" ✓ Scraped {len(scrape_result.pages)} pages") + typer.echo(" Indexing pages and generating embeddings...") + async def _index_pages_and_chunks(): + for page in scrape_result.pages: + scraped_page_id = create_scraped_page( + reference_doc_id=reference_doc_id, + url=page.url, + normalized_url=page.normalized_url, + title=page.title, + raw_content=page.content, + word_count=page.word_count, + scraped_at=page.scraped_at, + ) + page_chunk_tuples = chunk_text(page.content) + if not page_chunk_tuples: + continue + chunk_texts = [t[0] for t in page_chunk_tuples] + embeddings = await generate_embeddings(chunk_texts) + chunks_with_embeddings = [ + (chunk_texts[i], embeddings[i], page_chunk_tuples[i][1]) + for i in range(len(chunk_texts)) + ] + create_page_chunks(scraped_page_id, chunks_with_embeddings) + return len(scrape_result.pages) + page_count = _run_async_with_cleanup(_index_pages_and_chunks()) + typer.echo(f" ✓ Indexed {page_count} pages with embeddings") + except Exception as e: + typer.echo( + typer.style( + f" ⚠ Page indexing failed (search_pages tool will be empty): {e}", + fg=typer.colors.YELLOW, + ), + err=True, + ) + else: + typer.echo(f" Page index already has {len(existing_pages)} pages.") else: # Step 2a: Scrape typer.echo(f"Scraping {normalized_url}...") try: - text_chunks = _run_async_with_cleanup(scrape_website(normalized_url)) - typer.echo(f"✓ Scraped {len(text_chunks)} text chunks") + scrape_result = _run_async_with_cleanup(scrape_website(normalized_url)) + text_chunks = scrape_result.chunks + typer.echo(f"✓ Scraped {len(text_chunks)} text chunks from {len(scrape_result.pages)} pages") except Exception as e: typer.echo(f"✗ Error scraping website: {e}", err=True) raise typer.Exit(1) @@ -415,6 +463,44 @@ def setup(): raise typer.Exit(1) ref_doc_content = markdown_content + # Step 2d: Index scraped pages and chunks with embeddings for semantic search + typer.echo("Indexing pages and generating embeddings...") + try: + + async def _index_pages_and_chunks(): + for page in scrape_result.pages: + scraped_page_id = create_scraped_page( + reference_doc_id=reference_doc_id, + url=page.url, + normalized_url=page.normalized_url, + title=page.title, + raw_content=page.content, + word_count=page.word_count, + scraped_at=page.scraped_at, + ) + page_chunk_tuples = chunk_text(page.content) + if not page_chunk_tuples: + continue + chunk_texts = [t[0] for t in page_chunk_tuples] + embeddings = await generate_embeddings(chunk_texts) + chunks_with_embeddings = [ + (chunk_texts[i], embeddings[i], page_chunk_tuples[i][1]) + for i in range(len(chunk_texts)) + ] + create_page_chunks(scraped_page_id, chunks_with_embeddings) + return len(scrape_result.pages) + + page_count = _run_async_with_cleanup(_index_pages_and_chunks()) + typer.echo(f"✓ Indexed {page_count} pages with embeddings") + except Exception as e: + typer.echo( + typer.style( + f"⚠ Indexing failed (search_pages tool will be empty): {e}", + fg=typer.colors.YELLOW, + ), + err=True, + ) + # Step 3: Action menu (arrow-key); loop so user can Test then Continue or Exit while True: action = _action_menu() diff --git a/src/config.py b/src/config.py index 614e27e..4f9dcc0 100644 --- a/src/config.py +++ b/src/config.py @@ -50,6 +50,20 @@ class Settings(BaseSettings): description="Fallback Anthropic model if primary fails", ) + # Embedding (via PydanticAI Gateway) + embedding_model: str = Field( + default="gateway/openai:text-embedding-3-small", + description="Embedding model via PAIG (e.g. gateway/openai:text-embedding-3-small)", + ) + embedding_dimensions: int = Field( + default=1536, + description="Embedding vector dimension (matches text-embedding-3-small)", + ) + search_result_limit: int = Field( + default=5, + description="Max number of chunks to return from page search", + ) + # OpenAI Configuration (kept for direct fallback if needed) openai_api_key: str = Field( default="", description="OpenAI API key (legacy fallback)" diff --git a/src/db/repository.py b/src/db/repository.py index 5202de4..3447334 100644 --- a/src/db/repository.py +++ b/src/db/repository.py @@ -2,7 +2,7 @@ import time from datetime import datetime -from typing import Optional +from typing import Any, List, Optional import uuid import logfire @@ -232,6 +232,109 @@ def get_reference_document_by_source_url(source_url: str) -> Optional[dict]: return result.data[0] +def _embedding_to_text(embedding: List[float]) -> str: + """Format embedding list as pgvector text literal '[a,b,c,...]'.""" + return "[" + ",".join(str(x) for x in embedding) + "]" + + +def create_scraped_page( + reference_doc_id: str, + url: str, + normalized_url: str, + title: str, + raw_content: str, + word_count: int, + scraped_at: datetime, +) -> str: + """ + Insert a single scraped page row. + + Returns: + scraped_page id (uuid string) + """ + supabase = get_supabase_client() + data = { + "reference_doc_id": reference_doc_id, + "url": url, + "normalized_url": normalized_url, + "title": title or "", + "raw_content": raw_content, + "word_count": word_count, + "scraped_at": scraped_at.isoformat() if hasattr(scraped_at, "isoformat") else scraped_at, + } + result = supabase.table("scraped_pages").insert(data).execute() + if not result.data: + raise ValueError("Failed to create scraped_page") + return result.data[0]["id"] + + +def create_page_chunks( + scraped_page_id: str, + chunks_with_embeddings: List[tuple[str, List[float], int]], +) -> None: + """ + Batch insert page chunks with embeddings. + + chunks_with_embeddings: list of (content, embedding, word_count) per chunk. + """ + if not chunks_with_embeddings: + return + supabase = get_supabase_client() + rows: List[dict[str, Any]] = [] + for idx, (content, embedding, word_count) in enumerate(chunks_with_embeddings): + rows.append({ + "scraped_page_id": scraped_page_id, + "chunk_index": idx, + "content": content, + "embedding": embedding, # Supabase accepts list for vector column + "word_count": word_count, + }) + supabase.table("page_chunks").insert(rows).execute() + logfire.info( + "Page chunks created", + scraped_page_id=scraped_page_id, + chunk_count=len(rows), + ) + + +def search_page_chunks( + query_embedding: List[float], + reference_doc_id: str, + limit: int = 5, +) -> List[dict[str, Any]]: + """ + Semantic search over page chunks for a given reference document. + + Returns list of dicts with id, scraped_page_id, chunk_index, content, word_count, page_url, distance. + """ + supabase = get_supabase_client() + query_embedding_text = _embedding_to_text(query_embedding) + result = supabase.rpc( + "search_page_chunks", + { + "query_embedding_text": query_embedding_text, + "ref_doc_id": reference_doc_id, + "match_limit": limit, + }, + ).execute() + if not result.data: + return [] + return list(result.data) + + +def get_scraped_pages_by_reference_doc(reference_doc_id: str) -> List[dict[str, Any]]: + """List all scraped pages for a reference document.""" + supabase = get_supabase_client() + result = ( + supabase.table("scraped_pages") + .select("*") + .eq("reference_doc_id", reference_doc_id) + .order("created_at") + .execute() + ) + return list(result.data) if result.data else [] + + def get_user_profile(sender_id: str, page_id: str) -> dict | None: """ Get user profile by sender_id (unique per user). diff --git a/src/models/agent_models.py b/src/models/agent_models.py index 9e68440..838b941 100644 --- a/src/models/agent_models.py +++ b/src/models/agent_models.py @@ -7,6 +7,7 @@ class AgentContext(BaseModel): """Context for agent responses.""" bot_config_id: str + reference_doc_id: str reference_doc: str tone: str recent_messages: list[str] = Field(default_factory=list) diff --git a/src/models/scraper_models.py b/src/models/scraper_models.py new file mode 100644 index 0000000..4297963 --- /dev/null +++ b/src/models/scraper_models.py @@ -0,0 +1,26 @@ +"""Models for scraper results: per-page data and scrape result.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import List + + +@dataclass +class ScrapedPage: + """Metadata and content for a single scraped page.""" + + url: str + normalized_url: str + title: str + content: str + word_count: int + scraped_at: datetime + + +@dataclass +class ScrapeResult: + """Result of a multi-page scrape: pages and combined chunks.""" + + pages: List[ScrapedPage] + chunks: List[str] + content_hash: str diff --git a/src/services/agent_service.py b/src/services/agent_service.py index ff8829d..a5c0b41 100644 --- a/src/services/agent_service.py +++ b/src/services/agent_service.py @@ -4,12 +4,15 @@ import re from pathlib import Path +import logfire from pydantic import BaseModel, Field from pydantic_ai import Agent, RunContext from pydantic_ai.models.fallback import FallbackModel from src.config import get_settings +from src.db.repository import search_page_chunks from src.models.agent_models import AgentContext, AgentResponse +from src.services.embedding_service import embed_query logger = logging.getLogger(__name__) @@ -21,6 +24,7 @@ class MessengerAgentDeps(BaseModel): """Dependencies passed to the agent at runtime.""" + reference_doc_id: str reference_doc: str tone: str recent_messages: list[str] = Field(default_factory=list) @@ -129,6 +133,56 @@ async def check_reference_coverage( return f"Topic '{topic}' is covered in the reference document." return f"Topic '{topic}' is NOT covered. Consider escalating to human." + @self.agent.tool + async def search_pages( + ctx: RunContext[MessengerAgentDeps], query: str + ) -> str: + """Search scraped website pages for specific information. + + Use this when you need to find detailed information that may not be + in the overview, such as specific policies, contact details, or + facts about particular topics. + """ + logfire.info( + "Agent searching scraped pages beyond reference doc", + tool="search_pages", + query=query[:200], + reference_doc_id=ctx.deps.reference_doc_id, + ) + settings = get_settings() + limit = settings.search_result_limit + query_embedding = await embed_query(query) + if not query_embedding: + logfire.warning( + "search_pages skipped: empty query or embedding failed", + query_length=len(query), + ) + return "Search could not be run (empty query or embedding failed)." + results = search_page_chunks( + query_embedding=query_embedding, + reference_doc_id=ctx.deps.reference_doc_id, + limit=limit, + ) + if not results: + logfire.info( + "search_pages returned no matches", + query=query[:200], + reference_doc_id=ctx.deps.reference_doc_id, + ) + return "No matching content found in the scraped pages." + logfire.info( + "search_pages returned results from scraped pages", + result_count=len(results), + query=query[:200], + reference_doc_id=ctx.deps.reference_doc_id, + ) + parts = [] + for r in results: + page_url = r.get("page_url", "") + content = r.get("content", "")[:500] + parts.append(f"[Source: {page_url}]\n{content}...") + return "\n\n---\n\n".join(parts) + async def respond( self, context: AgentContext, @@ -146,6 +200,7 @@ async def respond( """ # Build dependencies deps = MessengerAgentDeps( + reference_doc_id=context.reference_doc_id, reference_doc=context.reference_doc, tone=context.tone, recent_messages=context.recent_messages, @@ -204,6 +259,7 @@ async def respond_with_fallback( ) deps = MessengerAgentDeps( + reference_doc_id=context.reference_doc_id, reference_doc=context.reference_doc, tone=context.tone, recent_messages=context.recent_messages, diff --git a/src/services/embedding_service.py b/src/services/embedding_service.py new file mode 100644 index 0000000..6b9c327 --- /dev/null +++ b/src/services/embedding_service.py @@ -0,0 +1,53 @@ +"""Embedding generation via PydanticAI Gateway.""" + +from typing import List + +import logfire +from pydantic_ai import Embedder + +from src.config import get_settings + + +async def generate_embeddings(texts: List[str]) -> List[List[float]]: + """ + Generate embeddings for a list of texts via PydanticAI Gateway. + + Uses settings.embedding_model (e.g. gateway/openai:text-embedding-3-small) + routed through existing pydantic_ai_gateway_api_key. + + Args: + texts: List of strings to embed (e.g. chunk contents). + + Returns: + List of embedding vectors (each a list of floats). + """ + if not texts: + return [] + settings = get_settings() + embedder = Embedder(settings.embedding_model) + with logfire.span("embedding_generate", text_count=len(texts)): + result = await embedder.embed_documents(texts) + return list(result.embeddings) + + +async def embed_query(query: str) -> List[float]: + """ + Generate a single embedding for a search query via PydanticAI Gateway. + + Use for query-side embedding when performing similarity search. + + Args: + query: Search query string. + + Returns: + Single embedding vector (list of floats). + """ + if not query or not query.strip(): + return [] + settings = get_settings() + embedder = Embedder(settings.embedding_model) + with logfire.span("embedding_query"): + result = await embedder.embed_query(query) + if not result.embeddings: + return [] + return list(result.embeddings[0]) diff --git a/src/services/scraper.py b/src/services/scraper.py index f582ab1..4c0a6bb 100644 --- a/src/services/scraper.py +++ b/src/services/scraper.py @@ -5,6 +5,7 @@ import os import re import time +from datetime import datetime, timezone from typing import List from urllib.parse import urljoin, urlparse @@ -12,6 +13,34 @@ import logfire from bs4 import BeautifulSoup +from src.models.scraper_models import ScrapedPage, ScrapeResult + + +def chunk_text(text: str, target_words: int = 650) -> List[tuple[str, int]]: + """ + Split text into chunks of approximately target_words each. + + Returns list of (chunk_text, word_count) for each chunk. + """ + if not text or not text.strip(): + return [] + words = text.strip().split() + if not words: + return [] + result: List[tuple[str, int]] = [] + current: List[str] = [] + current_count = 0 + for word in words: + current.append(word) + current_count += 1 + if current_count >= target_words: + result.append((" ".join(current), current_count)) + current = [] + current_count = 0 + if current: + result.append((" ".join(current), current_count)) + return result + def _normalize_url_for_crawl(url: str) -> str: """Strip fragment and trailing slash for dedup; keep scheme and netloc.""" @@ -152,10 +181,10 @@ async def _fetch_one_page(url: str, headers: dict) -> str: return html -def _parse_page_text_and_links(html: str, current_url: str) -> tuple[str, List[str]]: +def _parse_page_text_and_links(html: str, current_url: str) -> tuple[str, List[str], str]: """ - Parse HTML: extract visible text (nav/footer removed) and same-domain links. - Returns (normalized_text, list_of_absolute_same_domain_urls). + Parse HTML: extract visible text (nav/footer removed), same-domain links, and title. + Returns (normalized_text, list_of_absolute_same_domain_urls, page_title). """ soup = BeautifulSoup(html, "html.parser") for tag in soup(["script", "style", "nav", "footer"]): @@ -163,12 +192,15 @@ def _parse_page_text_and_links(html: str, current_url: str) -> tuple[str, List[s text = soup.get_text() text = re.sub(r"\s+", " ", text).strip() links = _extract_same_domain_links(soup, current_url) - return text, links + title = "" + if soup.title and soup.title.string: + title = soup.title.string.strip() + return text, links, title -async def scrape_website(url: str, max_pages: int = 20) -> List[str]: +async def scrape_website(url: str, max_pages: int = 20) -> ScrapeResult: """ - Scrape website and return text chunks from multiple same-domain pages. + Scrape website and return per-page data plus text chunks from multiple same-domain pages. Discovers internal links from each page and crawls up to max_pages. Tries httpx first per page; on 403/503 (e.g. Cloudflare) falls back to undetected Chrome. @@ -177,7 +209,7 @@ async def scrape_website(url: str, max_pages: int = 20) -> List[str]: max_pages: Maximum number of pages to scrape (default 20 for richer reference docs) Returns: - List of text chunks (500-800 words each) from all crawled pages combined + ScrapeResult with pages (URL, title, content per page), chunks, and content_hash """ start_time = time.time() normalized_start = _normalize_url_for_crawl(url) @@ -199,7 +231,7 @@ async def scrape_website(url: str, max_pages: int = 20) -> List[str]: visited: set[str] = set() to_visit: List[str] = [url] # exact start URL so first fetch matches user/tests in_queue: set[str] = {_normalize_url_for_crawl(url)} - page_texts: List[str] = [] + pages: List[ScrapedPage] = [] while to_visit and len(visited) < max_pages: current = to_visit.pop(0) @@ -211,12 +243,12 @@ async def scrape_website(url: str, max_pages: int = 20) -> List[str]: try: html = await _fetch_one_page(current, headers) except ValueError: - if not page_texts: + if not pages: raise logfire.warning("Skipping page after fetch error", url=current) continue - text, new_links = _parse_page_text_and_links(html, current) + text, new_links, title = _parse_page_text_and_links(html, current) # If first page has very little text (likely JS-rendered SPA), refetch with browser if len(visited) == 1 and len(text.split()) < 400: logfire.info( @@ -226,7 +258,7 @@ async def scrape_website(url: str, max_pages: int = 20) -> List[str]: ) try: html = await asyncio.to_thread(_fetch_with_browser_sync, current, 45.0) - text, new_links = _parse_page_text_and_links(html, current) + text, new_links, title = _parse_page_text_and_links(html, current) except Exception as e: logfire.warning( "Browser refetch failed, using initial content", @@ -234,7 +266,17 @@ async def scrape_website(url: str, max_pages: int = 20) -> List[str]: error=str(e), ) if text: - page_texts.append(text) + word_count = len(text.split()) + pages.append( + ScrapedPage( + url=current, + normalized_url=current_normalized, + title=title, + content=text, + word_count=word_count, + scraped_at=datetime.now(timezone.utc), + ) + ) for link in new_links: link_norm = _normalize_url_for_crawl(link) @@ -246,6 +288,7 @@ async def scrape_website(url: str, max_pages: int = 20) -> List[str]: if to_visit and len(visited) < max_pages: await asyncio.sleep(0.5) + page_texts = [p.content for p in pages] combined_text = " ".join(page_texts) combined_text = re.sub(r"\s+", " ", combined_text).strip() @@ -279,4 +322,4 @@ async def scrape_website(url: str, max_pages: int = 20) -> List[str]: content_hash=content_hash, total_time_ms=total_elapsed * 1000, ) - return chunks + return ScrapeResult(pages=pages, chunks=chunks, content_hash=content_hash) diff --git a/tests/conftest.py b/tests/conftest.py index 1831954..342e44e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -192,6 +192,7 @@ def sample_agent_context(sample_reference_doc): """Sample AgentContext for testing.""" return AgentContext( bot_config_id="bot-123", + reference_doc_id="ref-doc-123", reference_doc=sample_reference_doc, tone="professional", recent_messages=["Hello", "How can I help?"], diff --git a/tests/stateful/test_agent_conversation.py b/tests/stateful/test_agent_conversation.py index 8014ae4..a4f07cb 100644 --- a/tests/stateful/test_agent_conversation.py +++ b/tests/stateful/test_agent_conversation.py @@ -35,6 +35,7 @@ async def test_conversation_flow_basic(self, mock_get_settings, monkeypatch): # Initialize conversation state context = AgentContext( bot_config_id="test-123", + reference_doc_id="test-ref-doc-id", reference_doc="# Overview\nTest content for the agent.", tone="professional", recent_messages=[], @@ -99,6 +100,7 @@ async def test_conversation_maintains_context(self, mock_get_settings, monkeypat context = AgentContext( bot_config_id="test-123", + reference_doc_id="test-ref-doc-id", reference_doc="# Overview\nTest content.", tone="professional", recent_messages=[], diff --git a/tests/unit/test_agent_service.py b/tests/unit/test_agent_service.py index 51e5ee0..5652c81 100644 --- a/tests/unit/test_agent_service.py +++ b/tests/unit/test_agent_service.py @@ -15,6 +15,7 @@ def agent_context(self): """Sample agent context.""" return AgentContext( bot_config_id="test-bot-id", + reference_doc_id="test-ref-doc-id", reference_doc="# Test Reference\n\nThis is a test document about our services.", tone="professional", recent_messages=["Hello", "How can I help?"], @@ -105,6 +106,7 @@ async def test_respond_with_tenant_id(self, mock_settings): """Test that tenant_id is passed through correctly.""" context = AgentContext( bot_config_id="test-bot-id", + reference_doc_id="test-ref-doc-id", reference_doc="Test doc", tone="professional", recent_messages=[], diff --git a/tests/unit/test_embedding_service.py b/tests/unit/test_embedding_service.py new file mode 100644 index 0000000..b67d577 --- /dev/null +++ b/tests/unit/test_embedding_service.py @@ -0,0 +1,89 @@ +"""Tests for embedding service.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from src.services.embedding_service import generate_embeddings, embed_query + + +class TestGenerateEmbeddings: + """Test generate_embeddings().""" + + @pytest.mark.asyncio + async def test_generate_embeddings_empty_list(self): + """Empty input returns empty list.""" + result = await generate_embeddings([]) + assert result == [] + + @pytest.mark.asyncio + async def test_generate_embeddings_returns_vectors(self): + """generate_embeddings returns list of vectors (list of floats).""" + mock_result = MagicMock() + mock_result.embeddings = [[0.1] * 1536, [0.2] * 1536] + with patch( + "src.services.embedding_service.Embedder" + ) as mock_embedder_class: + mock_embedder = MagicMock() + mock_embedder.embed_documents = AsyncMock(return_value=mock_result) + mock_embedder_class.return_value = mock_embedder + result = await generate_embeddings(["text one", "text two"]) + assert len(result) == 2 + assert len(result[0]) == 1536 + assert len(result[1]) == 1536 + assert result[0][0] == 0.1 + assert result[1][0] == 0.2 + + @pytest.mark.asyncio + async def test_generate_embeddings_calls_embed_documents(self): + """generate_embeddings calls Embedder with settings.embedding_model.""" + mock_result = MagicMock() + mock_result.embeddings = [[0.0] * 1536] + mock_settings = MagicMock() + mock_settings.embedding_model = "gateway/openai:text-embedding-3-small" + with ( + patch("src.services.embedding_service.get_settings", return_value=mock_settings), + patch("src.services.embedding_service.Embedder") as mock_embedder_class, + ): + mock_embedder = MagicMock() + mock_embedder.embed_documents = AsyncMock(return_value=mock_result) + mock_embedder_class.return_value = mock_embedder + await generate_embeddings(["hello"]) + mock_embedder_class.assert_called_once_with("gateway/openai:text-embedding-3-small") + mock_embedder.embed_documents.assert_called_once_with(["hello"]) + + +class TestEmbedQuery: + """Test embed_query().""" + + @pytest.mark.asyncio + async def test_embed_query_empty_string_returns_empty(self): + """Empty or whitespace query returns empty list.""" + result = await embed_query("") + assert result == [] + result = await embed_query(" ") + assert result == [] + + @pytest.mark.asyncio + async def test_embed_query_returns_vector(self): + """embed_query returns a single vector.""" + mock_result = MagicMock() + mock_result.embeddings = [[0.5] * 1536] + with patch("src.services.embedding_service.Embedder") as mock_embedder_class: + mock_embedder = MagicMock() + mock_embedder.embed_query = AsyncMock(return_value=mock_result) + mock_embedder_class.return_value = mock_embedder + result = await embed_query("search query") + assert len(result) == 1536 + assert result[0] == 0.5 + + @pytest.mark.asyncio + async def test_embed_query_no_embeddings_returns_empty(self): + """When embed_query returns no embeddings, return empty list.""" + mock_result = MagicMock() + mock_result.embeddings = [] + with patch("src.services.embedding_service.Embedder") as mock_embedder_class: + mock_embedder = MagicMock() + mock_embedder.embed_query = AsyncMock(return_value=mock_result) + mock_embedder_class.return_value = mock_embedder + result = await embed_query("query") + assert result == [] diff --git a/tests/unit/test_logging.py b/tests/unit/test_logging.py index 575027b..a94e805 100644 --- a/tests/unit/test_logging.py +++ b/tests/unit/test_logging.py @@ -57,6 +57,7 @@ async def test_agent_service_logs_processing( agent = MessengerAgentService() context = AgentContext( bot_config_id="bot-123", + reference_doc_id="ref-doc-123", reference_doc="Test reference", tone="professional", recent_messages=[], @@ -96,7 +97,8 @@ async def test_scraper_logs_scraping_metrics(logfire_capture, respx_mock): ) ) - chunks = await scrape_website("https://example.com") + result = await scrape_website("https://example.com") + _ = result.chunks # ensure we have chunks for log verification # Verify scraping logs scrape_logs = [ diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index deb2812..375cb2f 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -105,6 +105,7 @@ class TestAgentModels: @given( bot_config_id=st.text(min_size=1, max_size=100), + reference_doc_id=st.uuids(), reference_doc=st.text(min_size=10, max_size=50000), tone=st.sampled_from( ["professional", "friendly", "casual", "formal", "humorous"] @@ -114,6 +115,7 @@ class TestAgentModels: def test_agent_context_properties( self, bot_config_id: str, + reference_doc_id, reference_doc: str, tone: str, recent_messages: list[str], @@ -121,11 +123,13 @@ def test_agent_context_properties( """Property: AgentContext should maintain invariants.""" context = AgentContext( bot_config_id=bot_config_id, + reference_doc_id=str(reference_doc_id), reference_doc=reference_doc, tone=tone, recent_messages=recent_messages, ) assert context.bot_config_id == bot_config_id + assert context.reference_doc_id == str(reference_doc_id) assert len(context.reference_doc) > 0 assert context.tone in [ "professional", diff --git a/tests/unit/test_scraper.py b/tests/unit/test_scraper.py index 98414aa..0c03eca 100644 --- a/tests/unit/test_scraper.py +++ b/tests/unit/test_scraper.py @@ -6,7 +6,8 @@ import httpx import respx as respx_lib -from src.services.scraper import scrape_website +from src.models.scraper_models import ScrapeResult +from src.services.scraper import chunk_text, scrape_website @pytest.fixture(autouse=True) @@ -47,11 +48,17 @@ async def test_scrape_website_valid_url(self, respx_mock): return_value=httpx.Response(200, text=html_content) ) - chunks = await scrape_website("https://example.com") + result = await scrape_website("https://example.com") - assert isinstance(chunks, list) - assert len(chunks) > 0 - assert all(isinstance(chunk, str) for chunk in chunks) + assert isinstance(result, ScrapeResult) + assert len(result.chunks) > 0 + assert all(isinstance(chunk, str) for chunk in result.chunks) + assert len(result.pages) == 1 + assert result.pages[0].url == "https://example.com" + # Normalizer keeps trailing slash for root path + assert result.pages[0].normalized_url in ("https://example.com", "https://example.com/") + assert result.pages[0].title == "Test Page" + assert result.content_hash @pytest.mark.asyncio async def test_scrape_website_invalid_url(self, respx_mock): @@ -93,7 +100,8 @@ async def test_scrape_website_removes_scripts(self, respx_mock): return_value=httpx.Response(200, text=html_content) ) - chunks = await scrape_website("https://example.com") + result = await scrape_website("https://example.com") + chunks = result.chunks # Script and style content should not appear combined_text = " ".join(chunks) @@ -119,7 +127,8 @@ async def test_scrape_website_removes_nav_footer(self, respx_mock): return_value=httpx.Response(200, text=html_content) ) - chunks = await scrape_website("https://example.com") + result = await scrape_website("https://example.com") + chunks = result.chunks combined_text = " ".join(chunks) assert "Navigation links" not in combined_text @@ -146,7 +155,8 @@ async def test_scrape_website_whitespace_normalization(self, respx_mock): return_value=httpx.Response(200, text=html_content) ) - chunks = await scrape_website("https://example.com") + result = await scrape_website("https://example.com") + chunks = result.chunks # Check that multiple spaces are normalized combined_text = " ".join(chunks) @@ -166,9 +176,11 @@ async def test_scrape_website_chunking_properties(self): return_value=httpx.Response(200, text=html_content) ) - chunks = await scrape_website("https://example.com") + result = await scrape_website("https://example.com") + chunks = result.chunks # Invariants + assert isinstance(result, ScrapeResult) assert isinstance(chunks, list) assert all(isinstance(chunk, str) for chunk in chunks) assert all(len(chunk) > 0 for chunk in chunks) # No empty chunks @@ -189,7 +201,8 @@ async def test_scrape_website_chunk_size(self, respx_mock): return_value=httpx.Response(200, text=html_content) ) - chunks = await scrape_website("https://example.com") + result = await scrape_website("https://example.com") + chunks = result.chunks # Check chunk sizes (target is 650 words, allow some flexibility) for chunk in chunks[:-1]: # Last chunk may be smaller @@ -208,9 +221,11 @@ async def test_scrape_website_empty_content(self, respx_mock): return_value=httpx.Response(200, text=html_content) ) - chunks = await scrape_website("https://example.com") + result = await scrape_website("https://example.com") + chunks = result.chunks - # Should return empty list or list with empty string + # Should return ScrapeResult with empty chunks when no content + assert isinstance(result, ScrapeResult) assert isinstance(chunks, list) # If there's no content, chunks might be empty or contain empty strings if chunks: @@ -228,7 +243,8 @@ async def test_scrape_website_follows_redirects(self): ) ) - chunks = await scrape_website("https://example.com/final") + result = await scrape_website("https://example.com/final") + chunks = result.chunks # Should get content from final URL combined_text = " ".join(chunks) @@ -252,10 +268,38 @@ async def test_scrape_website_various_html_structures(self): return_value=httpx.Response(200, text=html_content) ) - chunks = await scrape_website("https://example.com") + result = await scrape_website("https://example.com") + chunks = result.chunks - # Should always return list of strings + # Should always return ScrapeResult with list of strings + assert isinstance(result, ScrapeResult) assert isinstance(chunks, list) assert all(isinstance(chunk, str) for chunk in chunks) respx_lib.reset() + + +class TestChunkText: + """Test chunk_text() helper.""" + + def test_chunk_text_empty_returns_empty(self): + """Empty or whitespace text returns empty list.""" + assert chunk_text("") == [] + assert chunk_text(" ") == [] + + def test_chunk_text_single_chunk(self): + """Text under target words returns one chunk.""" + words = ["word"] * 100 + result = chunk_text(" ".join(words), target_words=650) + assert len(result) == 1 + assert result[0][1] == 100 + + def test_chunk_text_multiple_chunks(self): + """Text over target words is split into multiple chunks.""" + words = ["word"] * 1400 + result = chunk_text(" ".join(words), target_words=650) + assert len(result) >= 2 + total_words = sum(r[1] for r in result) + assert total_words == 1400 + for _chunk_str, wc in result[:-1]: + assert wc >= 650 diff --git a/tests/unit/test_setup_cli.py b/tests/unit/test_setup_cli.py index 78c53cf..322eea8 100644 --- a/tests/unit/test_setup_cli.py +++ b/tests/unit/test_setup_cli.py @@ -4,10 +4,11 @@ import asyncio import gc import time -from unittest.mock import patch, MagicMock +from unittest.mock import AsyncMock, patch, MagicMock import typer from src.models.agent_models import AgentResponse +from src.models.scraper_models import ScrapeResult from src.cli.setup_cli import ( setup, test as cli_test_command, @@ -167,8 +168,10 @@ def test_setup_complete_flow( mock_supabase = MagicMock() mock_get_supabase.return_value = mock_supabase - # Mock scraping - mock_scrape.return_value = ["chunk1", "chunk2", "chunk3"] + # Mock scraping (ScrapeResult with empty pages so indexing step does nothing) + mock_scrape.return_value = ScrapeResult( + pages=[], chunks=["chunk1", "chunk2", "chunk3"], content_hash="hash" + ) # Mock reference doc building (returns markdown string, not tuple) mock_build_ref.return_value = "# Reference Document" @@ -268,7 +271,7 @@ def test_setup_reference_doc_error( mock_get_settings.return_value = mock_settings mock_prompt.return_value = "https://example.com" - mock_scrape.return_value = ["chunk1"] + mock_scrape.return_value = ScrapeResult(pages=[], chunks=["chunk1"], content_hash="h") mock_build_ref.side_effect = Exception("Reference doc generation failed") with pytest.raises(typer.Exit): @@ -315,7 +318,7 @@ def test_setup_database_error( "token-123", "verify-123", ] - mock_scrape.return_value = ["chunk1"] + mock_scrape.return_value = ScrapeResult(pages=[], chunks=["chunk1"], content_hash="h") mock_build_ref.return_value = "# Doc" mock_create_ref_doc.side_effect = Exception("Database error") @@ -372,7 +375,7 @@ def test_setup_tone_selection( VALID_PAGE_ACCESS_TOKEN, VALID_VERIFY_TOKEN, ] - mock_scrape.return_value = ["chunk1"] + mock_scrape.return_value = ScrapeResult(pages=[], chunks=["chunk1"], content_hash="h") mock_build_ref.return_value = "# Doc" mock_create_ref_doc.return_value = "doc-123" mock_create_bot.return_value = MagicMock() @@ -427,7 +430,7 @@ def test_setup_prints_webhook_url( VALID_PAGE_ACCESS_TOKEN, VALID_VERIFY_TOKEN, ] - mock_scrape.return_value = ["chunk1"] + mock_scrape.return_value = ScrapeResult(pages=[], chunks=["chunk1"], content_hash="h") mock_build_ref.return_value = "# Doc" mock_create_ref_doc.return_value = "doc-123" mock_create_bot.return_value = MagicMock() @@ -439,6 +442,7 @@ def test_setup_prints_webhook_url( webhook_mentions = [call for call in echo_calls if "webhook" in call.lower()] assert len(webhook_mentions) > 0 + @patch("src.cli.setup_cli.get_scraped_pages_by_reference_doc") @patch("src.cli.setup_cli.create_bot_configuration") @patch("src.cli.setup_cli.create_reference_document") @patch("src.cli.setup_cli.build_reference_document") @@ -463,8 +467,10 @@ def test_setup_resume_when_ref_doc_exists( mock_build_ref, mock_create_ref_doc, mock_create_bot, + mock_get_scraped_pages, ): """When a reference doc already exists for the URL, skip scrape/build/store and resume at action menu then tone + Facebook.""" + mock_get_scraped_pages.return_value = [{"id": "page1"}] # already indexed mock_get_ref_doc.return_value = { "id": "existing-doc-456", "source_url": "https://example.com", @@ -485,7 +491,7 @@ def test_setup_resume_when_ref_doc_exists( setup() # Lookup was called with normalized URL mock_get_ref_doc.assert_called_once_with("https://example.com") - # Scrape and build were skipped + # Scrape and build were skipped (ref doc and page index both exist) mock_scrape.assert_not_called() mock_build_ref.assert_not_called() mock_create_ref_doc.assert_not_called() @@ -495,6 +501,69 @@ def test_setup_resume_when_ref_doc_exists( assert mock_create_bot.call_args[1]["tone"] == "Friendly" assert mock_create_bot.call_args[1]["page_id"] == "789012345678901" + @patch("src.cli.setup_cli.create_page_chunks") + @patch("src.cli.setup_cli.create_scraped_page") + @patch("src.cli.setup_cli.generate_embeddings", new_callable=AsyncMock) + @patch("src.cli.setup_cli.get_scraped_pages_by_reference_doc") + @patch("src.cli.setup_cli.create_bot_configuration") + @patch("src.cli.setup_cli.build_reference_document") + @patch("src.cli.setup_cli.scrape_website") + @patch("src.cli.setup_cli.get_reference_document_by_source_url") + @patch("src.config.get_settings") + @patch("src.db.client.get_supabase_client") + @patch("src.cli.setup_cli.questionary.select") + @patch("src.cli.setup_cli.typer.prompt") + @patch("src.cli.setup_cli.typer.echo") + def test_setup_existing_doc_no_pages_indexed_scrapes_and_indexes_only( + self, + mock_echo, + mock_prompt, + mock_questionary_select, + mock_get_supabase, + mock_get_settings, + mock_get_ref_doc, + mock_scrape, + mock_build_ref, + mock_create_bot, + mock_get_scraped_pages, + mock_generate_embeddings, + mock_create_scraped_page, + mock_create_page_chunks, + ): + """When ref doc exists but no pages indexed, scrape and index pages without modifying reference doc.""" + mock_get_ref_doc.return_value = { + "id": "existing-doc-456", + "source_url": "https://example.com", + "content": "# Existing doc content", + } + mock_get_scraped_pages.return_value = [] # no pages indexed yet + mock_scrape.return_value = ScrapeResult( + pages=[ + MagicMock( + url="https://example.com", + normalized_url="https://example.com", + title="Page", + content="Some content " * 100, + word_count=100, + scraped_at=MagicMock(), + ) + ], + chunks=["chunk1"], + content_hash="h", + ) + mock_create_scraped_page.return_value = "scraped-page-1" + mock_generate_embeddings.return_value = [[0.1] * 1536] # one embedding per chunk + mock_questionary_select.return_value.ask.side_effect = [ACTION_EXIT] + mock_prompt.side_effect = ["https://example.com"] + setup() + mock_get_ref_doc.assert_called_once_with("https://example.com") + mock_get_scraped_pages.assert_called_once_with("existing-doc-456") + mock_scrape.assert_called_once_with("https://example.com") + mock_build_ref.assert_not_called() + mock_create_scraped_page.assert_called() + mock_create_bot.assert_not_called() + + @patch("src.cli.setup_cli.get_scraped_pages_by_reference_doc") @patch("src.cli.setup_cli.create_bot_configuration") @patch("src.cli.setup_cli.get_reference_document_by_source_url") @patch("src.config.get_settings") @@ -511,8 +580,10 @@ def test_setup_exit_from_menu( mock_get_settings, mock_get_ref_doc, mock_create_bot, + mock_get_scraped_pages, ): """When user selects Exit from action menu, setup exits without creating bot.""" + mock_get_scraped_pages.return_value = [{"id": "page1"}] # already indexed mock_get_ref_doc.return_value = { "id": "existing-doc-456", "source_url": "https://example.com", @@ -555,7 +626,7 @@ def test_setup_aborts_when_confirmation_no( mock_settings.copilot_cli_host = "http://localhost:5909" mock_settings.copilot_enabled = True mock_get_settings.return_value = mock_settings - mock_scrape.return_value = ["chunk1"] + mock_scrape.return_value = ScrapeResult(pages=[], chunks=["chunk1"], content_hash="h") mock_build_ref.return_value = "# Doc" mock_create_ref_doc.return_value = "doc-123" mock_questionary_select.return_value.ask.side_effect = [ @@ -610,7 +681,7 @@ def test_setup_writes_webhook_info_file( mock_settings.copilot_cli_host = "http://localhost:5909" mock_settings.copilot_enabled = True mock_get_settings.return_value = mock_settings - mock_scrape.return_value = ["chunk1"] + mock_scrape.return_value = ScrapeResult(pages=[], chunks=["chunk1"], content_hash="h") mock_build_ref.return_value = "# Doc" mock_create_ref_doc.return_value = "doc-123" mock_create_bot.return_value = MagicMock() @@ -670,7 +741,7 @@ def test_setup_test_bot_then_continue( mock_settings.copilot_cli_host = "http://localhost:5909" mock_settings.copilot_enabled = True mock_get_settings.return_value = mock_settings - mock_scrape.return_value = ["chunk1"] + mock_scrape.return_value = ScrapeResult(pages=[], chunks=["chunk1"], content_hash="h") mock_build_ref.return_value = "# Doc" mock_create_ref_doc.return_value = "doc-123" mock_create_bot.return_value = MagicMock()