From 37140cf3989c695ba59b5cd17646f6bffe9f8b5b Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Fri, 10 Apr 2026 03:18:04 +0530 Subject: [PATCH 1/7] Auto-tune Embedding Model Parameters & Add Benchmarking Tool --- src/typeagent/aitools/vectorbase.py | 30 +++++- tools/benchmark_embeddings.py | 157 ++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+), 3 deletions(-) create mode 100644 tools/benchmark_embeddings.py diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index e22083c8..34552fae 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -34,11 +34,35 @@ def __init__( max_matches: int | None = None, batch_size: int | None = None, ): - self.min_score = min_score if min_score is not None else 0.85 - self.max_matches = max_matches if max_matches and max_matches >= 1 else None - self.batch_size = batch_size if batch_size and batch_size >= 1 else 8 self.embedding_model = embedding_model or create_embedding_model() + # Default fallback values + default_min_score = 0.85 + default_max_matches = None + + # Determine optimal parameters automatically for well-known models. + # Format: (min_score, max_matches) + # Note: text-embedding-3 models produce structurally lower cosine scores than older models + # and typically perform best in the 0.3 - 0.5 range for relevance filtering. + MODEL_DEFAULTS = { + "text-embedding-3-large": (0.30, 20), + "text-embedding-3-small": (0.35, 20), + "text-embedding-ada-002": (0.75, 20), + } + + # Check if the model_name matches any known ones + model_name = getattr(self.embedding_model, 'model_name', "") + + if model_name: + for known_model, defaults in MODEL_DEFAULTS.items(): + if known_model in model_name: + default_min_score, default_max_matches = defaults + break + + self.min_score = min_score if min_score is not None else default_min_score + self.max_matches = max_matches if max_matches is not None else default_max_matches + self.batch_size = batch_size if batch_size and batch_size >= 1 else 8 + class VectorBase: settings: TextEmbeddingIndexSettings diff --git a/tools/benchmark_embeddings.py b/tools/benchmark_embeddings.py new file mode 100644 index 00000000..2d6fc3a7 --- /dev/null +++ b/tools/benchmark_embeddings.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Utility script to benchmark different TextEmbeddingIndexSettings parameters. + +Usage: + uv run python tools/benchmark_embeddings.py [--model provider:model] +""" + +import argparse +import asyncio +import json +import logging +from pathlib import Path +from statistics import mean +import sys +from typing import Any + +from typeagent.aitools.model_adapters import create_embedding_model +from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase + + +async def run_benchmark(model_spec: str | None) -> None: + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + # Paths + script_dir = Path(__file__).resolve().parent + repo_root = script_dir.parent + index_data_path = repo_root / "tests" / "testdata" / "Episode_53_AdrianTchaikovsky_index_data.json" + search_data_path = repo_root / "tests" / "testdata" / "Episode_53_Search_results.json" + + logger.info(f"Loading index data from {index_data_path}") + try: + with open(index_data_path, "r", encoding="utf-8") as f: + index_json = json.load(f) + except Exception as e: + logger.error(f"Failed to load index data: {e}") + return + + messages = index_json.get("messages", []) + message_texts = [" ".join(m.get("textChunks", [])) for m in messages] + + logger.info(f"Loading search queries from {search_data_path}") + try: + with open(search_data_path, "r", encoding="utf-8") as f: + search_json = json.load(f) + except Exception as e: + logger.error(f"Failed to load search queries: {e}") + return + + # Filter out ones without results or expected matches + queries = [] + for item in search_json: + search_text = item.get("searchText") + results = item.get("results", []) + if not results: + continue + expected = results[0].get("messageMatches", []) + if not expected: + continue + queries.append((search_text, expected)) + + logger.info(f"Found {len(message_texts)} messages to embed.") + logger.info(f"Found {len(queries)} queries with expected matches to test.") + + try: + if model_spec == "test:fake": + from typeagent.aitools.model_adapters import create_test_embedding_model + model = create_test_embedding_model(embedding_size=384) + else: + model = create_embedding_model(model_spec) + except Exception as e: + logger.error(f"Failed to create embedding model: {e}") + logger.info("Are your environment variables (e.g. OPENAI_API_KEY) set?") + return + settings = TextEmbeddingIndexSettings(model) + vbase = VectorBase(settings) + + logger.info("Computing embeddings for messages (this may take some time...)") + # Batch the embeddings + batch_size = 50 + for i in range(0, len(message_texts), batch_size): + batch = message_texts[i : i + batch_size] + await vbase.add_keys(batch) + print(f" ... embedded {min(i + batch_size, len(message_texts))}/{len(message_texts)}") + + logger.info("Computing embeddings for queries...") + query_texts = [q[0] for q in queries] + query_embeddings = await model.get_embeddings(query_texts) + + # Grid search config + min_scores_to_test = [0.70, 0.75, 0.80, 0.85, 0.90, 0.95] + max_hits_to_test = [5, 10, 15, 20] + + logger.info(f"Starting grid search over model: {model.model_name}") + print("-" * 65) + print(f"{'Min Score':<12} | {'Max Hits':<10} | {'Hit Rate (%)':<15} | {'MRR':<10}") + print("-" * 65) + + best_mrr = -1.0 + best_config = None + + for ms in min_scores_to_test: + for mh in max_hits_to_test: + hits = 0 + reciprocal_ranks = [] + + for (query_text, expected_indices), q_emb in zip(queries, query_embeddings): + scored_results = vbase.fuzzy_lookup_embedding(q_emb, max_hits=mh, min_score=ms) + retrieved_indices = [sr.item for sr in scored_results] + + # Check if any of the expected items are in the retrieved answers + rank = -1 + for r_idx, retrieved in enumerate(retrieved_indices): + if retrieved in expected_indices: + rank = r_idx + 1 + break + + if rank > 0: + hits += 1 + reciprocal_ranks.append(1.0 / rank) + else: + reciprocal_ranks.append(0.0) + + hit_rate = (hits / len(queries)) * 100 + mrr = mean(reciprocal_ranks) + + print(f"{ms:<12.2f} | {mh:<10d} | {hit_rate:<15.2f} | {mrr:<10.4f}") + + if mrr > best_mrr: + best_mrr = mrr + best_config = (ms, mh) + + print("-" * 65) + if best_config: + logger.info(f"Optimal parameters found: min_score={best_config[0]}, max_hits={best_config[1]} (MRR={best_mrr:.4f})") + else: + logger.info("Could not determine optimal parameters (no hits).") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark embedding model parameters.") + parser.add_argument( + "--model", + type=str, + default=None, + help="Provider and model name, e.g. 'openai:text-embedding-3-small'", + ) + args = parser.parse_args() + asyncio.run(run_benchmark(args.model)) + + +if __name__ == "__main__": + main() From c2d019b29d33731c4b4de34d8dc3860ea76147a0 Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Fri, 10 Apr 2026 04:18:36 +0530 Subject: [PATCH 2/7] update --- tools/benchmark_embeddings.py | 223 ++++++++++++++++++++++++++++++++-- 1 file changed, 213 insertions(+), 10 deletions(-) diff --git a/tools/benchmark_embeddings.py b/tools/benchmark_embeddings.py index 2d6fc3a7..ee77c215 100644 --- a/tools/benchmark_embeddings.py +++ b/tools/benchmark_embeddings.py @@ -5,6 +5,20 @@ """ Utility script to benchmark different TextEmbeddingIndexSettings parameters. +Uses the Adrian Tchaikovsky podcast dataset (Episode 53) which contains: +- Index data: ~96 messages from the podcast conversation +- Search results: Queries with expected messageMatches (ground truth for retrieval) +- Answer results: Curated Q&A pairs with expected answers (ground truth for Q&A quality) + +The benchmark evaluates embedding model retrieval quality using: +1. Search-based evaluation: Compares fuzzy_lookup results against expected messageMatches +2. Answer-based evaluation: Tests if queries from the Answer dataset retrieve messages + that contain the expected answer content (substring matching) + +Metrics: +- Hit Rate: Percentage of queries where at least one expected result was retrieved +- MRR (Mean Reciprocal Rank): Average of 1/rank of the first relevant result + Usage: uv run python tools/benchmark_embeddings.py [--model provider:model] """ @@ -31,7 +45,9 @@ async def run_benchmark(model_spec: str | None) -> None: repo_root = script_dir.parent index_data_path = repo_root / "tests" / "testdata" / "Episode_53_AdrianTchaikovsky_index_data.json" search_data_path = repo_root / "tests" / "testdata" / "Episode_53_Search_results.json" + answer_data_path = repo_root / "tests" / "testdata" / "Episode_53_Answer_results.json" + # ── Load index data (messages to embed) ── logger.info(f"Loading index data from {index_data_path}") try: with open(index_data_path, "r", encoding="utf-8") as f: @@ -43,6 +59,7 @@ async def run_benchmark(model_spec: str | None) -> None: messages = index_json.get("messages", []) message_texts = [" ".join(m.get("textChunks", [])) for m in messages] + # ── Load search queries (ground truth: messageMatches) ── logger.info(f"Loading search queries from {search_data_path}") try: with open(search_data_path, "r", encoding="utf-8") as f: @@ -52,7 +69,7 @@ async def run_benchmark(model_spec: str | None) -> None: return # Filter out ones without results or expected matches - queries = [] + search_queries: list[tuple[str, list[int]]] = [] for item in search_json: search_text = item.get("searchText") results = item.get("results", []) @@ -61,11 +78,30 @@ async def run_benchmark(model_spec: str | None) -> None: expected = results[0].get("messageMatches", []) if not expected: continue - queries.append((search_text, expected)) + search_queries.append((search_text, expected)) + + # ── Load answer results (Q&A ground truth from Adrian Tchaikovsky dataset) ── + answer_queries: list[tuple[str, str, bool]] = [] # (question, answer, hasNoAnswer) + logger.info(f"Loading answer results from {answer_data_path}") + try: + with open(answer_data_path, "r", encoding="utf-8") as f: + answer_json = json.load(f) + for item in answer_json: + question = item.get("question", "") + answer = item.get("answer", "") + has_no_answer = item.get("hasNoAnswer", False) + if question and answer: + answer_queries.append((question, answer, has_no_answer)) + logger.info(f"Found {len(answer_queries)} answer Q&A pairs " + f"({sum(1 for _, _, h in answer_queries if not h)} with answers, " + f"{sum(1 for _, _, h in answer_queries if h)} with no-answer).") + except Exception as e: + logger.warning(f"Failed to load answer results (continuing without): {e}") logger.info(f"Found {len(message_texts)} messages to embed.") - logger.info(f"Found {len(queries)} queries with expected matches to test.") + logger.info(f"Found {len(search_queries)} search queries with expected matches.") + # ── Create embedding model and index ── try: if model_spec == "test:fake": from typeagent.aitools.model_adapters import create_test_embedding_model @@ -87,16 +123,30 @@ async def run_benchmark(model_spec: str | None) -> None: await vbase.add_keys(batch) print(f" ... embedded {min(i + batch_size, len(message_texts))}/{len(message_texts)}") - logger.info("Computing embeddings for queries...") - query_texts = [q[0] for q in queries] - query_embeddings = await model.get_embeddings(query_texts) + # ── Compute query embeddings ── + logger.info("Computing embeddings for search queries...") + search_query_texts = [q[0] for q in search_queries] + search_query_embeddings = await model.get_embeddings(search_query_texts) + + answer_query_embeddings = None + if answer_queries: + logger.info("Computing embeddings for answer queries...") + answer_query_texts = [q[0] for q in answer_queries] + answer_query_embeddings = await model.get_embeddings(answer_query_texts) + + # ────────────────────────────────────────────────────────────────────── + # Section 1: Grid Search using Search Results (messageMatches) + # ────────────────────────────────────────────────────────────────────── # Grid search config min_scores_to_test = [0.70, 0.75, 0.80, 0.85, 0.90, 0.95] max_hits_to_test = [5, 10, 15, 20] logger.info(f"Starting grid search over model: {model.model_name}") - print("-" * 65) + print() + print("=" * 72) + print(" SEARCH RESULTS BENCHMARK (messageMatches ground truth)") + print("=" * 72) print(f"{'Min Score':<12} | {'Max Hits':<10} | {'Hit Rate (%)':<15} | {'MRR':<10}") print("-" * 65) @@ -108,7 +158,7 @@ async def run_benchmark(model_spec: str | None) -> None: hits = 0 reciprocal_ranks = [] - for (query_text, expected_indices), q_emb in zip(queries, query_embeddings): + for (query_text, expected_indices), q_emb in zip(search_queries, search_query_embeddings): scored_results = vbase.fuzzy_lookup_embedding(q_emb, max_hits=mh, min_score=ms) retrieved_indices = [sr.item for sr in scored_results] @@ -125,7 +175,7 @@ async def run_benchmark(model_spec: str | None) -> None: else: reciprocal_ranks.append(0.0) - hit_rate = (hits / len(queries)) * 100 + hit_rate = (hits / len(search_queries)) * 100 mrr = mean(reciprocal_ranks) print(f"{ms:<12.2f} | {mh:<10d} | {hit_rate:<15.2f} | {mrr:<10.4f}") @@ -136,10 +186,163 @@ async def run_benchmark(model_spec: str | None) -> None: print("-" * 65) if best_config: - logger.info(f"Optimal parameters found: min_score={best_config[0]}, max_hits={best_config[1]} (MRR={best_mrr:.4f})") + logger.info(f"Search benchmark optimal: min_score={best_config[0]}, " + f"max_hits={best_config[1]} (MRR={best_mrr:.4f})") else: logger.info("Could not determine optimal parameters (no hits).") + # ────────────────────────────────────────────────────────────────────── + # Section 2: Answer Results Benchmark (Adrian Tchaikovsky Q&A pairs) + # ────────────────────────────────────────────────────────────────────── + + if answer_queries and answer_query_embeddings is not None: + print() + print("=" * 72) + print(" ANSWER RESULTS BENCHMARK (Adrian Tchaikovsky Q&A ground truth)") + print("=" * 72) + print() + + # For each answer query, check if retrieved messages contain key terms + # from the expected answer. This is a content-based relevance check. + # + # We split answers with hasNoAnswer=True vs False to evaluate separately. + + answerable = [(q, a, emb) for (q, a, h), emb + in zip(answer_queries, answer_query_embeddings) if not h] + unanswerable = [(q, a, emb) for (q, a, h), emb + in zip(answer_queries, answer_query_embeddings) if h] + + print(f"Answerable queries: {len(answerable)}") + print(f"Unanswerable queries (hasNoAnswer=True): {len(unanswerable)}") + print() + + # Extract key terms from expected answers for content matching + def extract_answer_keywords(answer_text: str) -> list[str]: + """Extract distinctive keywords/phrases from an answer for matching.""" + # Look for quoted items, proper nouns, and distinctive phrases + keywords = [] + # Extract quoted phrases + import re + quoted = re.findall(r"'([^']+)'", answer_text) + keywords.extend(quoted) + quoted2 = re.findall(r'"([^"]+)"', answer_text) + keywords.extend(quoted2) + + # Extract proper-noun-like terms (capitalized words that aren't sentence starters) + # and key named entities from the Adrian Tchaikovsky dataset + known_entities = [ + "Adrian Tchaikovsky", "Tchaikovsky", "Kevin Scott", "Christina Warren", + "Children of Time", "Children of Ruin", "Children of Memory", + "Shadows of the Apt", "Empire in Black and Gold", + "Final Architecture", "Lords of Uncreation", + "Dragonlance Chronicles", "Skynet", "Portids", "Corvids", + "University of Reading", "Magnus Carlsen", "Warhammer", + "Asimov", "Peter Watts", "William Gibson", "Iain Banks", + "Peter Hamilton", "Arthur C. Clarke", "Profiles of the Future", + "Dune", "Brave New World", "Iron Sunrise", "Wall-E", + "George RR Martin", "Alastair Reynolds", "Ovid", + "zoology", "psychology", "spiders", "arachnids", "insects", + ] + for entity in known_entities: + if entity.lower() in answer_text.lower(): + keywords.append(entity) + + return keywords + + # Run answer benchmark with the best config from search benchmark + if best_config: + eval_min_score, eval_max_hits = best_config + else: + eval_min_score, eval_max_hits = 0.80, 10 + + print(f"Using parameters: min_score={eval_min_score}, max_hits={eval_max_hits}") + print("-" * 72) + print(f"{'#':<4} | {'Question':<45} | {'Keywords Found':<14} | {'Msgs':<5}") + print("-" * 72) + + answer_hits = 0 + answer_keyword_scores: list[float] = [] + + for idx, (question, answer, q_emb) in enumerate(answerable, 1): + scored_results = vbase.fuzzy_lookup_embedding( + q_emb, max_hits=eval_max_hits, min_score=eval_min_score + ) + retrieved_indices = [sr.item for sr in scored_results] + + # Concatenate the text of all retrieved messages + retrieved_text = " ".join( + message_texts[i] for i in retrieved_indices if i < len(message_texts) + ) + + # Check how many answer keywords appear in retrieved text + keywords = extract_answer_keywords(answer) + if keywords: + found = sum( + 1 for kw in keywords + if kw.lower() in retrieved_text.lower() + ) + keyword_score = found / len(keywords) + else: + # No keywords extracted — just check if we retrieved anything + keyword_score = 1.0 if retrieved_indices else 0.0 + + if keyword_score > 0: + answer_hits += 1 + answer_keyword_scores.append(keyword_score) + + q_display = question[:42] + "..." if len(question) > 45 else question + kw_display = f"{int(keyword_score * 100):>3}%" + if keywords: + kw_display += f" ({sum(1 for kw in keywords if kw.lower() in retrieved_text.lower())}/{len(keywords)})" + print(f"{idx:<4} | {q_display:<45} | {kw_display:<14} | {len(retrieved_indices):<5}") + + print("-" * 72) + + if answerable: + answer_hit_rate = (answer_hits / len(answerable)) * 100 + avg_keyword_score = mean(answer_keyword_scores) * 100 + print(f"Answer Hit Rate: {answer_hit_rate:.1f}% " + f"({answer_hits}/{len(answerable)} queries found relevant content)") + print(f"Avg Keyword Coverage: {avg_keyword_score:.1f}%") + + # Evaluate unanswerable queries — ideally these should retrieve fewer/no results + if unanswerable: + print() + print("-" * 72) + print("Unanswerable queries (should ideally retrieve less relevant content):") + print("-" * 72) + false_positive_count = 0 + for question, answer, q_emb in unanswerable: + scored_results = vbase.fuzzy_lookup_embedding( + q_emb, max_hits=eval_max_hits, min_score=eval_min_score + ) + n_results = len(scored_results) + avg_score = mean(sr.score for sr in scored_results) if scored_results else 0.0 + q_display = question[:55] + "..." if len(question) > 58 else question + flag = "[!]" if n_results > 3 else "[ok]" + if n_results > 3: + false_positive_count += 1 + print(f" {flag} {q_display:<58} | {n_results:>3} results (avg={avg_score:.3f})") + print(f"\nFalse positives (>3 results): {false_positive_count}/{len(unanswerable)}") + + # ── Summary ── + print() + print("=" * 72) + print(" SUMMARY") + print("=" * 72) + print(f"Model: {model.model_name}") + print(f"Messages indexed: {len(message_texts)}") + print(f"Search queries tested: {len(search_queries)}") + if best_config: + print(f"Best search params: min_score={best_config[0]}, max_hits={best_config[1]}") + print(f"Best search MRR: {best_mrr:.4f}") + if answer_queries: + print(f"Answer queries tested: {len(answerable)} answerable, {len(unanswerable)} unanswerable") + if answerable: + print(f"Answer hit rate: {answer_hit_rate:.1f}%") + print(f"Keyword coverage: {avg_keyword_score:.1f}%") + print("=" * 72) + def main() -> None: parser = argparse.ArgumentParser(description="Benchmark embedding model parameters.") From 0678f8a2097a4aa785bb3622918ecb943f0e7da0 Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Sat, 11 Apr 2026 23:08:34 +0530 Subject: [PATCH 3/7] Tune embedding defaults with benchmark-backed thresholds Add benchmark scripts for sweeping and repeating min_score/max_hits against the Episode 53 dataset, update TextEmbeddingIndexSettings to use model-specific default min_score values, and add tests covering benchmark helper logic and explicit settings overrides. --- src/typeagent/aitools/vectorbase.py | 86 ++--- tools/benchmark_embeddings.py | 530 +++++++++++---------------- tools/repeat_embedding_benchmarks.py | 399 ++++++++++++++++++++ 3 files changed, 657 insertions(+), 358 deletions(-) create mode 100644 tools/repeat_embedding_benchmarks.py diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 34552fae..8f898ebf 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -13,15 +13,46 @@ ) from .model_adapters import create_embedding_model +DEFAULT_MIN_SCORE = 0.85 + +# Empirical defaults for built-in OpenAI embedding models. +# These values come from repeated runs of the Adrian Tchaikovsky Episode 53 +# search benchmark in `tools/repeat_embedding_benchmarks.py`, using an +# exhaustive 0.01..1.00 min_score sweep on the Adrian Tchaikovsky Episode 53 +# dataset. We keep the highest min_score that preserves the best benchmark +# metrics for each model, which yielded the current plateau boundaries of 0.16 for +# `text-embedding-3-small`, 0.07 for `text-embedding-3-large`, and 0.72 for +# `text-embedding-ada-002`. These are repository defaults for known models, +# not universal truths. Unknown models keep the long-standing fallback score +# of 0.85. Callers can always override `min_score` explicitly for their own +# use cases or models. We intentionally leave `max_matches` out of this table: +# the benchmark still reports a best `max_hits` row, but the library default +# remains `None` unless a caller opts into a specific limit. +MODEL_DEFAULT_MIN_SCORES: dict[str, float] = { + "text-embedding-3-large": 0.07, + "text-embedding-3-small": 0.16, + "text-embedding-ada-002": 0.72, +} + + +def get_default_min_score(model_name: str) -> float: + """Return the repository default score cutoff for a known model name.""" + + return MODEL_DEFAULT_MIN_SCORES.get(model_name, DEFAULT_MIN_SCORE) + @dataclass class ScoredInt: + """Associate an integer ordinal with its similarity score.""" + item: int score: float @dataclass class TextEmbeddingIndexSettings: + """Runtime settings for embedding-backed fuzzy lookup.""" + embedding_model: IEmbeddingModel min_score: float # Between 0.0 and 1.0 max_matches: int | None # >= 1; None means no limit @@ -35,32 +66,10 @@ def __init__( batch_size: int | None = None, ): self.embedding_model = embedding_model or create_embedding_model() - - # Default fallback values - default_min_score = 0.85 - default_max_matches = None - - # Determine optimal parameters automatically for well-known models. - # Format: (min_score, max_matches) - # Note: text-embedding-3 models produce structurally lower cosine scores than older models - # and typically perform best in the 0.3 - 0.5 range for relevance filtering. - MODEL_DEFAULTS = { - "text-embedding-3-large": (0.30, 20), - "text-embedding-3-small": (0.35, 20), - "text-embedding-ada-002": (0.75, 20), - } - - # Check if the model_name matches any known ones - model_name = getattr(self.embedding_model, 'model_name', "") - - if model_name: - for known_model, defaults in MODEL_DEFAULTS.items(): - if known_model in model_name: - default_min_score, default_max_matches = defaults - break - + model_name = getattr(self.embedding_model, "model_name", "") + default_min_score = get_default_min_score(model_name) self.min_score = min_score if min_score is not None else default_min_score - self.max_matches = max_matches if max_matches is not None else default_max_matches + self.max_matches = max_matches if max_matches and max_matches >= 1 else None self.batch_size = batch_size if batch_size and batch_size >= 1 else 8 @@ -190,27 +199,10 @@ def fuzzy_lookup_embedding_in_subset( max_hits: int | None = None, min_score: float | None = None, ) -> list[ScoredInt]: - if max_hits is None: - max_hits = 10 - if min_score is None: - min_score = 0.0 - if not ordinals_of_subset or len(self._vectors) == 0: - return [] - # Compute dot products only for the subset instead of all vectors. - subset = np.asarray(ordinals_of_subset) - scores = np.dot(self._vectors[subset], embedding) - indices = np.flatnonzero(scores >= min_score) - if len(indices) == 0: - return [] - filtered_scores = scores[indices] - if len(indices) <= max_hits: - order = np.argsort(filtered_scores)[::-1] - else: - top_k = np.argpartition(filtered_scores, -max_hits)[-max_hits:] - order = top_k[np.argsort(filtered_scores[top_k])[::-1]] - return [ - ScoredInt(int(subset[indices[i]]), float(filtered_scores[i])) for i in order - ] + ordinals_set = set(ordinals_of_subset) + return self.fuzzy_lookup_embedding( + embedding, max_hits, min_score, lambda i: i in ordinals_set + ) async def fuzzy_lookup( self, @@ -259,7 +251,7 @@ def deserialize(self, data: NormalizedEmbeddings | None) -> None: return if self._embedding_size == 0: if data.ndim < 2 or data.shape[0] == 0: - # Empty data — can't determine size; just clear. + # Empty data can't determine size; just clear. self.clear() return self._set_embedding_size(data.shape[1]) diff --git a/tools/benchmark_embeddings.py b/tools/benchmark_embeddings.py index ee77c215..4358ea31 100644 --- a/tools/benchmark_embeddings.py +++ b/tools/benchmark_embeddings.py @@ -1,359 +1,267 @@ -#!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -""" -Utility script to benchmark different TextEmbeddingIndexSettings parameters. - -Uses the Adrian Tchaikovsky podcast dataset (Episode 53) which contains: -- Index data: ~96 messages from the podcast conversation -- Search results: Queries with expected messageMatches (ground truth for retrieval) -- Answer results: Curated Q&A pairs with expected answers (ground truth for Q&A quality) +"""Benchmark retrieval settings for known embedding models. -The benchmark evaluates embedding model retrieval quality using: -1. Search-based evaluation: Compares fuzzy_lookup results against expected messageMatches -2. Answer-based evaluation: Tests if queries from the Answer dataset retrieve messages - that contain the expected answer content (substring matching) +This script evaluates the Adrian Tchaikovsky Episode 53 search dataset in +`tests/testdata/` and reports retrieval quality for combinations of +`min_score` and `max_hits`. -Metrics: -- Hit Rate: Percentage of queries where at least one expected result was retrieved -- MRR (Mean Reciprocal Rank): Average of 1/rank of the first relevant result +The benchmark is intentionally narrow: +- It only measures retrieval against `messageMatches` ground truth. +- It is meant to help choose repository defaults for known models. +- In practice, `min_score` is the primary library default this informs. +- It does not prove universal "best" settings for every dataset. Usage: - uv run python tools/benchmark_embeddings.py [--model provider:model] + uv run python tools/benchmark_embeddings.py + uv run python tools/benchmark_embeddings.py --model openai:text-embedding-3-small """ import argparse import asyncio +from dataclasses import dataclass import json -import logging from pathlib import Path from statistics import mean -import sys -from typing import Any +from dotenv import load_dotenv + +from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbeddings from typeagent.aitools.model_adapters import create_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase +DEFAULT_MIN_SCORES = [0.25, 0.30, 0.35, 0.40, 0.50, 0.60, 0.70, 0.75, 0.80, 0.85] +DEFAULT_MAX_HITS = [5, 10, 15, 20] +DATA_DIR = Path("tests") / "testdata" +INDEX_DATA_PATH = DATA_DIR / "Episode_53_AdrianTchaikovsky_index_data.json" +SEARCH_RESULTS_PATH = DATA_DIR / "Episode_53_Search_results.json" + + +@dataclass +class SearchQueryCase: + query: str + expected_matches: list[int] + + +@dataclass +class SearchMetrics: + hit_rate: float + mean_reciprocal_rank: float + + +@dataclass +class BenchmarkRow: + min_score: float + max_hits: int + metrics: SearchMetrics + + +def parse_float_list(raw: str | None) -> list[float]: + if raw is None: + return DEFAULT_MIN_SCORES + values = [float(item.strip()) for item in raw.split(",") if item.strip()] + if not values: + raise ValueError("--min-scores must contain at least one value") + return values + -async def run_benchmark(model_spec: str | None) -> None: - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - # Paths - script_dir = Path(__file__).resolve().parent - repo_root = script_dir.parent - index_data_path = repo_root / "tests" / "testdata" / "Episode_53_AdrianTchaikovsky_index_data.json" - search_data_path = repo_root / "tests" / "testdata" / "Episode_53_Search_results.json" - answer_data_path = repo_root / "tests" / "testdata" / "Episode_53_Answer_results.json" - - # ── Load index data (messages to embed) ── - logger.info(f"Loading index data from {index_data_path}") - try: - with open(index_data_path, "r", encoding="utf-8") as f: - index_json = json.load(f) - except Exception as e: - logger.error(f"Failed to load index data: {e}") - return - - messages = index_json.get("messages", []) - message_texts = [" ".join(m.get("textChunks", [])) for m in messages] - - # ── Load search queries (ground truth: messageMatches) ── - logger.info(f"Loading search queries from {search_data_path}") - try: - with open(search_data_path, "r", encoding="utf-8") as f: - search_json = json.load(f) - except Exception as e: - logger.error(f"Failed to load search queries: {e}") - return - - # Filter out ones without results or expected matches - search_queries: list[tuple[str, list[int]]] = [] - for item in search_json: +def parse_int_list(raw: str | None) -> list[int]: + if raw is None: + return DEFAULT_MAX_HITS + values = [int(item.strip()) for item in raw.split(",") if item.strip()] + if not values: + raise ValueError("--max-hits must contain at least one value") + if any(value <= 0 for value in values): + raise ValueError("--max-hits values must be positive integers") + return values + + +def load_message_texts(repo_root: Path) -> list[str]: + index_data = json.loads((repo_root / INDEX_DATA_PATH).read_text(encoding="utf-8")) + messages = index_data["messages"] + return [" ".join(message.get("textChunks", [])) for message in messages] + + +def load_search_queries(repo_root: Path) -> list[SearchQueryCase]: + search_data = json.loads( + (repo_root / SEARCH_RESULTS_PATH).read_text(encoding="utf-8") + ) + cases: list[SearchQueryCase] = [] + for item in search_data: search_text = item.get("searchText") results = item.get("results", []) - if not results: + if not search_text or not results: continue - expected = results[0].get("messageMatches", []) - if not expected: + expected_matches = results[0].get("messageMatches", []) + if not expected_matches: continue - search_queries.append((search_text, expected)) - - # ── Load answer results (Q&A ground truth from Adrian Tchaikovsky dataset) ── - answer_queries: list[tuple[str, str, bool]] = [] # (question, answer, hasNoAnswer) - logger.info(f"Loading answer results from {answer_data_path}") - try: - with open(answer_data_path, "r", encoding="utf-8") as f: - answer_json = json.load(f) - for item in answer_json: - question = item.get("question", "") - answer = item.get("answer", "") - has_no_answer = item.get("hasNoAnswer", False) - if question and answer: - answer_queries.append((question, answer, has_no_answer)) - logger.info(f"Found {len(answer_queries)} answer Q&A pairs " - f"({sum(1 for _, _, h in answer_queries if not h)} with answers, " - f"{sum(1 for _, _, h in answer_queries if h)} with no-answer).") - except Exception as e: - logger.warning(f"Failed to load answer results (continuing without): {e}") - - logger.info(f"Found {len(message_texts)} messages to embed.") - logger.info(f"Found {len(search_queries)} search queries with expected matches.") - - # ── Create embedding model and index ── - try: - if model_spec == "test:fake": - from typeagent.aitools.model_adapters import create_test_embedding_model - model = create_test_embedding_model(embedding_size=384) + cases.append(SearchQueryCase(search_text, expected_matches)) + return cases + + +async def build_vector_base( + model_spec: str | None, + message_texts: list[str], + batch_size: int, +) -> tuple[IEmbeddingModel, VectorBase]: + model = create_embedding_model(model_spec) + settings = TextEmbeddingIndexSettings( + embedding_model=model, + min_score=0.0, + max_matches=None, + batch_size=batch_size, + ) + vector_base = VectorBase(settings) + + for start in range(0, len(message_texts), batch_size): + batch = message_texts[start : start + batch_size] + await vector_base.add_keys(batch) + + return model, vector_base + + +def evaluate_search_queries( + vector_base: VectorBase, + query_cases: list[SearchQueryCase], + query_embeddings: NormalizedEmbeddings, + min_score: float, + max_hits: int, +) -> SearchMetrics: + hit_count = 0 + reciprocal_ranks: list[float] = [] + + for case, query_embedding in zip(query_cases, query_embeddings): + scored_results = vector_base.fuzzy_lookup_embedding( + query_embedding, + max_hits=max_hits, + min_score=min_score, + ) + rank = 0 + for result_index, scored_result in enumerate(scored_results, start=1): + if scored_result.item in case.expected_matches: + rank = result_index + break + if rank > 0: + hit_count += 1 + reciprocal_ranks.append(1.0 / rank) else: - model = create_embedding_model(model_spec) - except Exception as e: - logger.error(f"Failed to create embedding model: {e}") - logger.info("Are your environment variables (e.g. OPENAI_API_KEY) set?") - return - settings = TextEmbeddingIndexSettings(model) - vbase = VectorBase(settings) - - logger.info("Computing embeddings for messages (this may take some time...)") - # Batch the embeddings - batch_size = 50 - for i in range(0, len(message_texts), batch_size): - batch = message_texts[i : i + batch_size] - await vbase.add_keys(batch) - print(f" ... embedded {min(i + batch_size, len(message_texts))}/{len(message_texts)}") - - # ── Compute query embeddings ── - logger.info("Computing embeddings for search queries...") - search_query_texts = [q[0] for q in search_queries] - search_query_embeddings = await model.get_embeddings(search_query_texts) - - answer_query_embeddings = None - if answer_queries: - logger.info("Computing embeddings for answer queries...") - answer_query_texts = [q[0] for q in answer_queries] - answer_query_embeddings = await model.get_embeddings(answer_query_texts) - - # ────────────────────────────────────────────────────────────────────── - # Section 1: Grid Search using Search Results (messageMatches) - # ────────────────────────────────────────────────────────────────────── - - # Grid search config - min_scores_to_test = [0.70, 0.75, 0.80, 0.85, 0.90, 0.95] - max_hits_to_test = [5, 10, 15, 20] - - logger.info(f"Starting grid search over model: {model.model_name}") - print() + reciprocal_ranks.append(0.0) + + return SearchMetrics( + hit_rate=(hit_count / len(query_cases)) * 100, + mean_reciprocal_rank=mean(reciprocal_ranks), + ) + + +def select_best_row(rows: list[BenchmarkRow]) -> BenchmarkRow: + return max( + rows, + key=lambda row: ( + row.metrics.mean_reciprocal_rank, + row.metrics.hit_rate, + -row.min_score, + -row.max_hits, + ), + ) + + +def print_rows(rows: list[BenchmarkRow]) -> None: print("=" * 72) - print(" SEARCH RESULTS BENCHMARK (messageMatches ground truth)") + print("SEARCH BENCHMARK (Episode 53 messageMatches ground truth)") print("=" * 72) print(f"{'Min Score':<12} | {'Max Hits':<10} | {'Hit Rate (%)':<15} | {'MRR':<10}") print("-" * 65) - - best_mrr = -1.0 - best_config = None - - for ms in min_scores_to_test: - for mh in max_hits_to_test: - hits = 0 - reciprocal_ranks = [] - - for (query_text, expected_indices), q_emb in zip(search_queries, search_query_embeddings): - scored_results = vbase.fuzzy_lookup_embedding(q_emb, max_hits=mh, min_score=ms) - retrieved_indices = [sr.item for sr in scored_results] - - # Check if any of the expected items are in the retrieved answers - rank = -1 - for r_idx, retrieved in enumerate(retrieved_indices): - if retrieved in expected_indices: - rank = r_idx + 1 - break - - if rank > 0: - hits += 1 - reciprocal_ranks.append(1.0 / rank) - else: - reciprocal_ranks.append(0.0) - - hit_rate = (hits / len(search_queries)) * 100 - mrr = mean(reciprocal_ranks) - - print(f"{ms:<12.2f} | {mh:<10d} | {hit_rate:<15.2f} | {mrr:<10.4f}") - - if mrr > best_mrr: - best_mrr = mrr - best_config = (ms, mh) - + for row in rows: + print( + f"{row.min_score:<12.2f} | {row.max_hits:<10d} | " + f"{row.metrics.hit_rate:<15.2f} | " + f"{row.metrics.mean_reciprocal_rank:<10.4f}" + ) print("-" * 65) - if best_config: - logger.info(f"Search benchmark optimal: min_score={best_config[0]}, " - f"max_hits={best_config[1]} (MRR={best_mrr:.4f})") - else: - logger.info("Could not determine optimal parameters (no hits).") - - # ────────────────────────────────────────────────────────────────────── - # Section 2: Answer Results Benchmark (Adrian Tchaikovsky Q&A pairs) - # ────────────────────────────────────────────────────────────────────── - - if answer_queries and answer_query_embeddings is not None: - print() - print("=" * 72) - print(" ANSWER RESULTS BENCHMARK (Adrian Tchaikovsky Q&A ground truth)") - print("=" * 72) - print() - - # For each answer query, check if retrieved messages contain key terms - # from the expected answer. This is a content-based relevance check. - # - # We split answers with hasNoAnswer=True vs False to evaluate separately. - - answerable = [(q, a, emb) for (q, a, h), emb - in zip(answer_queries, answer_query_embeddings) if not h] - unanswerable = [(q, a, emb) for (q, a, h), emb - in zip(answer_queries, answer_query_embeddings) if h] - - print(f"Answerable queries: {len(answerable)}") - print(f"Unanswerable queries (hasNoAnswer=True): {len(unanswerable)}") - print() - - # Extract key terms from expected answers for content matching - def extract_answer_keywords(answer_text: str) -> list[str]: - """Extract distinctive keywords/phrases from an answer for matching.""" - # Look for quoted items, proper nouns, and distinctive phrases - keywords = [] - # Extract quoted phrases - import re - quoted = re.findall(r"'([^']+)'", answer_text) - keywords.extend(quoted) - quoted2 = re.findall(r'"([^"]+)"', answer_text) - keywords.extend(quoted2) - - # Extract proper-noun-like terms (capitalized words that aren't sentence starters) - # and key named entities from the Adrian Tchaikovsky dataset - known_entities = [ - "Adrian Tchaikovsky", "Tchaikovsky", "Kevin Scott", "Christina Warren", - "Children of Time", "Children of Ruin", "Children of Memory", - "Shadows of the Apt", "Empire in Black and Gold", - "Final Architecture", "Lords of Uncreation", - "Dragonlance Chronicles", "Skynet", "Portids", "Corvids", - "University of Reading", "Magnus Carlsen", "Warhammer", - "Asimov", "Peter Watts", "William Gibson", "Iain Banks", - "Peter Hamilton", "Arthur C. Clarke", "Profiles of the Future", - "Dune", "Brave New World", "Iron Sunrise", "Wall-E", - "George RR Martin", "Alastair Reynolds", "Ovid", - "zoology", "psychology", "spiders", "arachnids", "insects", - ] - for entity in known_entities: - if entity.lower() in answer_text.lower(): - keywords.append(entity) - - return keywords - - # Run answer benchmark with the best config from search benchmark - if best_config: - eval_min_score, eval_max_hits = best_config - else: - eval_min_score, eval_max_hits = 0.80, 10 - - print(f"Using parameters: min_score={eval_min_score}, max_hits={eval_max_hits}") - print("-" * 72) - print(f"{'#':<4} | {'Question':<45} | {'Keywords Found':<14} | {'Msgs':<5}") - print("-" * 72) - answer_hits = 0 - answer_keyword_scores: list[float] = [] - for idx, (question, answer, q_emb) in enumerate(answerable, 1): - scored_results = vbase.fuzzy_lookup_embedding( - q_emb, max_hits=eval_max_hits, min_score=eval_min_score +async def run_benchmark( + model_spec: str | None, + min_scores: list[float], + max_hits_values: list[int], + batch_size: int, +) -> None: + load_dotenv() + + repo_root = Path(__file__).resolve().parent.parent + message_texts = load_message_texts(repo_root) + query_cases = load_search_queries(repo_root) + if not query_cases: + raise ValueError("No search queries with messageMatches found in the dataset") + model, vector_base = await build_vector_base(model_spec, message_texts, batch_size) + query_embeddings = await model.get_embeddings([case.query for case in query_cases]) + + rows: list[BenchmarkRow] = [] + for min_score in min_scores: + for max_hits in max_hits_values: + metrics = evaluate_search_queries( + vector_base, + query_cases, + query_embeddings, + min_score, + max_hits, ) - retrieved_indices = [sr.item for sr in scored_results] + rows.append(BenchmarkRow(min_score, max_hits, metrics)) - # Concatenate the text of all retrieved messages - retrieved_text = " ".join( - message_texts[i] for i in retrieved_indices if i < len(message_texts) - ) + print(f"Model: {model.model_name}") + print(f"Messages indexed: {len(message_texts)}") + print(f"Queries evaluated: {len(query_cases)}") + print() + print_rows(rows) - # Check how many answer keywords appear in retrieved text - keywords = extract_answer_keywords(answer) - if keywords: - found = sum( - 1 for kw in keywords - if kw.lower() in retrieved_text.lower() - ) - keyword_score = found / len(keywords) - else: - # No keywords extracted — just check if we retrieved anything - keyword_score = 1.0 if retrieved_indices else 0.0 - - if keyword_score > 0: - answer_hits += 1 - answer_keyword_scores.append(keyword_score) - - q_display = question[:42] + "..." if len(question) > 45 else question - kw_display = f"{int(keyword_score * 100):>3}%" - if keywords: - kw_display += f" ({sum(1 for kw in keywords if kw.lower() in retrieved_text.lower())}/{len(keywords)})" - print(f"{idx:<4} | {q_display:<45} | {kw_display:<14} | {len(retrieved_indices):<5}") - - print("-" * 72) - - if answerable: - answer_hit_rate = (answer_hits / len(answerable)) * 100 - avg_keyword_score = mean(answer_keyword_scores) * 100 - print(f"Answer Hit Rate: {answer_hit_rate:.1f}% " - f"({answer_hits}/{len(answerable)} queries found relevant content)") - print(f"Avg Keyword Coverage: {avg_keyword_score:.1f}%") - - # Evaluate unanswerable queries — ideally these should retrieve fewer/no results - if unanswerable: - print() - print("-" * 72) - print("Unanswerable queries (should ideally retrieve less relevant content):") - print("-" * 72) - false_positive_count = 0 - for question, answer, q_emb in unanswerable: - scored_results = vbase.fuzzy_lookup_embedding( - q_emb, max_hits=eval_max_hits, min_score=eval_min_score - ) - n_results = len(scored_results) - avg_score = mean(sr.score for sr in scored_results) if scored_results else 0.0 - q_display = question[:55] + "..." if len(question) > 58 else question - flag = "[!]" if n_results > 3 else "[ok]" - if n_results > 3: - false_positive_count += 1 - print(f" {flag} {q_display:<58} | {n_results:>3} results (avg={avg_score:.3f})") - print(f"\nFalse positives (>3 results): {false_positive_count}/{len(unanswerable)}") - - # ── Summary ── + best_row = select_best_row(rows) print() - print("=" * 72) - print(" SUMMARY") - print("=" * 72) - print(f"Model: {model.model_name}") - print(f"Messages indexed: {len(message_texts)}") - print(f"Search queries tested: {len(search_queries)}") - if best_config: - print(f"Best search params: min_score={best_config[0]}, max_hits={best_config[1]}") - print(f"Best search MRR: {best_mrr:.4f}") - if answer_queries: - print(f"Answer queries tested: {len(answerable)} answerable, {len(unanswerable)} unanswerable") - if answerable: - print(f"Answer hit rate: {answer_hit_rate:.1f}%") - print(f"Keyword coverage: {avg_keyword_score:.1f}%") - print("=" * 72) + print("Best-scoring benchmark row:") + print(f" min_score={best_row.min_score:.2f}") + print(f" max_hits={best_row.max_hits}") + print(f" hit_rate={best_row.metrics.hit_rate:.2f}%") + print(f" mrr={best_row.metrics.mean_reciprocal_rank:.4f}") def main() -> None: - parser = argparse.ArgumentParser(description="Benchmark embedding model parameters.") + parser = argparse.ArgumentParser( + description="Benchmark retrieval settings for an embedding model." + ) parser.add_argument( "--model", type=str, default=None, help="Provider and model name, e.g. 'openai:text-embedding-3-small'", ) + parser.add_argument( + "--min-scores", + type=str, + default=None, + help="Comma-separated min_score values to test.", + ) + parser.add_argument( + "--max-hits", + type=str, + default=None, + help="Comma-separated max_hits values to test.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=16, + help="Batch size used when building the index.", + ) args = parser.parse_args() - asyncio.run(run_benchmark(args.model)) + + asyncio.run( + run_benchmark( + model_spec=args.model, + min_scores=parse_float_list(args.min_scores), + max_hits_values=parse_int_list(args.max_hits), + batch_size=args.batch_size, + ) + ) if __name__ == "__main__": diff --git a/tools/repeat_embedding_benchmarks.py b/tools/repeat_embedding_benchmarks.py new file mode 100644 index 00000000..bc2a741a --- /dev/null +++ b/tools/repeat_embedding_benchmarks.py @@ -0,0 +1,399 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Run embedding benchmarks repeatedly and save raw/summary JSON results. + +This script runs `tools/benchmark_embeddings.py` logic multiple times for each +embedding model, stores every run as JSON, and writes aggregate summaries that +can be used to justify tuned defaults. + +Usage: + uv run python tools/repeat_embedding_benchmarks.py + uv run python tools/repeat_embedding_benchmarks.py --runs 30 + uv run python tools/repeat_embedding_benchmarks.py --models openai:text-embedding-3-small,openai:text-embedding-3-large,openai:text-embedding-ada-002 + uv run python tools/repeat_embedding_benchmarks.py --models openai:text-embedding-3-small --min-score-start 0.01 --min-score-stop 0.20 --min-score-step 0.01 +""" + +import argparse +import asyncio +from dataclasses import asdict, dataclass +from datetime import datetime, UTC +import json +from pathlib import Path +from statistics import mean + +import benchmark_embeddings +from dotenv import load_dotenv + +BenchmarkRow = benchmark_embeddings.BenchmarkRow +DEFAULT_MAX_HITS = benchmark_embeddings.DEFAULT_MAX_HITS +parse_int_list = benchmark_embeddings.parse_int_list +resolve_min_scores = benchmark_embeddings.resolve_min_scores + +DEFAULT_MODELS = [ + "openai:text-embedding-3-small", + "openai:text-embedding-3-large", + "openai:text-embedding-ada-002", +] +DEFAULT_OUTPUT_DIR = Path("benchmark_results") + + +@dataclass +class RunRow: + """Serialized benchmark row for one repeated run.""" + + min_score: float + max_hits: int + hit_rate: float + mean_reciprocal_rank: float + + +@dataclass +class RunResult: + """All measurements captured for one benchmark repetition.""" + + run_index: int + model_spec: str + resolved_model_name: str + message_count: int + query_count: int + min_top_score: float + mean_top_score: float + max_top_score: float + rows: list[RunRow] + best_row: RunRow + + +def sanitize_model_name(model_spec: str) -> str: + """Convert a model spec into a filesystem-safe directory name.""" + + return model_spec.replace(":", "__").replace("/", "_").replace("\\", "_") + + +def benchmark_row_to_run_row(row: BenchmarkRow) -> RunRow: + """Flatten a benchmark row into the JSON-friendly repeated-run shape.""" + + return RunRow( + min_score=row.min_score, + max_hits=row.max_hits, + hit_rate=row.metrics.hit_rate, + mean_reciprocal_rank=row.metrics.mean_reciprocal_rank, + ) + + +def summarize_runs(model_spec: str, runs: list[RunResult]) -> dict[str, object]: + """Average repeated benchmark runs into a per-model summary payload.""" + + summary_rows: dict[tuple[float, int], list[RunRow]] = {} + for run in runs: + for row in run.rows: + summary_rows.setdefault((row.min_score, row.max_hits), []).append(row) + + averaged_rows: list[dict[str, float | int]] = [] + for (min_score, max_hits), rows in sorted(summary_rows.items()): + averaged_rows.append( + { + "min_score": min_score, + "max_hits": max_hits, + "mean_hit_rate": mean(row.hit_rate for row in rows), + "mean_mrr": mean(row.mean_reciprocal_rank for row in rows), + } + ) + + best_rows = [run.best_row for run in runs] + best_min_score_counts: dict[str, int] = {} + best_max_hits_counts: dict[str, int] = {} + for row in best_rows: + best_min_score_counts[f"{row.min_score:.2f}"] = ( + best_min_score_counts.get(f"{row.min_score:.2f}", 0) + 1 + ) + best_max_hits_counts[str(row.max_hits)] = ( + best_max_hits_counts.get(str(row.max_hits), 0) + 1 + ) + + averaged_best_row = max( + averaged_rows, + key=lambda row: ( + float(row["mean_mrr"]), + float(row["mean_hit_rate"]), + float(row["min_score"]), + -int(row["max_hits"]), + ), + ) + + return { + "model_spec": model_spec, + "resolved_model_name": runs[0].resolved_model_name, + "run_count": len(runs), + "message_count": runs[0].message_count, + "query_count": runs[0].query_count, + "min_top_score": mean(run.min_top_score for run in runs), + "mean_top_score": mean(run.mean_top_score for run in runs), + "max_top_score": mean(run.max_top_score for run in runs), + "candidate_rows": averaged_rows, + "recommended_row": averaged_best_row, + "best_min_score_counts": best_min_score_counts, + "best_max_hits_counts": best_max_hits_counts, + } + + +def write_json(path: Path, data: object) -> None: + """Write a JSON artifact with stable indentation for review and reuse.""" + + path.write_text(json.dumps(data, indent=2), encoding="utf-8") + + +def write_markdown_summary(path: Path, summaries: list[dict[str, object]]) -> None: + """Write the reviewer-facing markdown summary for all benchmarked models.""" + + lines = [ + "# Repeated Embedding Benchmark Summary", + "", + "| Model | Runs | Recommended min_score | Recommended max_hits | Mean hit rate | Mean MRR |", + "| --- | ---: | ---: | ---: | ---: | ---: |", + ] + for summary in summaries: + recommended_row = summary["recommended_row"] + assert isinstance(recommended_row, dict) + lines.append( + "| " + f"{summary['resolved_model_name']} | " + f"{summary['run_count']} | " + f"{recommended_row['min_score']:.2f} | " + f"{recommended_row['max_hits']} | " + f"{recommended_row['mean_hit_rate']:.2f} | " + f"{recommended_row['mean_mrr']:.4f} |" + ) + lines.append("") + for summary in summaries: + lines.append( + f"- {summary['resolved_model_name']}: observed top-1 score range " + f"{summary['min_top_score']:.4f}..{summary['max_top_score']:.4f} " + f"(mean {summary['mean_top_score']:.4f})." + ) + lines.append("") + path.write_text("\n".join(lines), encoding="utf-8") + + +async def run_single_model_benchmark( + model_spec: str, + runs: int, + min_scores: list[float], + max_hits_values: list[int], + batch_size: int, + output_dir: Path, +) -> dict[str, object]: + """Run the benchmark repeatedly for one model and persist raw artifacts.""" + + repo_root = Path(__file__).resolve().parent.parent + message_texts = benchmark_embeddings.load_message_texts(repo_root) + query_cases = benchmark_embeddings.load_search_queries(repo_root) + model_output_dir = output_dir / sanitize_model_name(model_spec) + model_output_dir.mkdir(parents=True, exist_ok=True) + + run_results: list[RunResult] = [] + for run_index in range(1, runs + 1): + model, vector_base = await benchmark_embeddings.build_vector_base( + model_spec, + message_texts, + batch_size, + ) + query_embeddings = await model.get_embeddings( + [case.query for case in query_cases] + ) + top_score_stats = benchmark_embeddings.measure_top_score_stats( + vector_base, + query_embeddings, + ) + effective_min_scores, skipped_min_scores = ( + benchmark_embeddings.filter_min_scores_by_ceiling( + min_scores, + top_score_stats.max_top_score, + ) + ) + if not effective_min_scores: + raise ValueError( + "No requested min_score values are below the observed top-score ceiling " + f"of {top_score_stats.max_top_score:.4f} for {model.model_name}" + ) + if skipped_min_scores: + print( + f"Skipping {len(skipped_min_scores)} min_score values above " + f"{top_score_stats.max_top_score:.4f} for {model.model_name}" + ) + benchmark_rows: list[benchmark_embeddings.BenchmarkRow] = [] + for min_score in effective_min_scores: + for max_hits in max_hits_values: + metrics = benchmark_embeddings.evaluate_search_queries( + vector_base, + query_cases, + query_embeddings, + min_score, + max_hits, + ) + benchmark_rows.append( + benchmark_embeddings.BenchmarkRow(min_score, max_hits, metrics) + ) + + best_row = benchmark_embeddings.select_best_row(benchmark_rows) + run_result = RunResult( + run_index=run_index, + model_spec=model_spec, + resolved_model_name=model.model_name, + message_count=len(message_texts), + query_count=len(query_cases), + min_top_score=top_score_stats.min_top_score, + mean_top_score=top_score_stats.mean_top_score, + max_top_score=top_score_stats.max_top_score, + rows=[benchmark_row_to_run_row(row) for row in benchmark_rows], + best_row=benchmark_row_to_run_row(best_row), + ) + run_results.append(run_result) + write_json(model_output_dir / f"run_{run_index:02d}.json", asdict(run_result)) + + summary = summarize_runs(model_spec, run_results) + write_json(model_output_dir / "summary.json", summary) + return summary + + +async def run_repeated_benchmarks( + models: list[str], + runs: int, + min_scores: list[float], + max_hits_values: list[int], + batch_size: int, + output_root: Path, +) -> Path: + """Run the benchmark suite for each requested model and save the artifacts.""" + + timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + output_dir = output_root / timestamp + output_dir.mkdir(parents=True, exist_ok=True) + + metadata = { + "created_at_utc": timestamp, + "runs_per_model": runs, + "models": models, + "min_scores": min_scores, + "max_hits_values": max_hits_values, + "batch_size": batch_size, + } + write_json(output_dir / "metadata.json", metadata) + + summaries: list[dict[str, object]] = [] + for model_spec in models: + print(f"Running {runs} benchmark iterations for {model_spec}...") + summary = await run_single_model_benchmark( + model_spec=model_spec, + runs=runs, + min_scores=min_scores, + max_hits_values=max_hits_values, + batch_size=batch_size, + output_dir=output_dir, + ) + summaries.append(summary) + + write_json(output_dir / "summary.json", summaries) + write_markdown_summary(output_dir / "summary.md", summaries) + return output_dir + + +def parse_models(raw: str | None) -> list[str]: + """Parse the model list or fall back to the built-in OpenAI benchmark set.""" + + if raw is None: + return DEFAULT_MODELS + models = [item.strip() for item in raw.split(",") if item.strip()] + if not models: + raise ValueError("--models must contain at least one model") + return models + + +def main() -> None: + """Parse CLI arguments and run repeated embedding benchmarks.""" + + parser = argparse.ArgumentParser( + description="Run embedding benchmarks repeatedly and save JSON results." + ) + parser.add_argument( + "--models", + type=str, + default=None, + help="Comma-separated model specs to benchmark.", + ) + parser.add_argument( + "--runs", + type=int, + default=30, + help="Number of repeated runs per model.", + ) + parser.add_argument( + "--min-scores", + type=str, + default=None, + help="Comma-separated min_score values to test.", + ) + parser.add_argument( + "--min-score-start", + type=float, + default=None, + help="Inclusive start of a generated min_score range.", + ) + parser.add_argument( + "--min-score-stop", + type=float, + default=None, + help="Inclusive end of a generated min_score range.", + ) + parser.add_argument( + "--min-score-step", + type=float, + default=None, + help="Step size for a generated min_score range.", + ) + parser.add_argument( + "--max-hits", + type=str, + default=",".join(str(value) for value in DEFAULT_MAX_HITS), + help="Comma-separated max_hits values to test.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=16, + help="Batch size used when building the index.", + ) + parser.add_argument( + "--output-dir", + type=str, + default=str(DEFAULT_OUTPUT_DIR), + help="Directory where benchmark results will be written.", + ) + args = parser.parse_args() + + if args.runs <= 0: + raise ValueError("--runs must be a positive integer") + if args.batch_size <= 0: + raise ValueError("--batch-size must be a positive integer") + + load_dotenv() + output_dir = asyncio.run( + run_repeated_benchmarks( + models=parse_models(args.models), + runs=args.runs, + min_scores=resolve_min_scores( + args.min_scores, + args.min_score_start, + args.min_score_stop, + args.min_score_step, + ), + max_hits_values=parse_int_list(args.max_hits), + batch_size=args.batch_size, + output_root=Path(args.output_dir), + ) + ) + print(f"Wrote benchmark results to {output_dir}") + + +if __name__ == "__main__": + main() From 619c9ec1b8328dc87fc9755f5049305b3401fdf8 Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Wed, 22 Apr 2026 00:27:43 +0530 Subject: [PATCH 4/7] add tests --- tests/test_benchmark_embeddings.py | 103 ++++++++++++++++ tests/test_vectorbase.py | 113 +++++++++++++++-- tools/benchmark_embeddings.py | 188 +++++++++++++++++++++++++++-- 3 files changed, 383 insertions(+), 21 deletions(-) create mode 100644 tests/test_benchmark_embeddings.py diff --git a/tests/test_benchmark_embeddings.py b/tests/test_benchmark_embeddings.py new file mode 100644 index 00000000..6e822c26 --- /dev/null +++ b/tests/test_benchmark_embeddings.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path + +import pytest + +MODULE_PATH = ( + Path(__file__).resolve().parent.parent / "tools" / "benchmark_embeddings.py" +) +SPEC = spec_from_file_location("benchmark_embeddings_for_test", MODULE_PATH) +assert SPEC is not None +assert SPEC.loader is not None +BENCHMARK_EMBEDDINGS = module_from_spec(SPEC) +SPEC.loader.exec_module(BENCHMARK_EMBEDDINGS) + +BenchmarkRow = BENCHMARK_EMBEDDINGS.BenchmarkRow +SearchMetrics = BENCHMARK_EMBEDDINGS.SearchMetrics +build_float_range = BENCHMARK_EMBEDDINGS.build_float_range +filter_min_scores_by_ceiling = BENCHMARK_EMBEDDINGS.filter_min_scores_by_ceiling +load_message_texts = BENCHMARK_EMBEDDINGS.load_message_texts +parse_float_list = BENCHMARK_EMBEDDINGS.parse_float_list +resolve_min_scores = BENCHMARK_EMBEDDINGS.resolve_min_scores +select_best_row = BENCHMARK_EMBEDDINGS.select_best_row + + +def make_row( + min_score: float, + max_hits: int, + hit_rate: float, + mean_reciprocal_rank: float, +) -> BenchmarkRow: + """Build a benchmark row without repeating nested metrics boilerplate.""" + + return BenchmarkRow( + min_score=min_score, + max_hits=max_hits, + metrics=SearchMetrics( + hit_rate=hit_rate, + mean_reciprocal_rank=mean_reciprocal_rank, + ), + ) + + +def test_select_best_row_prefers_higher_min_score_on_metric_tie() -> None: + rows = [ + make_row(0.25, 15, 98.5, 0.7514), + make_row(0.70, 15, 98.5, 0.7514), + ] + + best_row = select_best_row(rows) + + assert best_row.min_score == 0.70 + assert best_row.max_hits == 15 + + +def test_select_best_row_prefers_lower_max_hits_on_full_tie() -> None: + rows = [ + make_row(0.70, 20, 98.5, 0.7514), + make_row(0.70, 15, 98.5, 0.7514), + ] + + best_row = select_best_row(rows) + + assert best_row.min_score == 0.70 + assert best_row.max_hits == 15 + + +def test_parse_float_list_defaults_to_tenth_point_grid() -> None: + assert parse_float_list(None) == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + + +def test_build_float_range_supports_hundredth_point_sweeps() -> None: + assert build_float_range(0.01, 0.05, 0.01) == [0.01, 0.02, 0.03, 0.04, 0.05] + + +def test_resolve_min_scores_uses_generated_range() -> None: + assert resolve_min_scores(None, 0.01, 0.03, 0.01) == [0.01, 0.02, 0.03] + + +def test_resolve_min_scores_rejects_mixed_inputs() -> None: + with pytest.raises(ValueError, match="Use either --min-scores"): + resolve_min_scores("0.1,0.2", 0.01, 0.03, 0.01) + + +def test_filter_min_scores_by_ceiling_skips_guaranteed_zero_rows() -> None: + effective_scores, skipped_scores = filter_min_scores_by_ceiling( + [0.01, 0.16, 0.17, 0.5], + 0.16, + ) + + assert effective_scores == [0.01, 0.16] + assert skipped_scores == [0.17, 0.5] + + +def test_load_message_texts_returns_one_text_blob_per_message() -> None: + repo_root = Path(__file__).resolve().parent.parent + + message_texts = load_message_texts(repo_root) + + assert message_texts + assert all(isinstance(text, str) for text in message_texts) diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 81ccecc6..bb9ebb57 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -7,11 +7,42 @@ from typeagent.aitools.embeddings import ( CachingEmbeddingModel, NormalizedEmbedding, + NormalizedEmbeddings, ) from typeagent.aitools.model_adapters import ( create_test_embedding_model, ) -from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase +from typeagent.aitools.vectorbase import ( + DEFAULT_MIN_SCORE, + TextEmbeddingIndexSettings, + VectorBase, +) + + +class FakeEmbeddingModel: + """Minimal embedding model stub for settings tests.""" + + def __init__(self, model_name: str) -> None: + self.model_name = model_name + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + del key, embedding + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + del input + return np.array([1.0], dtype=np.float32) + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + del input + return np.array([[1.0]], dtype=np.float32) + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + del key + return np.array([1.0], dtype=np.float32) + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + del keys + return np.array([[1.0]], dtype=np.float32) @pytest.fixture(scope="function") @@ -38,7 +69,7 @@ def sample_embeddings() -> Samples: } -def test_add_embedding(vector_base: VectorBase, sample_embeddings: Samples): +def test_add_embedding(vector_base: VectorBase, sample_embeddings: Samples) -> None: """Test adding embeddings to the VectorBase.""" for key, embedding in sample_embeddings.items(): vector_base.add_embedding(key, embedding) @@ -48,7 +79,7 @@ def test_add_embedding(vector_base: VectorBase, sample_embeddings: Samples): np.testing.assert_array_equal(vector_base.serialize_embedding_at(i), embedding) -def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): +def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples) -> None: """Adding multiple embeddings at once matches repeated single adds.""" keys = list(sample_embeddings.keys()) for key, embedding in sample_embeddings.items(): @@ -71,7 +102,7 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): @pytest.mark.asyncio -async def test_add_key(vector_base: VectorBase, sample_embeddings: Samples): +async def test_add_key(vector_base: VectorBase, sample_embeddings: Samples) -> None: """Test adding keys to the VectorBase.""" for key in sample_embeddings: await vector_base.add_key(key) @@ -80,7 +111,9 @@ async def test_add_key(vector_base: VectorBase, sample_embeddings: Samples): @pytest.mark.asyncio -async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samples): +async def test_add_key_no_cache( + vector_base: VectorBase, sample_embeddings: Samples +) -> None: """Test adding keys to the VectorBase with cache disabled.""" for key in sample_embeddings: await vector_base.add_key(key, cache=False) @@ -91,7 +124,7 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp @pytest.mark.asyncio -async def test_add_keys(vector_base: VectorBase, sample_embeddings: Samples): +async def test_add_keys(vector_base: VectorBase, sample_embeddings: Samples) -> None: """Test adding multiple keys to the VectorBase.""" keys = list(sample_embeddings.keys()) await vector_base.add_keys(keys) @@ -100,7 +133,9 @@ async def test_add_keys(vector_base: VectorBase, sample_embeddings: Samples): @pytest.mark.asyncio -async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Samples): +async def test_add_keys_no_cache( + vector_base: VectorBase, sample_embeddings: Samples +) -> None: """Test adding multiple keys to the VectorBase with cache disabled.""" keys = list(sample_embeddings.keys()) await vector_base.add_keys(keys, cache=False) @@ -111,7 +146,9 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam @pytest.mark.asyncio -async def test_fuzzy_lookup(vector_base: VectorBase, sample_embeddings: Samples): +async def test_fuzzy_lookup( + vector_base: VectorBase, sample_embeddings: Samples +) -> None: """Test fuzzy lookup functionality.""" for key in sample_embeddings: await vector_base.add_key(key) @@ -122,7 +159,7 @@ async def test_fuzzy_lookup(vector_base: VectorBase, sample_embeddings: Samples) assert results[0].score > 0.9 # High similarity score for the same word -def test_clear(vector_base: VectorBase, sample_embeddings: Samples): +def test_clear(vector_base: VectorBase, sample_embeddings: Samples) -> None: """Test clearing the VectorBase.""" for key, embedding in sample_embeddings.items(): vector_base.add_embedding(key, embedding) @@ -132,7 +169,9 @@ def test_clear(vector_base: VectorBase, sample_embeddings: Samples): assert len(vector_base) == 0 -def test_serialize_deserialize(vector_base: VectorBase, sample_embeddings: Samples): +def test_serialize_deserialize( + vector_base: VectorBase, sample_embeddings: Samples +) -> None: """Test serialization and deserialization of the VectorBase.""" for key, embedding in sample_embeddings.items(): vector_base.add_embedding(key, embedding) @@ -149,12 +188,12 @@ def test_serialize_deserialize(vector_base: VectorBase, sample_embeddings: Sampl ) -def test_vectorbase_bool(vector_base: VectorBase): +def test_vectorbase_bool(vector_base: VectorBase) -> None: """__bool__ should always return True.""" assert bool(vector_base) is True -def test_get_embedding_at(vector_base: VectorBase, sample_embeddings: Samples): +def test_get_embedding_at(vector_base: VectorBase, sample_embeddings: Samples) -> None: """Test get_embedding_at returns correct embedding and raises IndexError.""" for key, embedding in sample_embeddings.items(): vector_base.add_embedding(key, embedding) @@ -169,7 +208,7 @@ def test_get_embedding_at(vector_base: VectorBase, sample_embeddings: Samples): def test_fuzzy_lookup_embedding_in_subset( vector_base: VectorBase, sample_embeddings: Samples -): +) -> None: """Test fuzzy_lookup_embedding_in_subset returns best match in subset or None.""" keys = list(sample_embeddings.keys()) for key, embedding in sample_embeddings.items(): @@ -220,3 +259,51 @@ def test_add_embeddings_wrong_ndim(vector_base: VectorBase) -> None: emb1d = np.array([0.1, 0.2, 0.3], dtype=np.float32) with pytest.raises(ValueError, match="Expected 2D"): vector_base.add_embeddings(None, emb1d) + + +@pytest.mark.parametrize( + ("model_name", "expected_min_score"), + [ + ("text-embedding-3-large", 0.07), + ("text-embedding-3-small", 0.16), + ("text-embedding-ada-002", 0.72), + ], +) +def test_text_embedding_index_settings_uses_known_model_default( + model_name: str, expected_min_score: float +) -> None: + settings = TextEmbeddingIndexSettings( + embedding_model=FakeEmbeddingModel(model_name) + ) + + assert settings.min_score == expected_min_score + assert settings.max_matches is None + + +def test_text_embedding_index_settings_keeps_unknown_model_fallback() -> None: + settings = TextEmbeddingIndexSettings( + embedding_model=FakeEmbeddingModel("custom-embedding-model") + ) + + assert settings.min_score == DEFAULT_MIN_SCORE + assert settings.max_matches is None + + +def test_text_embedding_index_settings_explicit_overrides_win() -> None: + settings = TextEmbeddingIndexSettings( + embedding_model=FakeEmbeddingModel("text-embedding-3-large"), + min_score=0.55, + max_matches=7, + ) + + assert settings.min_score == 0.55 + assert settings.max_matches == 7 + + +def test_text_embedding_index_settings_invalid_max_matches_becomes_none() -> None: + settings = TextEmbeddingIndexSettings( + embedding_model=FakeEmbeddingModel("text-embedding-3-large"), + max_matches=0, + ) + + assert settings.max_matches is None diff --git a/tools/benchmark_embeddings.py b/tools/benchmark_embeddings.py index 4358ea31..81938961 100644 --- a/tools/benchmark_embeddings.py +++ b/tools/benchmark_embeddings.py @@ -16,22 +16,30 @@ Usage: uv run python tools/benchmark_embeddings.py uv run python tools/benchmark_embeddings.py --model openai:text-embedding-3-small + uv run python tools/benchmark_embeddings.py --model openai:text-embedding-3-small --min-score-start 0.01 --min-score-stop 0.20 --min-score-step 0.01 """ import argparse import asyncio from dataclasses import dataclass +from decimal import Decimal import json from pathlib import Path from statistics import mean from dotenv import load_dotenv -from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbeddings +from typeagent.aitools.embeddings import ( + IEmbeddingModel, + NormalizedEmbeddings, +) from typeagent.aitools.model_adapters import create_embedding_model -from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase +from typeagent.aitools.vectorbase import ( + TextEmbeddingIndexSettings, + VectorBase, +) -DEFAULT_MIN_SCORES = [0.25, 0.30, 0.35, 0.40, 0.50, 0.60, 0.70, 0.75, 0.80, 0.85] +DEFAULT_MIN_SCORES = [score / 10 for score in range(1, 10)] DEFAULT_MAX_HITS = [5, 10, 15, 20] DATA_DIR = Path("tests") / "testdata" INDEX_DATA_PATH = DATA_DIR / "Episode_53_AdrianTchaikovsky_index_data.json" @@ -40,24 +48,41 @@ @dataclass class SearchQueryCase: + """A benchmark query paired with the message ordinals it should retrieve.""" + query: str expected_matches: list[int] @dataclass class SearchMetrics: + """Aggregate retrieval quality metrics for one benchmark row.""" + hit_rate: float mean_reciprocal_rank: float +@dataclass +class TopScoreStats: + """Observed top-1 score statistics across all benchmark queries.""" + + min_top_score: float + mean_top_score: float + max_top_score: float + + @dataclass class BenchmarkRow: + """One `(min_score, max_hits)` configuration evaluated by the benchmark.""" + min_score: float max_hits: int metrics: SearchMetrics def parse_float_list(raw: str | None) -> list[float]: + """Parse explicit min-score values or fall back to the coarse default grid.""" + if raw is None: return DEFAULT_MIN_SCORES values = [float(item.strip()) for item in raw.split(",") if item.strip()] @@ -66,7 +91,54 @@ def parse_float_list(raw: str | None) -> list[float]: return values +def build_float_range(start: float, stop: float, step: float) -> list[float]: + """Build an inclusive decimal-safe float range for score sweeps.""" + + if step <= 0: + raise ValueError("--min-score-step must be positive") + if start > stop: + raise ValueError("--min-score-start must be <= --min-score-stop") + + start_decimal = Decimal(str(start)) + stop_decimal = Decimal(str(stop)) + step_decimal = Decimal(str(step)) + values: list[float] = [] + current = start_decimal + while current <= stop_decimal: + values.append(float(current)) + current += step_decimal + return values + + +def resolve_min_scores( + raw: str | None, + start: float | None, + stop: float | None, + step: float | None, +) -> list[float]: + """Resolve the benchmark min-score grid from explicit values or a generated range.""" + + range_args = [start, stop, step] + using_range = any(value is not None for value in range_args) + if using_range: + if raw is not None: + raise ValueError( + "Use either --min-scores or the --min-score-start/stop/step range" + ) + if any(value is None for value in range_args): + raise ValueError( + "--min-score-start, --min-score-stop, and --min-score-step must all be set together" + ) + assert start is not None + assert stop is not None + assert step is not None + return build_float_range(start, stop, step) + return parse_float_list(raw) + + def parse_int_list(raw: str | None) -> list[int]: + """Parse positive integer arguments such as `max_hits` grids.""" + if raw is None: return DEFAULT_MAX_HITS values = [int(item.strip()) for item in raw.split(",") if item.strip()] @@ -78,12 +150,16 @@ def parse_int_list(raw: str | None) -> list[int]: def load_message_texts(repo_root: Path) -> list[str]: + """Load the benchmark corpus as one text blob per message.""" + index_data = json.loads((repo_root / INDEX_DATA_PATH).read_text(encoding="utf-8")) messages = index_data["messages"] return [" ".join(message.get("textChunks", [])) for message in messages] def load_search_queries(repo_root: Path) -> list[SearchQueryCase]: + """Load benchmark queries that include message-level ground-truth matches.""" + search_data = json.loads( (repo_root / SEARCH_RESULTS_PATH).read_text(encoding="utf-8") ) @@ -105,6 +181,8 @@ async def build_vector_base( message_texts: list[str], batch_size: int, ) -> tuple[IEmbeddingModel, VectorBase]: + """Build a message-level vector index for the benchmark corpus.""" + model = create_embedding_model(model_spec) settings = TextEmbeddingIndexSettings( embedding_model=model, @@ -113,7 +191,6 @@ async def build_vector_base( batch_size=batch_size, ) vector_base = VectorBase(settings) - for start in range(0, len(message_texts), batch_size): batch = message_texts[start : start + batch_size] await vector_base.add_keys(batch) @@ -128,6 +205,8 @@ def evaluate_search_queries( min_score: float, max_hits: int, ) -> SearchMetrics: + """Evaluate one benchmark row over every labeled query.""" + hit_count = 0 reciprocal_ranks: list[float] = [] @@ -154,19 +233,59 @@ def evaluate_search_queries( ) +def measure_top_score_stats( + vector_base: VectorBase, + query_embeddings: NormalizedEmbeddings, +) -> TopScoreStats: + """Measure the achievable top-1 score range for the current model and corpus.""" + + top_scores: list[float] = [] + for query_embedding in query_embeddings: + scored_results = vector_base.fuzzy_lookup_embedding( + query_embedding, + max_hits=1, + min_score=0.0, + ) + top_scores.append(scored_results[0].score if scored_results else 0.0) + + return TopScoreStats( + min_top_score=min(top_scores), + mean_top_score=mean(top_scores), + max_top_score=max(top_scores), + ) + + +def filter_min_scores_by_ceiling( + min_scores: list[float], max_top_score: float +) -> tuple[list[float], list[float]]: + """Discard score thresholds that cannot return any results for this run.""" + + effective_scores = [ + min_score for min_score in min_scores if min_score <= max_top_score + 1e-9 + ] + skipped_scores = [ + min_score for min_score in min_scores if min_score > max_top_score + 1e-9 + ] + return effective_scores, skipped_scores + + def select_best_row(rows: list[BenchmarkRow]) -> BenchmarkRow: + """Prefer the strongest MRR/Hit Rate row, then the stricter score cutoff.""" + return max( rows, key=lambda row: ( row.metrics.mean_reciprocal_rank, row.metrics.hit_rate, - -row.min_score, + row.min_score, -row.max_hits, ), ) def print_rows(rows: list[BenchmarkRow]) -> None: + """Print the benchmark grid in a reviewer-friendly table.""" + print("=" * 72) print("SEARCH BENCHMARK (Episode 53 messageMatches ground truth)") print("=" * 72) @@ -187,6 +306,8 @@ async def run_benchmark( max_hits_values: list[int], batch_size: int, ) -> None: + """Run a single benchmark sweep and print the evaluated grid.""" + load_dotenv() repo_root = Path(__file__).resolve().parent.parent @@ -194,11 +315,25 @@ async def run_benchmark( query_cases = load_search_queries(repo_root) if not query_cases: raise ValueError("No search queries with messageMatches found in the dataset") - model, vector_base = await build_vector_base(model_spec, message_texts, batch_size) + model, vector_base = await build_vector_base( + model_spec, + message_texts, + batch_size, + ) query_embeddings = await model.get_embeddings([case.query for case in query_cases]) + top_score_stats = measure_top_score_stats(vector_base, query_embeddings) + effective_min_scores, skipped_min_scores = filter_min_scores_by_ceiling( + min_scores, + top_score_stats.max_top_score, + ) + if not effective_min_scores: + raise ValueError( + "No requested min_score values are below the observed top-score ceiling " + f"of {top_score_stats.max_top_score:.4f}" + ) rows: list[BenchmarkRow] = [] - for min_score in min_scores: + for min_score in effective_min_scores: for max_hits in max_hits_values: metrics = evaluate_search_queries( vector_base, @@ -212,6 +347,16 @@ async def run_benchmark( print(f"Model: {model.model_name}") print(f"Messages indexed: {len(message_texts)}") print(f"Queries evaluated: {len(query_cases)}") + print( + "Observed top-1 score range: " + f"{top_score_stats.min_top_score:.4f}..{top_score_stats.max_top_score:.4f} " + f"(mean {top_score_stats.mean_top_score:.4f})" + ) + if skipped_min_scores: + print( + f"Skipped {len(skipped_min_scores)} min_score values above " + f"{top_score_stats.max_top_score:.4f}; they cannot return any matches." + ) print() print_rows(rows) @@ -225,6 +370,8 @@ async def run_benchmark( def main() -> None: + """Parse CLI arguments and run the benchmark once.""" + parser = argparse.ArgumentParser( description="Benchmark retrieval settings for an embedding model." ) @@ -240,6 +387,24 @@ def main() -> None: default=None, help="Comma-separated min_score values to test.", ) + parser.add_argument( + "--min-score-start", + type=float, + default=None, + help="Inclusive start of a generated min_score range.", + ) + parser.add_argument( + "--min-score-stop", + type=float, + default=None, + help="Inclusive end of a generated min_score range.", + ) + parser.add_argument( + "--min-score-step", + type=float, + default=None, + help="Step size for a generated min_score range.", + ) parser.add_argument( "--max-hits", type=str, @@ -253,11 +418,18 @@ def main() -> None: help="Batch size used when building the index.", ) args = parser.parse_args() + if args.batch_size <= 0: + raise ValueError("--batch-size must be a positive integer") asyncio.run( run_benchmark( model_spec=args.model, - min_scores=parse_float_list(args.min_scores), + min_scores=resolve_min_scores( + args.min_scores, + args.min_score_start, + args.min_score_stop, + args.min_score_step, + ), max_hits_values=parse_int_list(args.max_hits), batch_size=args.batch_size, ) From 2303777c35d06aacf44c938e008a8a8b8175e155 Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Fri, 1 May 2026 14:30:00 +0530 Subject: [PATCH 5/7] Tune embedding defaults from real benchmark pipeline Update OpenAI embedding min_score defaults from the latest Episode 53 benchmark sweep: - text-embedding-3-small: 0.73 - text-embedding-3-large: 0.74 - text-embedding-ada-002: 0.93 The benchmark now recomputes corpus/query embeddings per model, ignores the serialized embedding sidecar, evaluates the real search pipeline, related-term retrieval, and answer-context signals, and records grid metadata including the min_score x max_hits row count. Keep normal ConversationSettings thresholds unchanged so application behavior does not silently adopt benchmark-only settings. Add tests covering normal app thresholds, benchmark-specific settings, normalized score scaling, fixture loading, sidecar handling, and benchmark metadata. --- pyproject.toml | 1 + src/typeagent/aitools/vectorbase.py | 74 +- src/typeagent/knowpro/convsettings.py | 7 +- tests/test_benchmark_embeddings.py | 340 +++++++++- tests/test_convsettings.py | 58 ++ tests/test_repeat_embedding_benchmarks.py | 60 ++ tests/test_vectorbase.py | 22 +- tools/benchmark_embeddings.py | 785 ++++++++++++++++++++-- tools/repeat_embedding_benchmarks.py | 307 +++++++-- 9 files changed, 1506 insertions(+), 148 deletions(-) create mode 100644 tests/test_convsettings.py create mode 100644 tests/test_repeat_embedding_benchmarks.py diff --git a/pyproject.toml b/pyproject.toml index 99371ae7..b994d3f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] [tool.pyright] +extraPaths = ["src", "tools"] reportUnusedVariable = true reportUnusedImport = true reportDuplicateImport = true diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 8f898ebf..7b5d3448 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -16,22 +16,22 @@ DEFAULT_MIN_SCORE = 0.85 # Empirical defaults for built-in OpenAI embedding models. -# These values come from repeated runs of the Adrian Tchaikovsky Episode 53 +# These values come from the Adrian Tchaikovsky Episode 53 # search benchmark in `tools/repeat_embedding_benchmarks.py`, using an -# exhaustive 0.01..1.00 min_score sweep on the Adrian Tchaikovsky Episode 53 -# dataset. We keep the highest min_score that preserves the best benchmark -# metrics for each model, which yielded the current plateau boundaries of 0.16 for -# `text-embedding-3-small`, 0.07 for `text-embedding-3-large`, and 0.72 for -# `text-embedding-ada-002`. These are repository defaults for known models, -# not universal truths. Unknown models keep the long-standing fallback score -# of 0.85. Callers can always override `min_score` explicitly for their own -# use cases or models. We intentionally leave `max_matches` out of this table: -# the benchmark still reports a best `max_hits` row, but the library default -# remains `None` unless a caller opts into a specific limit. +# exhaustive 0.00..1.00 min_score sweep. The benchmark recomputes corpus and +# query embeddings for each model and ignores the fixture's serialized +# embedding sidecar. Scores are normalized from cosine similarity to the public +# 0..1 min_score scale. +# These are repository defaults for known models, not universal truths. +# Unknown models keep the long-standing fallback score of 0.85. Callers can +# always override `min_score` explicitly for their own use cases or models. We +# intentionally leave `max_matches` out of this table: the benchmark still +# reports a best `max_hits` row, but the library default remains `None` unless +# a caller opts into a specific limit. MODEL_DEFAULT_MIN_SCORES: dict[str, float] = { - "text-embedding-3-large": 0.07, - "text-embedding-3-small": 0.16, - "text-embedding-ada-002": 0.72, + "text-embedding-3-large": 0.74, + "text-embedding-3-small": 0.73, + "text-embedding-ada-002": 0.93, } @@ -41,6 +41,12 @@ def get_default_min_score(model_name: str) -> float: return MODEL_DEFAULT_MIN_SCORES.get(model_name, DEFAULT_MIN_SCORE) +def cosine_to_score(cosine_similarity: np.ndarray) -> np.ndarray: + """Map cosine similarity from -1..1 to the public 0..1 score scale.""" + + return np.clip((cosine_similarity + 1.0) / 2.0, 0.0, 1.0) + + @dataclass class ScoredInt: """Associate an integer ordinal with its similarity score.""" @@ -109,20 +115,19 @@ def __bool__(self) -> bool: def add_embedding( self, key: str | None, embedding: NormalizedEmbedding | list[float] ) -> None: - if isinstance(embedding, list): - embedding = np.array(embedding, dtype=np.float32) + embedding_array = np.asarray(embedding, dtype=np.float32) if self._embedding_size == 0: - self._set_embedding_size(len(embedding)) + self._set_embedding_size(len(embedding_array)) self._vectors.shape = (0, self._embedding_size) - if len(embedding) != self._embedding_size: + if len(embedding_array) != self._embedding_size: raise ValueError( f"Embedding size mismatch: expected {self._embedding_size}, " - f"got {len(embedding)}" + f"got {len(embedding_array)}" ) - embeddings = embedding.reshape(1, -1) # Make it 2D: 1xN + embeddings = embedding_array.reshape(1, -1) # Make it 2D: 1xN self._vectors = np.append(self._vectors, embeddings, axis=0) if key is not None: - self._model.add_embedding(key, embedding) + self._model.add_embedding(key, embedding_array) def add_embeddings( self, keys: None | list[str], embeddings: NormalizedEmbeddings @@ -165,7 +170,7 @@ def fuzzy_lookup_embedding( min_score = 0.0 if len(self._vectors) == 0: return [] - scores = np.dot(self._vectors, embedding) + scores = cosine_to_score(np.dot(self._vectors, embedding)) if predicate is None: # Stay in numpy: filter by score, then top-k via argpartition. indices = np.flatnonzero(scores >= min_score) @@ -199,10 +204,27 @@ def fuzzy_lookup_embedding_in_subset( max_hits: int | None = None, min_score: float | None = None, ) -> list[ScoredInt]: - ordinals_set = set(ordinals_of_subset) - return self.fuzzy_lookup_embedding( - embedding, max_hits, min_score, lambda i: i in ordinals_set - ) + if max_hits is None: + max_hits = 10 + if min_score is None: + min_score = 0.0 + if not ordinals_of_subset or len(self._vectors) == 0: + return [] + # Compute dot products only for the subset instead of all vectors. + subset = np.asarray(ordinals_of_subset) + scores = cosine_to_score(np.dot(self._vectors[subset], embedding)) + indices = np.flatnonzero(scores >= min_score) + if len(indices) == 0: + return [] + filtered_scores = scores[indices] + if len(indices) <= max_hits: + order = np.argsort(filtered_scores)[::-1] + else: + top_k = np.argpartition(filtered_scores, -max_hits)[-max_hits:] + order = top_k[np.argsort(filtered_scores[top_k])[::-1]] + return [ + ScoredInt(int(subset[indices[i]]), float(filtered_scores[i])) for i in order + ] async def fuzzy_lookup( self, diff --git a/src/typeagent/knowpro/convsettings.py b/src/typeagent/knowpro/convsettings.py index 9dbf1214..30aa22ad 100644 --- a/src/typeagent/knowpro/convsettings.py +++ b/src/typeagent/knowpro/convsettings.py @@ -10,6 +10,9 @@ from ..aitools.vectorbase import TextEmbeddingIndexSettings from .interfaces import IKnowledgeExtractor, IStorageProvider +DEFAULT_RELATED_TERM_MIN_SCORE = 0.85 +DEFAULT_MESSAGE_TEXT_MIN_SCORE = 0.7 + @dataclass class MessageTextIndexSettings: @@ -45,13 +48,13 @@ def __init__( # All settings share the same model, so they share the embedding cache. model = model or create_embedding_model() self.embedding_model = model - min_score = 0.85 + min_score = DEFAULT_RELATED_TERM_MIN_SCORE self.related_term_index_settings = RelatedTermIndexSettings( TextEmbeddingIndexSettings(model, min_score=min_score, max_matches=50) ) self.thread_settings = TextEmbeddingIndexSettings(model, min_score=min_score) self.message_text_index_settings = MessageTextIndexSettings( - TextEmbeddingIndexSettings(model, min_score=0.7) + TextEmbeddingIndexSettings(model, min_score=DEFAULT_MESSAGE_TEXT_MIN_SCORE) ) self.semantic_ref_index_settings = SemanticRefIndexSettings( batch_size=4, # Effectively max concurrency diff --git a/tests/test_benchmark_embeddings.py b/tests/test_benchmark_embeddings.py index 6e822c26..5b6f6815 100644 --- a/tests/test_benchmark_embeddings.py +++ b/tests/test_benchmark_embeddings.py @@ -2,10 +2,14 @@ # Licensed under the MIT License. from importlib.util import module_from_spec, spec_from_file_location +import json from pathlib import Path +import numpy as np import pytest +from typeagent.aitools.embeddings import NormalizedEmbedding, NormalizedEmbeddings + MODULE_PATH = ( Path(__file__).resolve().parent.parent / "tools" / "benchmark_embeddings.py" ) @@ -16,23 +20,93 @@ SPEC.loader.exec_module(BENCHMARK_EMBEDDINGS) BenchmarkRow = BENCHMARK_EMBEDDINGS.BenchmarkRow +AnswerMetrics = BENCHMARK_EMBEDDINGS.AnswerMetrics +PipelineMetrics = BENCHMARK_EMBEDDINGS.PipelineMetrics +RelatedTermMetrics = BENCHMARK_EMBEDDINGS.RelatedTermMetrics +RelatedTermQueryCase = BENCHMARK_EMBEDDINGS.RelatedTermQueryCase SearchMetrics = BENCHMARK_EMBEDDINGS.SearchMetrics build_float_range = BENCHMARK_EMBEDDINGS.build_float_range -filter_min_scores_by_ceiling = BENCHMARK_EMBEDDINGS.filter_min_scores_by_ceiling +create_benchmark_conversation_settings = ( + BENCHMARK_EMBEDDINGS.create_benchmark_conversation_settings +) +evaluate_answer_queries = BENCHMARK_EMBEDDINGS.evaluate_answer_queries +evaluate_related_term_queries = BENCHMARK_EMBEDDINGS.evaluate_related_term_queries +load_corpus_metadata = BENCHMARK_EMBEDDINGS.load_corpus_metadata +load_pipeline_queries = BENCHMARK_EMBEDDINGS.load_pipeline_queries load_message_texts = BENCHMARK_EMBEDDINGS.load_message_texts +load_related_term_queries = BENCHMARK_EMBEDDINGS.load_related_term_queries +load_related_term_texts = BENCHMARK_EMBEDDINGS.load_related_term_texts parse_float_list = BENCHMARK_EMBEDDINGS.parse_float_list resolve_min_scores = BENCHMARK_EMBEDDINGS.resolve_min_scores select_best_row = BENCHMARK_EMBEDDINGS.select_best_row +class FakeEmbeddingModel: + """Minimal embedding model stub for settings tests.""" + + def __init__(self, model_name: str) -> None: + self.model_name = model_name + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + del key, embedding + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + del input + return np.array([1.0], dtype=np.float32) + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + del input + return np.array([[1.0]], dtype=np.float32) + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + del key + return np.array([1.0], dtype=np.float32) + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + del keys + return np.array([[1.0]], dtype=np.float32) + + def make_row( min_score: float, max_hits: int, hit_rate: float, mean_reciprocal_rank: float, + semantic_score: float | None = None, + pipeline_hit_rate: float | None = None, + pipeline_mean_reciprocal_rank: float | None = None, + related_hit_rate: float | None = None, + related_mean_reciprocal_rank: float | None = None, ) -> BenchmarkRow: """Build a benchmark row without repeating nested metrics boilerplate.""" + answer_metrics = ( + AnswerMetrics( + answerable_support=semantic_score / 100, + no_answer_rejection_rate=0.0, + semantic_score=semantic_score, + ) + if semantic_score is not None + else None + ) + related_metrics = ( + RelatedTermMetrics( + hit_rate=related_hit_rate, + mean_reciprocal_rank=related_mean_reciprocal_rank, + mean_result_count=10.0, + ) + if related_hit_rate is not None and related_mean_reciprocal_rank is not None + else None + ) + pipeline_metrics = ( + PipelineMetrics( + hit_rate=pipeline_hit_rate, + mean_reciprocal_rank=pipeline_mean_reciprocal_rank, + mean_result_count=10.0, + ) + if pipeline_hit_rate is not None and pipeline_mean_reciprocal_rank is not None + else None + ) return BenchmarkRow( min_score=min_score, max_hits=max_hits, @@ -40,13 +114,44 @@ def make_row( hit_rate=hit_rate, mean_reciprocal_rank=mean_reciprocal_rank, ), + pipeline_metrics=pipeline_metrics, + related_metrics=related_metrics, + answer_metrics=answer_metrics, + ) + + +@pytest.mark.parametrize( + ("model_name", "expected_min_score"), + [ + ("text-embedding-3-large", 0.74), + ("text-embedding-3-small", 0.73), + ("text-embedding-ada-002", 0.93), + ], +) +def test_benchmark_conversation_settings_use_model_default( + model_name: str, + expected_min_score: float, +) -> None: + settings = create_benchmark_conversation_settings(FakeEmbeddingModel(model_name)) + + assert ( + settings.related_term_index_settings.embedding_index_settings.min_score + == expected_min_score + ) + assert settings.thread_settings.min_score == expected_min_score + assert ( + settings.message_text_index_settings.embedding_index_settings.min_score + == expected_min_score + ) + assert ( + settings.related_term_index_settings.embedding_index_settings.max_matches == 50 ) def test_select_best_row_prefers_higher_min_score_on_metric_tie() -> None: rows = [ - make_row(0.25, 15, 98.5, 0.7514), - make_row(0.70, 15, 98.5, 0.7514), + make_row(0.25, 15, 98.5, 0.7514, semantic_score=90.0), + make_row(0.70, 15, 98.5, 0.7514, semantic_score=80.0), ] best_row = select_best_row(rows) @@ -55,6 +160,123 @@ def test_select_best_row_prefers_higher_min_score_on_metric_tie() -> None: assert best_row.max_hits == 15 +def test_select_best_row_only_uses_answer_context_after_min_score_tie() -> None: + rows = [ + make_row(0.70, 15, 98.5, 0.7514, semantic_score=80.0), + make_row(0.70, 15, 98.5, 0.7514, semantic_score=90.0), + ] + + best_row = select_best_row(rows) + + assert best_row.answer_metrics is not None + assert best_row.answer_metrics.semantic_score == 90.0 + + +def test_select_best_row_prefers_related_term_quality_before_message_quality() -> None: + rows = [ + make_row( + 0.80, + 15, + 98.5, + 0.90, + related_hit_rate=90.0, + related_mean_reciprocal_rank=0.70, + ), + make_row( + 0.70, + 15, + 98.5, + 0.80, + related_hit_rate=95.0, + related_mean_reciprocal_rank=0.75, + ), + ] + + best_row = select_best_row(rows) + + assert best_row.min_score == 0.70 + + +def test_select_best_row_prefers_pipeline_quality_before_related_term_quality() -> None: + rows = [ + make_row( + 0.80, + 15, + 98.5, + 0.90, + pipeline_hit_rate=90.0, + pipeline_mean_reciprocal_rank=0.70, + related_hit_rate=99.0, + related_mean_reciprocal_rank=0.99, + ), + make_row( + 0.70, + 15, + 98.5, + 0.80, + pipeline_hit_rate=95.0, + pipeline_mean_reciprocal_rank=0.75, + related_hit_rate=90.0, + related_mean_reciprocal_rank=0.70, + ), + ] + + best_row = select_best_row(rows) + + assert best_row.min_score == 0.70 + + +def test_evaluate_related_term_queries_scores_expected_terms() -> None: + vector_base = BENCHMARK_EMBEDDINGS.VectorBase( + BENCHMARK_EMBEDDINGS.TextEmbeddingIndexSettings( + BENCHMARK_EMBEDDINGS.create_embedding_model("test") + ) + ) + vector_base.add_embedding(None, np.array([1.0, 0.0], dtype=np.float32)) + vector_base.add_embedding(None, np.array([0.0, 1.0], dtype=np.float32)) + + metrics = evaluate_related_term_queries( + vector_base, + ["alpha", "beta"], + [RelatedTermQueryCase("query", ["beta"])], + np.array([[0.0, 1.0]], dtype=np.float32), + min_score=0.0, + max_hits=2, + ) + + assert metrics.hit_rate == 100.0 + assert metrics.mean_reciprocal_rank == 1.0 + assert metrics.mean_result_count == 2.0 + + +def test_evaluate_answer_queries_reports_normalized_support_score() -> None: + vector_base = BENCHMARK_EMBEDDINGS.VectorBase( + BENCHMARK_EMBEDDINGS.TextEmbeddingIndexSettings( + BENCHMARK_EMBEDDINGS.create_embedding_model("test") + ) + ) + vector_base.add_embedding(None, np.array([1.0, 0.0], dtype=np.float32)) + answer_cases = [ + BENCHMARK_EMBEDDINGS.AnswerQueryCase( + question="question", + answer="answer", + has_no_answer=False, + ) + ] + + metrics = evaluate_answer_queries( + vector_base, + answer_cases, + np.array([[0.0, 1.0]], dtype=np.float32), + np.array([[0.0, 1.0]], dtype=np.float32), + min_score=0.0, + max_hits=1, + ) + + assert metrics.answerable_support == 0.5 + assert metrics.semantic_score == 75.0 + + def test_select_best_row_prefers_lower_max_hits_on_full_tie() -> None: rows = [ make_row(0.70, 20, 98.5, 0.7514), @@ -84,16 +306,6 @@ def test_resolve_min_scores_rejects_mixed_inputs() -> None: resolve_min_scores("0.1,0.2", 0.01, 0.03, 0.01) -def test_filter_min_scores_by_ceiling_skips_guaranteed_zero_rows() -> None: - effective_scores, skipped_scores = filter_min_scores_by_ceiling( - [0.01, 0.16, 0.17, 0.5], - 0.16, - ) - - assert effective_scores == [0.01, 0.16] - assert skipped_scores == [0.17, 0.5] - - def test_load_message_texts_returns_one_text_blob_per_message() -> None: repo_root = Path(__file__).resolve().parent.parent @@ -101,3 +313,105 @@ def test_load_message_texts_returns_one_text_blob_per_message() -> None: assert message_texts assert all(isinstance(text, str) for text in message_texts) + + +def test_load_related_term_texts_returns_fixture_terms() -> None: + repo_root = Path(__file__).resolve().parent.parent + + terms = load_related_term_texts(repo_root) + + assert len(terms) == 1188 + assert "adrian tchaikovsky" in terms + + +def test_load_related_term_queries_returns_compiled_related_terms() -> None: + repo_root = Path(__file__).resolve().parent.parent + + cases = load_related_term_queries(repo_root) + + assert cases + assert all(case.expected_related_terms for case in cases) + + +def test_load_pipeline_queries_strips_cached_related_terms() -> None: + repo_root = Path(__file__).resolve().parent.parent + + cases = load_pipeline_queries(repo_root) + + assert cases + for case in cases: + for obj in BENCHMARK_EMBEDDINGS.iter_dicts( + case.query_exprs[0].__pydantic_serializer__.to_python( + case.query_exprs[0], + by_alias=True, + ) + ): + assert obj.get("relatedTerms") in (None, []) + + +def test_load_message_texts_ignores_serialized_embedding_sidecar( + tmp_path: Path, +) -> None: + testdata_dir = tmp_path / "tests" / "testdata" + testdata_dir.mkdir(parents=True) + (testdata_dir / "Episode_53_AdrianTchaikovsky_index_data.json").write_text( + json.dumps( + { + "messages": [ + {"textChunks": ["hello", "world"]}, + {"textChunks": ["goodbye"]}, + ], + "embeddingFileHeader": { + "messageCount": 2, + "relatedCount": 0, + "modelMetadata": {"embeddingSize": 1536}, + }, + } + ), + encoding="utf-8", + ) + (testdata_dir / "Episode_53_AdrianTchaikovsky_index_embeddings.bin").write_bytes( + b"not real embeddings" + ) + + message_texts = load_message_texts(tmp_path) + + assert message_texts == ["hello world", "goodbye"] + + +def test_load_corpus_metadata_reports_serialized_sidecar_details() -> None: + repo_root = Path(__file__).resolve().parent.parent + + metadata = load_corpus_metadata(repo_root) + + assert metadata.message_count > 0 + assert metadata.serialized_embedding_size == 1536 + assert metadata.serialized_message_count == 106 + assert metadata.serialized_related_count == 1188 + assert metadata.serialized_total_embedding_count == 1294 + + +def test_load_corpus_metadata_rejects_inconsistent_sidecar_size( + tmp_path: Path, +) -> None: + testdata_dir = tmp_path / "tests" / "testdata" + testdata_dir.mkdir(parents=True) + (testdata_dir / "Episode_53_AdrianTchaikovsky_index_data.json").write_text( + json.dumps( + { + "messages": [{"textChunks": ["hello"]}], + "embeddingFileHeader": { + "messageCount": 1, + "relatedCount": 1, + "modelMetadata": {"embeddingSize": 2}, + }, + } + ), + encoding="utf-8", + ) + (testdata_dir / "Episode_53_AdrianTchaikovsky_index_embeddings.bin").write_bytes( + b"bad-sidecar" + ) + + with pytest.raises(ValueError, match="Serialized benchmark sidecar size"): + load_corpus_metadata(tmp_path) diff --git a/tests/test_convsettings.py b/tests/test_convsettings.py new file mode 100644 index 00000000..4063583c --- /dev/null +++ b/tests/test_convsettings.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np + +from typeagent.aitools.embeddings import NormalizedEmbedding, NormalizedEmbeddings +from typeagent.knowpro.convsettings import ( + ConversationSettings, + DEFAULT_MESSAGE_TEXT_MIN_SCORE, + DEFAULT_RELATED_TERM_MIN_SCORE, +) + + +class FakeEmbeddingModel: + """Minimal embedding model stub for settings tests.""" + + def __init__(self, model_name: str) -> None: + self.model_name = model_name + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + del key, embedding + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + del input + return np.array([1.0], dtype=np.float32) + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + del input + return np.array([[1.0]], dtype=np.float32) + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + del key + return np.array([1.0], dtype=np.float32) + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + del keys + return np.array([[1.0]], dtype=np.float32) + + +def test_conversation_settings_keep_normal_application_thresholds() -> None: + settings = ConversationSettings(model=FakeEmbeddingModel("text-embedding-3-small")) + + assert ( + settings.related_term_index_settings.embedding_index_settings.min_score + == DEFAULT_RELATED_TERM_MIN_SCORE + ) + assert settings.thread_settings.min_score == DEFAULT_RELATED_TERM_MIN_SCORE + assert ( + settings.message_text_index_settings.embedding_index_settings.min_score + == DEFAULT_MESSAGE_TEXT_MIN_SCORE + ) + assert ( + settings.related_term_index_settings.embedding_index_settings.max_matches == 50 + ) + assert ( + settings.message_text_index_settings.embedding_index_settings.max_matches + is None + ) diff --git a/tests/test_repeat_embedding_benchmarks.py b/tests/test_repeat_embedding_benchmarks.py new file mode 100644 index 00000000..47f3e80d --- /dev/null +++ b/tests/test_repeat_embedding_benchmarks.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path +import sys + +MODULE_PATH = ( + Path(__file__).resolve().parent.parent / "tools" / "repeat_embedding_benchmarks.py" +) +sys.path.insert(0, str(MODULE_PATH.parent)) +SPEC = spec_from_file_location("repeat_embedding_benchmarks_for_test", MODULE_PATH) +assert SPEC is not None +assert SPEC.loader is not None +REPEAT_BENCHMARKS = module_from_spec(SPEC) +SPEC.loader.exec_module(REPEAT_BENCHMARKS) + +build_run_suite_metadata = REPEAT_BENCHMARKS.build_run_suite_metadata + + +def test_build_run_suite_metadata_records_ignored_sidecar() -> None: + repo_root = Path(__file__).resolve().parent.parent + + metadata = build_run_suite_metadata( + repo_root=repo_root, + timestamp="20260424T000000Z", + models=["openai:text-embedding-3-small"], + runs=3, + min_scores=[0.01, 0.02], + max_hits_values=[5, 10], + batch_size=16, + ) + + assert metadata["ignored_serialized_embedding_size"] == 1536 + assert metadata["ignored_serialized_message_embedding_count"] == 106 + assert metadata["ignored_serialized_related_embedding_count"] == 1188 + assert metadata["ignored_serialized_total_embedding_count"] == 1294 + assert metadata["pipeline_source_json"] == ( + "tests\\testdata\\Episode_53_Search_results.json" + ) + assert metadata["related_term_source_json"] == ( + "tests\\testdata\\Episode_53_Search_results.json" + ) + assert metadata["pipeline_scoring_paths"] == [ + "src/typeagent/knowpro/search.py::run_search_query", + "src/typeagent/knowpro/query.py::MatchSearchTermExpr.accumulate_matches_for_term", + "src/typeagent/knowpro/collections.py::SemanticRefAccumulator.add_term_matches", + "src/typeagent/knowpro/collections.py::add_smooth_related_score_to_match_score", + "src/typeagent/knowpro/query.py::message_matches_from_knowledge_matches", + "src/typeagent/knowpro/collections.py::MessageAccumulator.smooth_scores", + ] + assert metadata["corpus_embedding_source"] == ( + "recomputed_per_model_from_message_text" + ) + assert metadata["query_embedding_source"] == ( + "recomputed_per_model_from_search_text" + ) + assert metadata["min_score_count"] == 2 + assert metadata["max_hits_count"] == 2 + assert metadata["grid_row_count"] == 4 diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index bb9ebb57..763937b1 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -236,6 +236,22 @@ def test_fuzzy_lookup_embedding_in_subset( assert result == [] +def test_fuzzy_lookup_embedding_reports_normalized_score_scale() -> None: + vector_base = make_vector_base() + vector_base.add_embedding(None, np.array([1.0, 0.0], dtype=np.float32)) + vector_base.add_embedding(None, np.array([0.0, 1.0], dtype=np.float32)) + vector_base.add_embedding(None, np.array([-1.0, 0.0], dtype=np.float32)) + + results = vector_base.fuzzy_lookup_embedding( + np.array([1.0, 0.0], dtype=np.float32), + max_hits=3, + min_score=0.0, + ) + + assert [result.item for result in results] == [0, 1, 2] + assert [result.score for result in results] == [1.0, 0.5, 0.0] + + def test_add_embedding_size_mismatch(vector_base: VectorBase) -> None: """Adding an embedding of wrong size raises ValueError.""" emb3 = np.array([0.1, 0.2, 0.3], dtype=np.float32) @@ -264,9 +280,9 @@ def test_add_embeddings_wrong_ndim(vector_base: VectorBase) -> None: @pytest.mark.parametrize( ("model_name", "expected_min_score"), [ - ("text-embedding-3-large", 0.07), - ("text-embedding-3-small", 0.16), - ("text-embedding-ada-002", 0.72), + ("text-embedding-3-large", 0.74), + ("text-embedding-3-small", 0.73), + ("text-embedding-ada-002", 0.93), ], ) def test_text_embedding_index_settings_uses_known_model_default( diff --git a/tools/benchmark_embeddings.py b/tools/benchmark_embeddings.py index 81938961..96d1c036 100644 --- a/tools/benchmark_embeddings.py +++ b/tools/benchmark_embeddings.py @@ -7,11 +7,19 @@ `tests/testdata/` and reports retrieval quality for combinations of `min_score` and `max_hits`. +Methodology: +- Load only message text from the benchmark `_data.json` payload. +- Treat the serialized `_embeddings.bin` sidecar as metadata only. +- Recompute corpus and query embeddings with the requested model. + The benchmark is intentionally narrow: - It only measures retrieval against `messageMatches` ground truth. - It is meant to help choose repository defaults for known models. - In practice, `min_score` is the primary library default this informs. - It does not prove universal "best" settings for every dataset. +It also includes a semantic answer-context signal from the answer fixture: +- Answerable questions should retrieve messages close to the expected answer. +- No-answer questions should avoid high-confidence retrieved context. Usage: uv run python tools/benchmark_embeddings.py @@ -21,6 +29,7 @@ import argparse import asyncio +from copy import deepcopy from dataclasses import dataclass from decimal import Decimal import json @@ -28,6 +37,7 @@ from statistics import mean from dotenv import load_dotenv +import numpy as np from typeagent.aitools.embeddings import ( IEmbeddingModel, @@ -38,12 +48,39 @@ TextEmbeddingIndexSettings, VectorBase, ) +from typeagent.knowpro import search, secindex, serialization +from typeagent.knowpro.convsettings import ( + ConversationSettings, + MessageTextIndexSettings, + RelatedTermIndexSettings, +) +from typeagent.podcasts import podcast DEFAULT_MIN_SCORES = [score / 10 for score in range(1, 10)] DEFAULT_MAX_HITS = [5, 10, 15, 20] DATA_DIR = Path("tests") / "testdata" +INDEX_PREFIX_PATH = DATA_DIR / "Episode_53_AdrianTchaikovsky_index" INDEX_DATA_PATH = DATA_DIR / "Episode_53_AdrianTchaikovsky_index_data.json" +INDEX_EMBEDDINGS_PATH = DATA_DIR / "Episode_53_AdrianTchaikovsky_index_embeddings.bin" SEARCH_RESULTS_PATH = DATA_DIR / "Episode_53_Search_results.json" +ANSWER_RESULTS_PATH = DATA_DIR / "Episode_53_Answer_results.json" +CORPUS_EMBEDDING_SOURCE = "recomputed_per_model_from_message_text" +QUERY_EMBEDDING_SOURCE = "recomputed_per_model_from_search_text" +ANSWER_EMBEDDING_SOURCE = "recomputed_per_model_from_expected_answer_text" +PIPELINE_SCORING_PATHS = [ + "src/typeagent/knowpro/search.py::run_search_query", + "src/typeagent/knowpro/query.py::MatchSearchTermExpr.accumulate_matches_for_term", + "src/typeagent/knowpro/collections.py::SemanticRefAccumulator.add_term_matches", + "src/typeagent/knowpro/collections.py::add_smooth_related_score_to_match_score", + "src/typeagent/knowpro/query.py::message_matches_from_knowledge_matches", + "src/typeagent/knowpro/collections.py::MessageAccumulator.smooth_scores", +] + + +def score_from_cosine(cosine_similarity: np.ndarray) -> np.ndarray: + """Map cosine similarity from -1..1 to the public 0..1 score scale.""" + + return np.clip((cosine_similarity + 1.0) / 2.0, 0.0, 1.0) @dataclass @@ -62,6 +99,59 @@ class SearchMetrics: mean_reciprocal_rank: float +@dataclass +class PipelineQueryCase: + """A compiled query fixture with message-level ground truth.""" + + query: str + query_exprs: list[search.SearchQueryExpr] + expected_matches: list[int] + + +@dataclass +class PipelineMetrics: + """Aggregate metrics from the real query scoring pipeline.""" + + hit_rate: float + mean_reciprocal_rank: float + mean_result_count: float + + +@dataclass +class RelatedTermQueryCase: + """A search term paired with related terms from the compiled query fixture.""" + + term: str + expected_related_terms: list[str] + + +@dataclass +class RelatedTermMetrics: + """Aggregate fuzzy related-term retrieval metrics for one benchmark row.""" + + hit_rate: float + mean_reciprocal_rank: float + mean_result_count: float + + +@dataclass +class AnswerQueryCase: + """A benchmark answer case paired with its expected answerability.""" + + question: str + answer: str + has_no_answer: bool + + +@dataclass +class AnswerMetrics: + """Aggregate semantic answer-context metrics for one benchmark row.""" + + answerable_support: float + no_answer_rejection_rate: float + semantic_score: float + + @dataclass class TopScoreStats: """Observed top-1 score statistics across all benchmark queries.""" @@ -78,6 +168,20 @@ class BenchmarkRow: min_score: float max_hits: int metrics: SearchMetrics + pipeline_metrics: PipelineMetrics | None = None + related_metrics: RelatedTermMetrics | None = None + answer_metrics: AnswerMetrics | None = None + + +@dataclass +class CorpusMetadata: + """Metadata about the serialized benchmark corpus fixture.""" + + message_count: int + serialized_embedding_size: int | None + serialized_message_count: int | None + serialized_related_count: int | None + serialized_total_embedding_count: int | None def parse_float_list(raw: str | None) -> list[float]: @@ -150,13 +254,79 @@ def parse_int_list(raw: str | None) -> list[int]: def load_message_texts(repo_root: Path) -> list[str]: - """Load the benchmark corpus as one text blob per message.""" + """Load the benchmark corpus as one text blob per message. + + The JSON fixture also points at a serialized embedding sidecar, but that + sidecar is deliberately ignored here. Cross-model comparisons are only + meaningful when every evaluated model embeds the same raw message text. + """ index_data = json.loads((repo_root / INDEX_DATA_PATH).read_text(encoding="utf-8")) messages = index_data["messages"] return [" ".join(message.get("textChunks", [])) for message in messages] +def load_related_term_texts(repo_root: Path) -> list[str]: + """Load the term corpus used by fuzzy related-term lookup.""" + + index_data = json.loads((repo_root / INDEX_DATA_PATH).read_text(encoding="utf-8")) + related_terms_index_data = index_data.get("relatedTermsIndexData") or {} + text_embedding_data = related_terms_index_data.get("textEmbeddingData") or {} + text_items = text_embedding_data.get("textItems") + if isinstance(text_items, list) and text_items: + return [text for text in text_items if isinstance(text, str)] + + semantic_index_data = index_data.get("semanticIndexData") or {} + items = semantic_index_data.get("items") or [] + return [item["term"] for item in items if isinstance(item.get("term"), str)] + + +def load_corpus_metadata(repo_root: Path) -> CorpusMetadata: + """Load sidecar metadata without loading the sidecar embeddings.""" + + index_data = json.loads((repo_root / INDEX_DATA_PATH).read_text(encoding="utf-8")) + embedding_file_header = index_data.get("embeddingFileHeader") or {} + model_metadata = embedding_file_header.get("modelMetadata") or {} + serialized_embedding_size = model_metadata.get("embeddingSize") + serialized_message_count = embedding_file_header.get("messageCount") + serialized_related_count = embedding_file_header.get("relatedCount") + serialized_total_embedding_count: int | None = None + + sidecar_path = repo_root / INDEX_EMBEDDINGS_PATH + if serialized_embedding_size is not None and sidecar_path.exists(): + bytes_per_embedding = serialized_embedding_size * np.dtype(np.float32).itemsize + if bytes_per_embedding <= 0: + raise ValueError( + "Serialized benchmark corpus has a non-positive embedding size" + ) + sidecar_size_bytes = sidecar_path.stat().st_size + if sidecar_size_bytes % bytes_per_embedding != 0: + raise ValueError( + "Serialized benchmark sidecar size is not divisible by the declared " + f"embedding width of {serialized_embedding_size}" + ) + serialized_total_embedding_count = sidecar_size_bytes // bytes_per_embedding + declared_total_count = (serialized_message_count or 0) + ( + serialized_related_count or 0 + ) + if ( + declared_total_count + and declared_total_count != serialized_total_embedding_count + ): + raise ValueError( + "Serialized benchmark sidecar row count does not match the counts " + "declared in the JSON metadata" + ) + + return CorpusMetadata( + message_count=len(index_data.get("messages", [])), + serialized_embedding_size=serialized_embedding_size, + serialized_message_count=serialized_message_count, + serialized_related_count=serialized_related_count, + serialized_total_embedding_count=serialized_total_embedding_count, + ) + + def load_search_queries(repo_root: Path) -> list[SearchQueryCase]: """Load benchmark queries that include message-level ground-truth matches.""" @@ -176,14 +346,196 @@ def load_search_queries(repo_root: Path) -> list[SearchQueryCase]: return cases +def strip_related_terms(value: object) -> None: + """Remove cached related-term expansions so each model resolves its own.""" + + for obj in iter_dicts(value): + related_terms = obj.get("relatedTerms") + if isinstance(related_terms, list) and related_terms: + obj["relatedTerms"] = None + + +def load_pipeline_queries(repo_root: Path) -> list[PipelineQueryCase]: + """Load compiled query fixtures for the real semantic scoring pipeline.""" + + search_data = json.loads( + (repo_root / SEARCH_RESULTS_PATH).read_text(encoding="utf-8") + ) + cases: list[PipelineQueryCase] = [] + for item in search_data: + search_text = item.get("searchText") + compiled_query_expr = item.get("compiledQueryExpr") + results = item.get("results", []) + if not ( + isinstance(search_text, str) + and isinstance(compiled_query_expr, list) + and results + ): + continue + expected_matches = results[0].get("messageMatches", []) + if not expected_matches: + continue + strip_related_terms(compiled_query_expr) + query_exprs = serialization.deserialize_object( + list[search.SearchQueryExpr], + compiled_query_expr, + ) + cases.append( + PipelineQueryCase( + query=search_text, + query_exprs=query_exprs, + expected_matches=expected_matches, + ) + ) + return cases + + +def iter_dicts(value: object): + """Yield dictionaries recursively from a decoded JSON value.""" + + if isinstance(value, dict): + yield value + for child in value.values(): + yield from iter_dicts(child) + elif isinstance(value, list): + for child in value: + yield from iter_dicts(child) + + +def load_related_term_queries(repo_root: Path) -> list[RelatedTermQueryCase]: + """Load expected fuzzy related-term outputs from compiled query fixtures. + + These compiled fixtures are closer to the real query pipeline than raw + query-to-message similarity: `min_score` normally gates fuzzy related-term + expansion before semantic-ref and message scores are accumulated. + """ + + search_data = json.loads( + (repo_root / SEARCH_RESULTS_PATH).read_text(encoding="utf-8") + ) + cases: list[RelatedTermQueryCase] = [] + seen: set[tuple[str, tuple[str, ...]]] = set() + for item in search_data: + for obj in iter_dicts(item.get("compiledQueryExpr", [])): + term = obj.get("term") + related_terms = obj.get("relatedTerms") + if not ( + isinstance(term, dict) + and isinstance(term.get("text"), str) + and isinstance(related_terms, list) + and related_terms + ): + continue + expected: list[str] = [] + for related in related_terms: + if isinstance(related, dict): + related_text = related.get("text") + if isinstance(related_text, str): + expected.append(related_text) + if not expected: + continue + key = (term["text"], tuple(expected)) + if key not in seen: + seen.add(key) + cases.append(RelatedTermQueryCase(term["text"], expected)) + return cases + + +def load_answer_queries(repo_root: Path) -> list[AnswerQueryCase]: + """Load expected answers for semantic answer-context benchmarking.""" + + answer_data = json.loads( + (repo_root / ANSWER_RESULTS_PATH).read_text(encoding="utf-8") + ) + cases: list[AnswerQueryCase] = [] + for item in answer_data: + question = item.get("question") + answer = item.get("answer") + has_no_answer = item.get("hasNoAnswer") + if ( + isinstance(question, str) + and isinstance(answer, str) + and isinstance(has_no_answer, bool) + ): + cases.append(AnswerQueryCase(question, answer, has_no_answer)) + return cases + + async def build_vector_base( model_spec: str | None, message_texts: list[str], batch_size: int, ) -> tuple[IEmbeddingModel, VectorBase]: - """Build a message-level vector index for the benchmark corpus.""" + """Build a message-level vector index for the benchmark corpus. + + This computes fresh embeddings for `message_texts` with the requested + model. It does not deserialize or consult the fixture's `_embeddings.bin` + sidecar, which may have been generated by a different embedding model. + """ model = create_embedding_model(model_spec) + vector_base = await build_text_vector_base(model, message_texts, batch_size) + return model, vector_base + + +async def build_pipeline_conversation( + repo_root: Path, + model: IEmbeddingModel, +) -> podcast.Podcast: + """Build the benchmark conversation with per-model secondary indexes. + + The fixture's serialized embedding sidecar is deliberately not used here. + We keep the semantic refs and exact semantic index, then rebuild related-term + and message-text indexes with the requested model so `min_score` gates the + same fuzzy expansion path the runtime uses. + """ + + settings = create_benchmark_conversation_settings(model) + data = podcast.Podcast._read_conversation_data_from_file( + str(repo_root / INDEX_PREFIX_PATH) + ) + data.pop("relatedTermsIndexData", None) + data.pop("messageIndexData", None) + conversation = await podcast.Podcast.create(settings) + await conversation.deserialize(data) + await secindex.build_secondary_indexes(conversation, settings) + return conversation + + +def create_benchmark_conversation_settings( + model: IEmbeddingModel, +) -> ConversationSettings: + """Use benchmarked model defaults without changing normal app settings.""" + + settings = ConversationSettings(model=model) + benchmark_min_score = TextEmbeddingIndexSettings(model).min_score + settings.related_term_index_settings = RelatedTermIndexSettings( + TextEmbeddingIndexSettings( + model, + min_score=benchmark_min_score, + max_matches=50, + ) + ) + settings.thread_settings = TextEmbeddingIndexSettings( + model, + min_score=benchmark_min_score, + ) + settings.message_text_index_settings = MessageTextIndexSettings( + TextEmbeddingIndexSettings( + model, + min_score=benchmark_min_score, + ) + ) + return settings + + +async def build_text_vector_base( + model: IEmbeddingModel, + texts: list[str], + batch_size: int, +) -> VectorBase: + """Build a vector index for already selected benchmark text items.""" + settings = TextEmbeddingIndexSettings( embedding_model=model, min_score=0.0, @@ -191,11 +543,11 @@ async def build_vector_base( batch_size=batch_size, ) vector_base = VectorBase(settings) - for start in range(0, len(message_texts), batch_size): - batch = message_texts[start : start + batch_size] - await vector_base.add_keys(batch) - - return model, vector_base + for start in range(0, len(texts), batch_size): + batch = texts[start : start + batch_size] + embeddings = await model.get_embeddings_nocache(batch) + vector_base.add_embeddings(None, embeddings) + return vector_base def evaluate_search_queries( @@ -233,6 +585,240 @@ def evaluate_search_queries( ) +async def evaluate_pipeline_queries( + conversation: podcast.Podcast, + query_cases: list[PipelineQueryCase], + min_score: float, + max_hits: int, +) -> PipelineMetrics: + """Evaluate compiled queries through the runtime semantic scoring path.""" + + related_settings = ( + conversation.settings.related_term_index_settings.embedding_index_settings + ) + related_settings.min_score = min_score + hit_count = 0 + reciprocal_ranks: list[float] = [] + result_counts: list[int] = [] + options = search.SearchOptions(max_message_matches=max_hits) + + for case in query_cases: + query_exprs = deepcopy(case.query_exprs) + scored_results = [] + for query_expr in query_exprs: + search_results = await search.run_search_query( + conversation, + query_expr, + options, + ) + for result in search_results: + scored_results.extend(result.message_matches) + scored_results.sort(key=lambda result: result.score, reverse=True) + result_counts.append(len(scored_results)) + expected_matches = set(case.expected_matches) + rank = 0 + for result_index, scored_result in enumerate(scored_results, start=1): + if scored_result.message_ordinal in expected_matches: + rank = result_index + break + if rank > 0: + hit_count += 1 + reciprocal_ranks.append(1.0 / rank) + else: + reciprocal_ranks.append(0.0) + + return PipelineMetrics( + hit_rate=(hit_count / len(query_cases)) * 100, + mean_reciprocal_rank=mean(reciprocal_ranks), + mean_result_count=mean(result_counts), + ) + + +def evaluate_related_term_queries( + vector_base: VectorBase, + related_terms: list[str], + query_cases: list[RelatedTermQueryCase], + query_embeddings: NormalizedEmbeddings, + min_score: float, + max_hits: int, +) -> RelatedTermMetrics: + """Evaluate fuzzy related-term retrieval against compiled query fixtures.""" + + hit_count = 0 + reciprocal_ranks: list[float] = [] + result_counts: list[int] = [] + + for case, query_embedding in zip(query_cases, query_embeddings): + expected_terms = set(case.expected_related_terms) + scored_results = vector_base.fuzzy_lookup_embedding( + query_embedding, + max_hits=max_hits, + min_score=min_score, + ) + result_counts.append(len(scored_results)) + rank = 0 + for result_index, scored_result in enumerate(scored_results, start=1): + if related_terms[scored_result.item] in expected_terms: + rank = result_index + break + if rank > 0: + hit_count += 1 + reciprocal_ranks.append(1.0 / rank) + else: + reciprocal_ranks.append(0.0) + + return RelatedTermMetrics( + hit_rate=(hit_count / len(query_cases)) * 100, + mean_reciprocal_rank=mean(reciprocal_ranks), + mean_result_count=mean(result_counts), + ) + + +def evaluate_answer_queries( + vector_base: VectorBase, + answer_cases: list[AnswerQueryCase], + question_embeddings: NormalizedEmbeddings, + answer_embeddings: NormalizedEmbeddings, + min_score: float, + max_hits: int, +) -> AnswerMetrics: + """Evaluate whether retrieved message context semantically supports answers.""" + + answerable_support_scores: list[float] = [] + no_answer_rejections = 0 + no_answer_count = 0 + corpus_embeddings = vector_base.serialize() + + for case, question_embedding, answer_embedding in zip( + answer_cases, + question_embeddings, + answer_embeddings, + strict=True, + ): + scored_results = vector_base.fuzzy_lookup_embedding( + question_embedding, + max_hits=max_hits, + min_score=min_score, + ) + if case.has_no_answer: + no_answer_count += 1 + if not scored_results: + no_answer_rejections += 1 + continue + + if not scored_results: + answerable_support_scores.append(0.0) + continue + + retrieved_embeddings = corpus_embeddings[ + [scored_result.item for scored_result in scored_results] + ] + scores = score_from_cosine(np.dot(retrieved_embeddings, answer_embedding)) + answerable_support_scores.append(float(np.max(scores))) + + answerable_support = ( + mean(answerable_support_scores) if answerable_support_scores else 0.0 + ) + no_answer_rejection_rate = ( + (no_answer_rejections / no_answer_count) * 100 if no_answer_count else 100.0 + ) + semantic_score = (answerable_support * 100 + no_answer_rejection_rate) / 2 + return AnswerMetrics( + answerable_support=answerable_support, + no_answer_rejection_rate=no_answer_rejection_rate, + semantic_score=semantic_score, + ) + + +async def evaluate_grid( + vector_base: VectorBase, + query_cases: list[SearchQueryCase], + query_embeddings: NormalizedEmbeddings, + min_scores: list[float], + max_hits_values: list[int], + pipeline_conversation: podcast.Podcast | None = None, + pipeline_query_cases: list[PipelineQueryCase] | None = None, + related_vector_base: VectorBase | None = None, + related_terms: list[str] | None = None, + related_query_cases: list[RelatedTermQueryCase] | None = None, + related_query_embeddings: NormalizedEmbeddings | None = None, + answer_cases: list[AnswerQueryCase] | None = None, + answer_question_embeddings: NormalizedEmbeddings | None = None, + answer_embeddings: NormalizedEmbeddings | None = None, + progress_label: str | None = None, +) -> list[BenchmarkRow]: + """Evaluate every `(min_score, max_hits)` row in the requested grid.""" + + rows: list[BenchmarkRow] = [] + for min_score_index, min_score in enumerate(min_scores, start=1): + if progress_label and ( + min_score_index == 1 + or min_score_index == len(min_scores) + or min_score_index % 10 == 0 + ): + print( + f"{progress_label}: min_score {min_score:.2f} " + f"({min_score_index}/{len(min_scores)})...", + flush=True, + ) + for max_hits in max_hits_values: + metrics = evaluate_search_queries( + vector_base, + query_cases, + query_embeddings, + min_score, + max_hits, + ) + pipeline_metrics = None + if pipeline_conversation is not None and pipeline_query_cases is not None: + pipeline_metrics = await evaluate_pipeline_queries( + pipeline_conversation, + pipeline_query_cases, + min_score, + max_hits, + ) + related_metrics = None + if ( + related_vector_base is not None + and related_terms is not None + and related_query_cases is not None + and related_query_embeddings is not None + ): + related_metrics = evaluate_related_term_queries( + related_vector_base, + related_terms, + related_query_cases, + related_query_embeddings, + min_score, + max_hits, + ) + answer_metrics = None + if ( + answer_cases is not None + and answer_question_embeddings is not None + and answer_embeddings is not None + ): + answer_metrics = evaluate_answer_queries( + vector_base, + answer_cases, + answer_question_embeddings, + answer_embeddings, + min_score, + max_hits, + ) + rows.append( + BenchmarkRow( + min_score, + max_hits, + metrics, + pipeline_metrics, + related_metrics, + answer_metrics, + ) + ) + return rows + + def measure_top_score_stats( vector_base: VectorBase, query_embeddings: NormalizedEmbeddings, @@ -258,26 +844,26 @@ def measure_top_score_stats( def filter_min_scores_by_ceiling( min_scores: list[float], max_top_score: float ) -> tuple[list[float], list[float]]: - """Discard score thresholds that cannot return any results for this run.""" + """Keep the requested min-score grid intact.""" - effective_scores = [ - min_score for min_score in min_scores if min_score <= max_top_score + 1e-9 - ] - skipped_scores = [ - min_score for min_score in min_scores if min_score > max_top_score + 1e-9 - ] - return effective_scores, skipped_scores + _ = max_top_score + return list(min_scores), [] def select_best_row(rows: list[BenchmarkRow]) -> BenchmarkRow: - """Prefer the strongest MRR/Hit Rate row, then the stricter score cutoff.""" + """Prefer true pipeline quality, then related-term quality and strictness.""" return max( rows, key=lambda row: ( + row.pipeline_metrics.mean_reciprocal_rank if row.pipeline_metrics else 0.0, + row.pipeline_metrics.hit_rate if row.pipeline_metrics else 0.0, + row.related_metrics.mean_reciprocal_rank if row.related_metrics else 0.0, + row.related_metrics.hit_rate if row.related_metrics else 0.0, row.metrics.mean_reciprocal_rank, row.metrics.hit_rate, row.min_score, + row.answer_metrics.semantic_score if row.answer_metrics else 0.0, -row.max_hits, ), ) @@ -287,17 +873,70 @@ def print_rows(rows: list[BenchmarkRow]) -> None: """Print the benchmark grid in a reviewer-friendly table.""" print("=" * 72) - print("SEARCH BENCHMARK (Episode 53 messageMatches ground truth)") + print("PIPELINE + SEARCH + ANSWER-CONTEXT BENCHMARK (Episode 53 fixtures)") print("=" * 72) - print(f"{'Min Score':<12} | {'Max Hits':<10} | {'Hit Rate (%)':<15} | {'MRR':<10}") - print("-" * 65) + print( + f"{'Min Score':<12} | {'Max Hits':<10} | {'Hit Rate (%)':<15} | " + f"{'MRR':<10} | {'Pipe Hit':<10} | {'Pipe MRR':<10} | " + f"{'Pipe Cnt':<10} | {'Rel Hit':<10} | {'Rel MRR':<10} | " + f"{'Rel Cnt':<10} | {'Ans Sup':<10} | {'NoAns (%)':<10} | {'Sem':<10}" + ) + print("-" * 174) for row in rows: + pipeline_hit_rate = ( + f"{row.pipeline_metrics.hit_rate:<10.2f}" + if row.pipeline_metrics + else f"{'n/a':<10}" + ) + pipeline_mrr = ( + f"{row.pipeline_metrics.mean_reciprocal_rank:<10.4f}" + if row.pipeline_metrics + else f"{'n/a':<10}" + ) + pipeline_count = ( + f"{row.pipeline_metrics.mean_result_count:<10.2f}" + if row.pipeline_metrics + else f"{'n/a':<10}" + ) + related_hit_rate = ( + f"{row.related_metrics.hit_rate:<10.2f}" + if row.related_metrics + else f"{'n/a':<10}" + ) + related_mrr = ( + f"{row.related_metrics.mean_reciprocal_rank:<10.4f}" + if row.related_metrics + else f"{'n/a':<10}" + ) + related_count = ( + f"{row.related_metrics.mean_result_count:<10.2f}" + if row.related_metrics + else f"{'n/a':<10}" + ) + answer_support = ( + f"{row.answer_metrics.answerable_support:<10.4f}" + if row.answer_metrics + else f"{'n/a':<10}" + ) + no_answer = ( + f"{row.answer_metrics.no_answer_rejection_rate:<10.2f}" + if row.answer_metrics + else f"{'n/a':<10}" + ) + semantic_score = ( + f"{row.answer_metrics.semantic_score:<10.2f}" + if row.answer_metrics + else f"{'n/a':<10}" + ) print( f"{row.min_score:<12.2f} | {row.max_hits:<10d} | " f"{row.metrics.hit_rate:<15.2f} | " - f"{row.metrics.mean_reciprocal_rank:<10.4f}" + f"{row.metrics.mean_reciprocal_rank:<10.4f} | " + f"{pipeline_hit_rate} | {pipeline_mrr} | {pipeline_count} | " + f"{related_hit_rate} | {related_mrr} | {related_count} | " + f"{answer_support} | {no_answer} | {semantic_score}" ) - print("-" * 65) + print("-" * 174) async def run_benchmark( @@ -312,7 +951,12 @@ async def run_benchmark( repo_root = Path(__file__).resolve().parent.parent message_texts = load_message_texts(repo_root) + related_terms = load_related_term_texts(repo_root) + corpus_metadata = load_corpus_metadata(repo_root) query_cases = load_search_queries(repo_root) + pipeline_query_cases = load_pipeline_queries(repo_root) + related_query_cases = load_related_term_queries(repo_root) + answer_cases = load_answer_queries(repo_root) if not query_cases: raise ValueError("No search queries with messageMatches found in the dataset") model, vector_base = await build_vector_base( @@ -320,43 +964,71 @@ async def run_benchmark( message_texts, batch_size, ) - query_embeddings = await model.get_embeddings([case.query for case in query_cases]) + related_vector_base = await build_text_vector_base( + model, + related_terms, + batch_size, + ) + pipeline_conversation = await build_pipeline_conversation(repo_root, model) + query_embeddings = await model.get_embeddings_nocache( + [case.query for case in query_cases] + ) + related_query_embeddings = await model.get_embeddings_nocache( + [case.term for case in related_query_cases] + ) + answer_question_embeddings = await model.get_embeddings_nocache( + [case.question for case in answer_cases] + ) + answer_embeddings = await model.get_embeddings_nocache( + [case.answer for case in answer_cases] + ) top_score_stats = measure_top_score_stats(vector_base, query_embeddings) - effective_min_scores, skipped_min_scores = filter_min_scores_by_ceiling( + rows = await evaluate_grid( + vector_base, + query_cases, + query_embeddings, min_scores, - top_score_stats.max_top_score, + max_hits_values, + pipeline_conversation, + pipeline_query_cases, + related_vector_base, + related_terms, + related_query_cases, + related_query_embeddings, + answer_cases, + answer_question_embeddings, + answer_embeddings, ) - if not effective_min_scores: - raise ValueError( - "No requested min_score values are below the observed top-score ceiling " - f"of {top_score_stats.max_top_score:.4f}" - ) - - rows: list[BenchmarkRow] = [] - for min_score in effective_min_scores: - for max_hits in max_hits_values: - metrics = evaluate_search_queries( - vector_base, - query_cases, - query_embeddings, - min_score, - max_hits, - ) - rows.append(BenchmarkRow(min_score, max_hits, metrics)) print(f"Model: {model.model_name}") print(f"Messages indexed: {len(message_texts)}") + print(f"Related terms indexed: {len(related_terms)}") print(f"Queries evaluated: {len(query_cases)}") + print(f"Pipeline query cases evaluated: {len(pipeline_query_cases)}") + print(f"Related-term cases evaluated: {len(related_query_cases)}") + print(f"Answer cases evaluated: {len(answer_cases)}") + print("Pipeline scoring paths:") + for path in PIPELINE_SCORING_PATHS: + print(f" {path}") + if corpus_metadata.serialized_total_embedding_count is not None: + print( + "Serialized sidecar rows ignored: " + f"{corpus_metadata.serialized_total_embedding_count} " + f"({INDEX_EMBEDDINGS_PATH.name})" + ) + elif corpus_metadata.serialized_embedding_size is not None: + print( + "Serialized sidecar metadata found and ignored: " + f"embedding_size={corpus_metadata.serialized_embedding_size}" + ) + print(f"Corpus embeddings: {CORPUS_EMBEDDING_SOURCE}") + print(f"Query embeddings: {QUERY_EMBEDDING_SOURCE}") + print(f"Answer embeddings: {ANSWER_EMBEDDING_SOURCE}") print( "Observed top-1 score range: " f"{top_score_stats.min_top_score:.4f}..{top_score_stats.max_top_score:.4f} " f"(mean {top_score_stats.mean_top_score:.4f})" ) - if skipped_min_scores: - print( - f"Skipped {len(skipped_min_scores)} min_score values above " - f"{top_score_stats.max_top_score:.4f}; they cannot return any matches." - ) print() print_rows(rows) @@ -367,6 +1039,27 @@ async def run_benchmark( print(f" max_hits={best_row.max_hits}") print(f" hit_rate={best_row.metrics.hit_rate:.2f}%") print(f" mrr={best_row.metrics.mean_reciprocal_rank:.4f}") + if best_row.pipeline_metrics: + print(f" pipeline_hit_rate={best_row.pipeline_metrics.hit_rate:.2f}%") + print(f" pipeline_mrr={best_row.pipeline_metrics.mean_reciprocal_rank:.4f}") + print( + " pipeline_mean_result_count=" + f"{best_row.pipeline_metrics.mean_result_count:.2f}" + ) + if best_row.related_metrics: + print(f" related_hit_rate={best_row.related_metrics.hit_rate:.2f}%") + print(f" related_mrr={best_row.related_metrics.mean_reciprocal_rank:.4f}") + print( + " related_mean_result_count=" + f"{best_row.related_metrics.mean_result_count:.2f}" + ) + if best_row.answer_metrics: + print(f" answerable_support={best_row.answer_metrics.answerable_support:.4f}") + print( + " no_answer_rejection_rate=" + f"{best_row.answer_metrics.no_answer_rejection_rate:.2f}%" + ) + print(f" semantic_score={best_row.answer_metrics.semantic_score:.2f}") def main() -> None: diff --git a/tools/repeat_embedding_benchmarks.py b/tools/repeat_embedding_benchmarks.py index bc2a741a..f6590da6 100644 --- a/tools/repeat_embedding_benchmarks.py +++ b/tools/repeat_embedding_benchmarks.py @@ -22,14 +22,34 @@ from pathlib import Path from statistics import mean -import benchmark_embeddings +from benchmark_embeddings import ( + ANSWER_EMBEDDING_SOURCE, + BenchmarkRow, + build_pipeline_conversation, + build_text_vector_base, + build_vector_base, + CORPUS_EMBEDDING_SOURCE, + DEFAULT_MAX_HITS, + evaluate_grid, + INDEX_DATA_PATH, + INDEX_EMBEDDINGS_PATH, + load_answer_queries, + load_corpus_metadata, + load_message_texts, + load_pipeline_queries, + load_related_term_queries, + load_related_term_texts, + load_search_queries, + measure_top_score_stats, + parse_int_list, + PIPELINE_SCORING_PATHS, + QUERY_EMBEDDING_SOURCE, + resolve_min_scores, + SEARCH_RESULTS_PATH, + select_best_row, +) from dotenv import load_dotenv -BenchmarkRow = benchmark_embeddings.BenchmarkRow -DEFAULT_MAX_HITS = benchmark_embeddings.DEFAULT_MAX_HITS -parse_int_list = benchmark_embeddings.parse_int_list -resolve_min_scores = benchmark_embeddings.resolve_min_scores - DEFAULT_MODELS = [ "openai:text-embedding-3-small", "openai:text-embedding-3-large", @@ -46,6 +66,15 @@ class RunRow: max_hits: int hit_rate: float mean_reciprocal_rank: float + pipeline_hit_rate: float | None + pipeline_mean_reciprocal_rank: float | None + pipeline_mean_result_count: float | None + related_hit_rate: float | None + related_mean_reciprocal_rank: float | None + related_mean_result_count: float | None + answerable_support: float | None + no_answer_rejection_rate: float | None + semantic_score: float | None @dataclass @@ -78,6 +107,33 @@ def benchmark_row_to_run_row(row: BenchmarkRow) -> RunRow: max_hits=row.max_hits, hit_rate=row.metrics.hit_rate, mean_reciprocal_rank=row.metrics.mean_reciprocal_rank, + pipeline_hit_rate=( + row.pipeline_metrics.hit_rate if row.pipeline_metrics else None + ), + pipeline_mean_reciprocal_rank=( + row.pipeline_metrics.mean_reciprocal_rank if row.pipeline_metrics else None + ), + pipeline_mean_result_count=( + row.pipeline_metrics.mean_result_count if row.pipeline_metrics else None + ), + related_hit_rate=( + row.related_metrics.hit_rate if row.related_metrics else None + ), + related_mean_reciprocal_rank=( + row.related_metrics.mean_reciprocal_rank if row.related_metrics else None + ), + related_mean_result_count=( + row.related_metrics.mean_result_count if row.related_metrics else None + ), + answerable_support=( + row.answer_metrics.answerable_support if row.answer_metrics else None + ), + no_answer_rejection_rate=( + row.answer_metrics.no_answer_rejection_rate if row.answer_metrics else None + ), + semantic_score=( + row.answer_metrics.semantic_score if row.answer_metrics else None + ), ) @@ -97,6 +153,31 @@ def summarize_runs(model_spec: str, runs: list[RunResult]) -> dict[str, object]: "max_hits": max_hits, "mean_hit_rate": mean(row.hit_rate for row in rows), "mean_mrr": mean(row.mean_reciprocal_rank for row in rows), + "mean_pipeline_hit_rate": mean( + row.pipeline_hit_rate or 0.0 for row in rows + ), + "mean_pipeline_mrr": mean( + row.pipeline_mean_reciprocal_rank or 0.0 for row in rows + ), + "mean_pipeline_result_count": mean( + row.pipeline_mean_result_count or 0.0 for row in rows + ), + "mean_related_hit_rate": mean( + row.related_hit_rate or 0.0 for row in rows + ), + "mean_related_mrr": mean( + row.related_mean_reciprocal_rank or 0.0 for row in rows + ), + "mean_related_result_count": mean( + row.related_mean_result_count or 0.0 for row in rows + ), + "mean_answerable_support": mean( + row.answerable_support or 0.0 for row in rows + ), + "mean_no_answer_rejection_rate": mean( + row.no_answer_rejection_rate or 0.0 for row in rows + ), + "mean_semantic_score": mean(row.semantic_score or 0.0 for row in rows), } ) @@ -114,9 +195,14 @@ def summarize_runs(model_spec: str, runs: list[RunResult]) -> dict[str, object]: averaged_best_row = max( averaged_rows, key=lambda row: ( + float(row["mean_pipeline_mrr"]), + float(row["mean_pipeline_hit_rate"]), + float(row["mean_related_mrr"]), + float(row["mean_related_hit_rate"]), float(row["mean_mrr"]), float(row["mean_hit_rate"]), float(row["min_score"]), + float(row["mean_semantic_score"]), -int(row["max_hits"]), ), ) @@ -149,8 +235,14 @@ def write_markdown_summary(path: Path, summaries: list[dict[str, object]]) -> No lines = [ "# Repeated Embedding Benchmark Summary", "", - "| Model | Runs | Recommended min_score | Recommended max_hits | Mean hit rate | Mean MRR |", - "| --- | ---: | ---: | ---: | ---: | ---: |", + "Corpus and query embeddings are recomputed for each evaluated model.", + ( + "The serialized benchmark sidecar " + f"`{INDEX_EMBEDDINGS_PATH.name}` is ignored." + ), + "", + "| Model | Runs | Recommended min_score | Recommended max_hits | Pipeline hit rate | Pipeline MRR | Related hit rate | Related MRR | Mean hit rate | Mean MRR | Answer support | No-answer rejection | Semantic score |", + "| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |", ] for summary in summaries: recommended_row = summary["recommended_row"] @@ -161,8 +253,15 @@ def write_markdown_summary(path: Path, summaries: list[dict[str, object]]) -> No f"{summary['run_count']} | " f"{recommended_row['min_score']:.2f} | " f"{recommended_row['max_hits']} | " + f"{recommended_row['mean_pipeline_hit_rate']:.2f} | " + f"{recommended_row['mean_pipeline_mrr']:.4f} | " + f"{recommended_row['mean_related_hit_rate']:.2f} | " + f"{recommended_row['mean_related_mrr']:.4f} | " f"{recommended_row['mean_hit_rate']:.2f} | " - f"{recommended_row['mean_mrr']:.4f} |" + f"{recommended_row['mean_mrr']:.4f} | " + f"{recommended_row['mean_answerable_support']:.4f} | " + f"{recommended_row['mean_no_answer_rejection_rate']:.2f} | " + f"{recommended_row['mean_semantic_score']:.2f} |" ) lines.append("") for summary in summaries: @@ -175,6 +274,51 @@ def write_markdown_summary(path: Path, summaries: list[dict[str, object]]) -> No path.write_text("\n".join(lines), encoding="utf-8") +def build_run_suite_metadata( + repo_root: Path, + timestamp: str, + models: list[str], + runs: int, + min_scores: list[float], + max_hits_values: list[int], + batch_size: int, +) -> dict[str, object]: + """Build the shared metadata payload for one repeated benchmark suite.""" + + corpus_metadata = load_corpus_metadata(repo_root) + return { + "created_at_utc": timestamp, + "runs_per_model": runs, + "models": models, + "message_source_json": str(INDEX_DATA_PATH), + "pipeline_source_json": str(SEARCH_RESULTS_PATH), + "related_term_source_json": str(SEARCH_RESULTS_PATH), + "pipeline_scoring_paths": PIPELINE_SCORING_PATHS, + "ignored_serialized_embedding_sidecar": str(INDEX_EMBEDDINGS_PATH), + "ignored_serialized_embedding_size": ( + corpus_metadata.serialized_embedding_size + ), + "ignored_serialized_message_embedding_count": ( + corpus_metadata.serialized_message_count + ), + "ignored_serialized_related_embedding_count": ( + corpus_metadata.serialized_related_count + ), + "ignored_serialized_total_embedding_count": ( + corpus_metadata.serialized_total_embedding_count + ), + "corpus_embedding_source": CORPUS_EMBEDDING_SOURCE, + "query_embedding_source": QUERY_EMBEDDING_SOURCE, + "answer_embedding_source": ANSWER_EMBEDDING_SOURCE, + "min_scores": min_scores, + "max_hits_values": max_hits_values, + "min_score_count": len(min_scores), + "max_hits_count": len(max_hits_values), + "grid_row_count": len(min_scores) * len(max_hits_values), + "batch_size": batch_size, + } + + async def run_single_model_benchmark( model_spec: str, runs: int, @@ -186,56 +330,102 @@ async def run_single_model_benchmark( """Run the benchmark repeatedly for one model and persist raw artifacts.""" repo_root = Path(__file__).resolve().parent.parent - message_texts = benchmark_embeddings.load_message_texts(repo_root) - query_cases = benchmark_embeddings.load_search_queries(repo_root) + message_texts = load_message_texts(repo_root) + related_terms = load_related_term_texts(repo_root) + query_cases = load_search_queries(repo_root) + pipeline_query_cases = load_pipeline_queries(repo_root) + related_query_cases = load_related_term_queries(repo_root) + answer_cases = load_answer_queries(repo_root) model_output_dir = output_dir / sanitize_model_name(model_spec) model_output_dir.mkdir(parents=True, exist_ok=True) + grid_row_count = len(min_scores) * len(max_hits_values) + print( + " Loaded fixtures: " + f"{len(message_texts)} messages, " + f"{len(query_cases)} direct queries, " + f"{len(pipeline_query_cases)} pipeline queries, " + f"{len(related_query_cases)} related-term queries, " + f"{len(answer_cases)} answer queries.", + flush=True, + ) + print( + f" Preparing {model_spec} indexes and embeddings once " + f"for {grid_row_count} grid rows " + f"({len(min_scores)} min_score values x " + f"{len(max_hits_values)} max_hits values)...", + flush=True, + ) + model, vector_base = await build_vector_base( + model_spec, + message_texts, + batch_size, + ) + print(" Message vector index ready.", flush=True) + related_vector_base = await build_text_vector_base( + model, + related_terms, + batch_size, + ) + print(" Related-term vector index ready.", flush=True) + pipeline_conversation = await build_pipeline_conversation( + repo_root, + model, + ) + print(" True pipeline conversation indexes ready.", flush=True) + query_embeddings = await model.get_embeddings_nocache( + [case.query for case in query_cases] + ) + related_query_embeddings = await model.get_embeddings_nocache( + [case.term for case in related_query_cases] + ) + answer_question_embeddings = await model.get_embeddings_nocache( + [case.question for case in answer_cases] + ) + answer_embeddings = await model.get_embeddings_nocache( + [case.answer for case in answer_cases] + ) + top_score_stats = measure_top_score_stats( + vector_base, + query_embeddings, + ) + print( + " Direct query-to-message diagnostic ready " + f"(best-match score range {top_score_stats.min_top_score:.4f}.." + f"{top_score_stats.max_top_score:.4f}; not used to cap the sweep).", + flush=True, + ) run_results: list[RunResult] = [] for run_index in range(1, runs + 1): - model, vector_base = await benchmark_embeddings.build_vector_base( - model_spec, - message_texts, - batch_size, + print( + f" Run {run_index}/{runs}: evaluating {grid_row_count} grid rows...", + flush=True, ) - query_embeddings = await model.get_embeddings( - [case.query for case in query_cases] - ) - top_score_stats = benchmark_embeddings.measure_top_score_stats( + benchmark_rows = await evaluate_grid( vector_base, + query_cases, query_embeddings, + min_scores, + max_hits_values, + pipeline_conversation, + pipeline_query_cases, + related_vector_base, + related_terms, + related_query_cases, + related_query_embeddings, + answer_cases, + answer_question_embeddings, + answer_embeddings, + progress_label=f" Run {run_index}/{runs}", ) - effective_min_scores, skipped_min_scores = ( - benchmark_embeddings.filter_min_scores_by_ceiling( - min_scores, - top_score_stats.max_top_score, - ) + + best_row = select_best_row(benchmark_rows) + print( + " Run " + f"{run_index}/{runs}: best min_score={best_row.min_score:.2f}, " + f"max_hits={best_row.max_hits}.", + flush=True, ) - if not effective_min_scores: - raise ValueError( - "No requested min_score values are below the observed top-score ceiling " - f"of {top_score_stats.max_top_score:.4f} for {model.model_name}" - ) - if skipped_min_scores: - print( - f"Skipping {len(skipped_min_scores)} min_score values above " - f"{top_score_stats.max_top_score:.4f} for {model.model_name}" - ) - benchmark_rows: list[benchmark_embeddings.BenchmarkRow] = [] - for min_score in effective_min_scores: - for max_hits in max_hits_values: - metrics = benchmark_embeddings.evaluate_search_queries( - vector_base, - query_cases, - query_embeddings, - min_score, - max_hits, - ) - benchmark_rows.append( - benchmark_embeddings.BenchmarkRow(min_score, max_hits, metrics) - ) - - best_row = benchmark_embeddings.select_best_row(benchmark_rows) run_result = RunResult( run_index=run_index, model_spec=model_spec, @@ -269,15 +459,16 @@ async def run_repeated_benchmarks( timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") output_dir = output_root / timestamp output_dir.mkdir(parents=True, exist_ok=True) - - metadata = { - "created_at_utc": timestamp, - "runs_per_model": runs, - "models": models, - "min_scores": min_scores, - "max_hits_values": max_hits_values, - "batch_size": batch_size, - } + repo_root = Path(__file__).resolve().parent.parent + metadata = build_run_suite_metadata( + repo_root=repo_root, + timestamp=timestamp, + models=models, + runs=runs, + min_scores=min_scores, + max_hits_values=max_hits_values, + batch_size=batch_size, + ) write_json(output_dir / "metadata.json", metadata) summaries: list[dict[str, object]] = [] From 5e5b637faad992c529decab32aff229b61ff1399 Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Wed, 6 May 2026 09:59:12 +0530 Subject: [PATCH 6/7] Fix CI failure from platform-specific path separators. Benchmark metadata now serializes repository-relative paths with POSIX separators, so JSON output is stable across Windows and Linux. Updated the repeat benchmark metadata test to expect the portable tests/testdata/... form instead of Windows-only backslashes. --- tests/test_repeat_embedding_benchmarks.py | 4 ++-- tools/repeat_embedding_benchmarks.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/test_repeat_embedding_benchmarks.py b/tests/test_repeat_embedding_benchmarks.py index 47f3e80d..982039ba 100644 --- a/tests/test_repeat_embedding_benchmarks.py +++ b/tests/test_repeat_embedding_benchmarks.py @@ -36,10 +36,10 @@ def test_build_run_suite_metadata_records_ignored_sidecar() -> None: assert metadata["ignored_serialized_related_embedding_count"] == 1188 assert metadata["ignored_serialized_total_embedding_count"] == 1294 assert metadata["pipeline_source_json"] == ( - "tests\\testdata\\Episode_53_Search_results.json" + "tests/testdata/Episode_53_Search_results.json" ) assert metadata["related_term_source_json"] == ( - "tests\\testdata\\Episode_53_Search_results.json" + "tests/testdata/Episode_53_Search_results.json" ) assert metadata["pipeline_scoring_paths"] == [ "src/typeagent/knowpro/search.py::run_search_query", diff --git a/tools/repeat_embedding_benchmarks.py b/tools/repeat_embedding_benchmarks.py index f6590da6..5673bca4 100644 --- a/tools/repeat_embedding_benchmarks.py +++ b/tools/repeat_embedding_benchmarks.py @@ -229,6 +229,12 @@ def write_json(path: Path, data: object) -> None: path.write_text(json.dumps(data, indent=2), encoding="utf-8") +def serialize_metadata_path(path: Path) -> str: + """Serialize repository-relative metadata paths consistently across OSes.""" + + return path.as_posix() + + def write_markdown_summary(path: Path, summaries: list[dict[str, object]]) -> None: """Write the reviewer-facing markdown summary for all benchmarked models.""" @@ -290,11 +296,13 @@ def build_run_suite_metadata( "created_at_utc": timestamp, "runs_per_model": runs, "models": models, - "message_source_json": str(INDEX_DATA_PATH), - "pipeline_source_json": str(SEARCH_RESULTS_PATH), - "related_term_source_json": str(SEARCH_RESULTS_PATH), + "message_source_json": serialize_metadata_path(INDEX_DATA_PATH), + "pipeline_source_json": serialize_metadata_path(SEARCH_RESULTS_PATH), + "related_term_source_json": serialize_metadata_path(SEARCH_RESULTS_PATH), "pipeline_scoring_paths": PIPELINE_SCORING_PATHS, - "ignored_serialized_embedding_sidecar": str(INDEX_EMBEDDINGS_PATH), + "ignored_serialized_embedding_sidecar": serialize_metadata_path( + INDEX_EMBEDDINGS_PATH + ), "ignored_serialized_embedding_size": ( corpus_metadata.serialized_embedding_size ), From b869f47eed9126ffce6cf0fb63de24725646b107 Mon Sep 17 00:00:00 2001 From: shreejaykurhade Date: Wed, 6 May 2026 10:15:40 +0530 Subject: [PATCH 7/7] Merge latest PR updates and stabilize benchmark metadata paths --- AGENTS.md | 37 +- docs/high-level-api.md | 31 +- make.bat | 3 + pyproject.toml | 29 +- src/typeagent/aitools/model_adapters.py | 87 +- src/typeagent/aitools/utils.py | 62 +- src/typeagent/aitools/vectorbase.py | 7 +- src/typeagent/emails/email_memory.py | 4 +- src/typeagent/emails/email_message.py | 1 + src/typeagent/knowpro/conversation_base.py | 289 +++++- src/typeagent/knowpro/convsettings.py | 15 +- src/typeagent/knowpro/factory.py | 8 +- src/typeagent/knowpro/interfaces_core.py | 12 +- src/typeagent/knowpro/interfaces_storage.py | 79 +- src/typeagent/knowpro/knowledge.py | 42 +- src/typeagent/knowpro/messageutils.py | 85 +- src/typeagent/knowpro/secindex.py | 36 +- src/typeagent/knowpro/universal_message.py | 5 + src/typeagent/podcasts/podcast.py | 4 +- src/typeagent/podcasts/podcast_ingest.py | 55 +- src/typeagent/storage/memory/messageindex.py | 2 +- src/typeagent/storage/memory/provider.py | 67 +- src/typeagent/storage/memory/semrefindex.py | 270 +++--- src/typeagent/storage/sqlite/provider.py | 137 ++- src/typeagent/storage/sqlite/reltermsindex.py | 40 +- src/typeagent/storage/sqlite/schema.py | 24 + .../storage/sqlite/timestampindex.py | 11 +- src/typeagent/transcripts/transcript.py | 4 +- tests/conftest.py | 6 +- tests/test_add_messages_streaming.py | 842 ++++++++++++++++++ tests/test_convthreads.py | 122 +++ tests/test_convutils.py | 60 ++ tests/test_email_message.py | 223 +++++ tests/test_knowledge.py | 7 - tests/test_mcp_server.py | 181 ++-- tests/test_memory_semrefindex.py | 201 +++++ tests/test_message_text_index_population.py | 6 +- tests/test_messageutils.py | 23 + tests/test_podcasts.py | 8 +- tests/test_property_index_population.py | 4 +- tests/test_related_terms_index_population.py | 12 +- tests/test_reltermsindex.py | 4 +- tests/test_search.py | 113 +++ tests/test_searchlang_compile.py | 638 +++++++++++++ tests/test_secindex.py | 10 +- tests/test_secindex_storage_integration.py | 14 +- tests/test_semrefindex.py | 14 +- tests/test_serialization.py | 140 +++ tests/test_source_id_ingestion.py | 159 ++++ tests/test_sqlitestore.py | 12 +- tests/test_storage_providers_unified.py | 72 +- tests/test_textlocindex.py | 146 +++ tests/test_transcripts.py | 226 ++++- tests/test_utils.py | 128 ++- tools/benchmark_query.py | 233 +++++ tools/benchmark_semref_writes.py | 306 +++++++ tools/benchmark_vectorbase.py | 167 ++++ tools/ingest_email.py | 307 +++++-- tools/ingest_podcast.py | 9 +- tools/ingest_vtt.py | 308 ++++--- tools/load_json.py | 2 +- tools/query.py | 162 ++-- uv.lock | 117 ++- 63 files changed, 5436 insertions(+), 992 deletions(-) create mode 100644 tests/test_add_messages_streaming.py create mode 100644 tests/test_convthreads.py create mode 100644 tests/test_convutils.py create mode 100644 tests/test_email_message.py create mode 100644 tests/test_memory_semrefindex.py create mode 100644 tests/test_messageutils.py create mode 100644 tests/test_search.py create mode 100644 tests/test_searchlang_compile.py create mode 100644 tests/test_source_id_ingestion.py create mode 100644 tests/test_textlocindex.py create mode 100644 tools/benchmark_query.py create mode 100644 tools/benchmark_semref_writes.py create mode 100644 tools/benchmark_vectorbase.py diff --git a/AGENTS.md b/AGENTS.md index f9a4fb3b..ad337a70 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -3,16 +3,38 @@ **NEVER use TEST_MODEL_NAME or "test" embedding model outside of test files** Never run git commands that make any changes. (`git status` and `git diff` are fine) +Exceptions: `git push`, `git worktree`, `git branch` (for tracking setup), as instructed below. -**NEVER COMMIT CODE. Do not run `git commit` or any other git commands -that make changes to the repository. Not even `git add`** +**NEVER COMMIT CODE.** Do not run `git commit` or any other git commands +that make changes to the repository. Exception: Worktrees/Branches below. +`git add` is fine. When moving, copying or deleting files, use the git commands: `git mv`, `git cp`, `git rm` -When I ask to update AGENTS.md (even if maybe) extract a general rule from what I said -before and update AGENTS.md (unless it's already in there -- maybe reformulate since -it apparently didn't work). Also, when it looks like I state a general rule, add it to -AGENTS.md. In all cases show what you added to AGENTS.md. +## Worktrees and Branches + +- Each session uses its own worktree with a feature branch +- Create worktrees with: `git worktree add ../- -b ` +- Push the branch to the `me` remote: `git push me ` +- Set upstream to `me/`: `git branch --set-upstream-to me/` +- **Never** upstream to `me/main` — that must stay identical to `origin/main` +- The worktree directory name should be `-` (sibling of the main checkout) +- **Work in the worktree directory**, not the main checkout — edit files there, run tests there +- VS Code may show buffers from the main checkout; ignore those when working in a worktree. + When in doubt, verify edits landed on disk with `cat` or `grep` in the terminal. + +## Debugging discipline + +- When a bug seems impossible, suspect stale files or wrong working directory — not exotic causes. +- If you're tempted to blame installed package versions, `__pycache__`, or similar, + **stop and ask the user** before investigating further. You're probably on the wrong track. +- When fixing CI failures on a PR, compare against the latest PR head and the exact failing + CI job context before declaring the fix complete. + +**Whenever the user tells you how to do something, states a preference, or corrects you, +extract a general rule and add it to AGENTS.md** (unless it's already covered -- maybe +reformulate since it apparently didn't work). This applies even without being asked. +In all cases show what you added to AGENTS.md. - Don't use '!' on the command line, it's some bash magic (even inside single quotes) - When running 'make' commands, do not use the venv (the Makefile uses 'uv run') @@ -25,6 +47,7 @@ AGENTS.md. In all cases show what you added to AGENTS.md. - Use `make test` to run all tests - Use `make check test` to run `make check` and if it passes also run `make test` - Use `make format` to format all files using `black`. Do this before reporting success. +- When validating changes, first run `pytest` only on new/modified test files, then run `make format check test` once at the end. - Keep ad-hoc and performance benchmarks under `tools/`, not `tests/`, so `make test` does not run them. ## Package Management with uv @@ -36,7 +59,7 @@ AGENTS.md. In all cases show what you added to AGENTS.md. - uv maintains consistency between `pyproject.toml`, `uv.lock`, and installed packages - Trust uv's automatic version resolution and file management -**IMPORTANT! YOU ARE NOT DONE UNTIL `make check test format` PASSES** +**IMPORTANT! YOU ARE NOT DONE UNTIL `make format check test` PASSES** # Code generation diff --git a/docs/high-level-api.md b/docs/high-level-api.md index 4978f619..fe8b4a7e 100644 --- a/docs/high-level-api.md +++ b/docs/high-level-api.md @@ -20,7 +20,7 @@ class ConversationMessage( text_chunks: list[str], # Text of the message, 1 or more chunks tags: list[str] = [], # Optional tags timestamp: str | None = None, # ISO timestamp in UTC with 'z' suffix - metadata: ConversationMessageMeta, # See below + metadata: ConversationMessageMeta, # See below ) ``` @@ -64,7 +64,32 @@ extracted and indexed knowledge thereof. It is constructed by calling the factory function `typeagent.create_conversation` described below. -It has one public method: +It has these public methods: + +- `add_messages_with_indexing` + ```py + async def add_messages_with_indexing( + messages: list[TMessage], + *, + source_ids: list[str] | None = None, + ) -> AddMessagesResult + ``` + + Adds messages and updates all indexes in a single transaction. + For SQLite storage this is all-or-nothing. + +- `add_messages_streaming` + ```py + async def add_messages_streaming( + messages: AsyncIterable[TMessage], + *, + batch_size: int = 100, + on_batch_committed: Callable[[AddMessagesResult], None] | None = None, + ) -> AddMessagesResult + ``` + + Adds messages from an async stream, committing each batch separately. + Useful for very large ingestions where one large transaction is impractical. - `query` ```py @@ -80,7 +105,7 @@ It has one public method: ## Functions -There is currently only one public function. +There is currently only one public top-level function. #### Factory function diff --git a/make.bat b/make.bat index be66b3bb..2a2c85a6 100644 --- a/make.bat +++ b/make.bat @@ -45,12 +45,15 @@ uv run pytest goto end :coverage +setlocal if not exist ".venv\" call make.bat venv echo Running test coverage... uv run coverage erase +set "COVERAGE_PROCESS_START=.coveragerc" uv run coverage run -m pytest uv run coverage combine uv run coverage report +endlocal goto end diff --git a/pyproject.toml b/pyproject.toml index b994d3f0..328ff6f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "pyreadline3>=3.5.4 ; sys_platform == 'win32'", "pyright>=1.1.409", "python-dotenv>=1.1.0", + "stamina>=26.1.0", "tiktoken>=0.12.0", "typechat>=0.0.4", "webvtt-py>=0.5.1", @@ -81,18 +82,18 @@ known_local_folder = ["conftest"] [dependency-groups] dev = [ "azure-mgmt-authorization>=4.0.0", - "azure-mgmt-keyvault>=12.1.1", - "black>=25.12.0", - "coverage[toml]>=7.9.1", - "google-api-python-client>=2.184.0", - "google-auth-httplib2>=0.2.0", - "google-auth-oauthlib>=1.2.2", - "isort>=7.0.0", - "logfire>=4.1.0", # So 'make check' passes - "msgraph-sdk>=1.54.0", - "opentelemetry-instrumentation-httpx>=0.57b0", - "pyright>=1.1.408", # 407 has a regression - "pytest>=8.3.5", - "pytest-asyncio>=0.26.0", - "pytest-mock>=3.14.0", + "azure-mgmt-keyvault>=14.0.1", + "black>=26.3.1", + "coverage[toml]>=7.13.5", + "google-api-python-client>=2.194.0", + "google-auth-httplib2>=0.3.1", + "google-auth-oauthlib>=1.3.1", + "isort>=8.0.1", + "logfire>=4.32.1", # So 'make check' passes + "msgraph-sdk>=1.56.0", + "opentelemetry-instrumentation-httpx>=0.61b0", + "pyright>=1.1.409", + "pytest>=9.0.3", + "pytest-asyncio>=1.3.0", + "pytest-mock>=3.15.1", ] diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 46208f6e..7a195d8c 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -27,15 +27,21 @@ """ from collections.abc import Sequence +import logging import os import numpy as np from numpy.typing import NDArray +import stamina +from stamina import BoundAsyncRetryingCaller +from stamina.instrumentation import RetryDetails, set_on_retry_hooks +import openai from pydantic_ai import Embedder as _PydanticAIEmbedder from pydantic_ai.embeddings.base import EmbeddingModel as _PydanticAIEmbeddingModelBase from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType from pydantic_ai.embeddings.settings import EmbeddingSettings +from pydantic_ai.exceptions import ModelAPIError from pydantic_ai.messages import ( ModelMessage, ModelRequest, @@ -52,6 +58,47 @@ NormalizedEmbeddings, ) +_TRANSIENT_ERRORS = ( + openai.RateLimitError, + openai.APIConnectionError, + openai.APITimeoutError, + openai.InternalServerError, + ModelAPIError, +) + +DEFAULT_CHAT_RETRIER = stamina.AsyncRetryingCaller(attempts=6, timeout=120).on( + _TRANSIENT_ERRORS +) +DEFAULT_EMBED_RETRIER = stamina.AsyncRetryingCaller(attempts=6, timeout=120).on( + _TRANSIENT_ERRORS +) + +_logger = logging.getLogger("stamina") + +_CALLABLE_LABELS: dict[str, str] = { + "request": "chat", + "embed_documents": "embedding", +} + + +def _on_retry(details: RetryDetails) -> None: + kind = _CALLABLE_LABELS.get(details.name, details.name) + caused = details.caused_by + exc_summary = repr(caused)[:200] + _logger.warning( + "stamina: retrying %s request (attempt %d, waited %.1fs so far, " + "waiting %.1fs): %s", + kind, + details.retry_num, + details.waited_so_far, + details.wait_for, + exc_summary, + ) + + +set_on_retry_hooks([_on_retry]) + + # --------------------------------------------------------------------------- # Chat model adapter # --------------------------------------------------------------------------- @@ -65,8 +112,13 @@ class PydanticAIChatModel(typechat.TypeChatLanguageModel): used wherever TypeChat expects a ``TypeChatLanguageModel``. """ - def __init__(self, model: Model) -> None: + def __init__( + self, + model: Model, + retrier: BoundAsyncRetryingCaller | None = None, + ) -> None: self._model = model + self._retrier = retrier or DEFAULT_CHAT_RETRIER async def complete( self, prompt: str | list[typechat.PromptSection] @@ -84,7 +136,7 @@ async def complete( messages: list[ModelMessage] = [ModelRequest(parts=parts)] params = ModelRequestParameters() - response = await self._model.request(messages, None, params) + response = await self._retrier(self._model.request, messages, None, params) text_parts = [p.content for p in response.parts if isinstance(p, TextPart)] if text_parts: return typechat.Success("".join(text_parts)) @@ -111,24 +163,20 @@ def __init__( self, embedder: _PydanticAIEmbedder, model_name: str, + retrier: BoundAsyncRetryingCaller | None = None, ) -> None: self._embedder = embedder self.model_name = model_name + self._retrier = retrier or DEFAULT_EMBED_RETRIER async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: - result = await self._embedder.embed_documents([input]) - embedding: NDArray[np.float32] = np.array( - result.embeddings[0], dtype=np.float32 - ) - norm = float(np.linalg.norm(embedding)) - if norm > 0: - embedding = (embedding / norm).astype(np.float32) - return embedding + embeddings = await self.get_embeddings_nocache([input]) + return embeddings[0] async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: if not input: raise ValueError("Cannot embed an empty list") - result = await self._embedder.embed_documents(input) + result = await self._retrier(self._embedder.embed_documents, input) embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) norms = np.where(norms > 0, norms, np.float32(1.0)) @@ -182,7 +230,7 @@ def _make_azure_provider( azure_endpoint=azure_endpoint, api_version=api_version, azure_ad_token_provider=token_provider.get_token, - max_retries=5, + max_retries=0, ) else: apim_key = os.getenv("AZURE_APIM_SUBSCRIPTION_KEY") @@ -193,7 +241,7 @@ def _make_azure_provider( default_headers=( {"Ocp-Apim-Subscription-Key": apim_key} if apim_key else None ), - max_retries=5, + max_retries=0, ) return AzureProvider(openai_client=client) @@ -208,6 +256,8 @@ def _make_azure_provider( def create_chat_model( model_spec: str | None = None, + *, + retrier: BoundAsyncRetryingCaller | None = None, ) -> PydanticAIChatModel: """Create a chat model from a ``provider:model`` spec. @@ -249,7 +299,7 @@ def create_chat_model( ) else: model = infer_model(model_spec) - return PydanticAIChatModel(model) + return PydanticAIChatModel(model, retrier) DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-ada-002" @@ -257,6 +307,7 @@ def create_chat_model( def create_embedding_model( model_spec: str | None = None, + retrier: BoundAsyncRetryingCaller | None = None, ) -> CachingEmbeddingModel: """Create an embedding model from a ``provider:model`` spec. @@ -313,7 +364,7 @@ def create_embedding_model( embedder = _PydanticAIEmbedder(embedding_model) else: embedder = _PydanticAIEmbedder(model_spec) - return CachingEmbeddingModel(PydanticAIEmbedder(embedder, model_name)) + return CachingEmbeddingModel(PydanticAIEmbedder(embedder, model_name, retrier)) # --------------------------------------------------------------------------- @@ -400,6 +451,8 @@ def create_test_embedding_model( def configure_models( chat_model_spec: str, embedding_model_spec: str, + chat_retrier: BoundAsyncRetryingCaller | None = None, + embed_retrier: BoundAsyncRetryingCaller | None = None, ) -> tuple[PydanticAIChatModel, CachingEmbeddingModel]: """Configure both a chat model and an embedding model at once. @@ -416,6 +469,6 @@ def configure_models( extractor = KnowledgeExtractor(model=chat) """ return ( - create_chat_model(chat_model_spec), - create_embedding_model(embedding_model_spec), + create_chat_model(chat_model_spec, retrier=chat_retrier), + create_embedding_model(embedding_model_spec, retrier=embed_retrier), ) diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index cc6bcbb8..e421e8d5 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -5,6 +5,7 @@ from contextlib import contextmanager import difflib +import io import os import re import shutil @@ -43,12 +44,17 @@ def timelog(label: str, verbose: bool = True): ) -def pretty_print(obj: object, prefix: str = "", suffix: str = "") -> None: +def pretty_print( + obj: object, + prefix: str = "", + suffix: str = "", + file: io.StringIO | None = None, +) -> None: """Pretty-print an object using pprint.""" import pprint line_width = min(200, shutil.get_terminal_size().columns) - print(prefix + pprint.pformat(obj, width=line_width) + suffix) + print(prefix + pprint.pformat(obj, width=line_width) + suffix, file=file) def format_code(text: str, line_width=None) -> str: @@ -91,7 +97,7 @@ def create_translator[T]( # Vibe-coded by o4-mini-high -def list_diff(label_a, a, label_b, b, max_items): +def list_diff(label_a, a, label_b, b, max_items, file=None): """Print colorized diff between two sorted list of numbers.""" sm = difflib.SequenceMatcher(None, a, b) a_out, b_out = [], [] @@ -145,8 +151,8 @@ def fmt(row, seg_widths): # print each segment for start, end in segments: seg_widths = widths[start:end] - print(la, fmt(a_cols[start:end], seg_widths)) - print(lb, fmt(b_cols[start:end], seg_widths)) + print(la, fmt(a_cols[start:end], seg_widths), file=file) + print(lb, fmt(b_cols[start:end], seg_widths), file=file) def setup_logfire(): @@ -252,52 +258,6 @@ def get_azure_api_key(azure_api_key: str) -> str: return azure_api_key -def create_async_openai_client( - endpoint_envvar: str = "AZURE_OPENAI_ENDPOINT", - base_url: str | None = None, -): - """Create AsyncOpenAI or AsyncAzureOpenAI client based on environment variables. - - Returns the appropriate async OpenAI client based on what credentials are available. - Prefers OPENAI_API_KEY over AZURE_OPENAI_API_KEY. - - Args: - endpoint_envvar: Environment variable name for Azure endpoint (default: AZURE_OPENAI_ENDPOINT). - base_url: Optional base URL override for OpenAI client. - - Returns: - AsyncOpenAI or AsyncAzureOpenAI client instance. - - Raises: - RuntimeError: If neither OPENAI_API_KEY nor AZURE_OPENAI_API_KEY is set. - """ - from openai import AsyncAzureOpenAI, AsyncOpenAI - - if openai_api_key := os.getenv("OPENAI_API_KEY"): - return AsyncOpenAI(api_key=openai_api_key, base_url=base_url, max_retries=5) - - elif azure_api_key := os.getenv("AZURE_OPENAI_API_KEY"): - azure_api_key = get_azure_api_key(azure_api_key) - azure_endpoint, api_version = parse_azure_endpoint(endpoint_envvar) - - apim_key = os.getenv("AZURE_APIM_SUBSCRIPTION_KEY") - - return AsyncAzureOpenAI( - api_version=api_version, - azure_endpoint=azure_endpoint, - api_key=azure_api_key, - default_headers=( - {"Ocp-Apim-Subscription-Key": apim_key} if apim_key else None - ), - max_retries=5, - ) - - else: - raise RuntimeError( - "Neither OPENAI_API_KEY nor AZURE_OPENAI_API_KEY was provided." - ) - - def resolve_azure_model_name( model_name: str, endpoint_envvar: str = "AZURE_OPENAI_ENDPOINT", diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 7b5d3448..7a3c4549 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -151,11 +151,14 @@ async def add_key(self, key: str, cache: bool = True) -> None: embedding = await self.get_embedding(key, cache=cache) self.add_embedding(key if cache else None, embedding) - async def add_keys(self, keys: list[str], cache: bool = True) -> None: + async def add_keys( + self, keys: list[str], cache: bool = True + ) -> NormalizedEmbeddings | None: if not keys: - return + return None embeddings = await self.get_embeddings(keys, cache=cache) self.add_embeddings(keys if cache else None, embeddings) + return embeddings def fuzzy_lookup_embedding( self, diff --git a/src/typeagent/emails/email_memory.py b/src/typeagent/emails/email_memory.py index 6dd50cc4..1523f3a8 100644 --- a/src/typeagent/emails/email_memory.py +++ b/src/typeagent/emails/email_memory.py @@ -23,7 +23,9 @@ class EmailMemorySettings: def __init__(self, conversation_settings: ConversationSettings) -> None: - self.language_model = model_adapters.create_chat_model() + self.language_model = model_adapters.create_chat_model( + retrier=conversation_settings.chat_retrier + ) self.query_translator = utils.create_translator( self.language_model, search_query_schema.SearchQuery ) diff --git a/src/typeagent/emails/email_message.py b/src/typeagent/emails/email_message.py index 0a469f49..47abdbec 100644 --- a/src/typeagent/emails/email_message.py +++ b/src/typeagent/emails/email_message.py @@ -161,6 +161,7 @@ def __init__(self, **data: Any) -> None: ) timestamp: str | None = None # Use metadata.sent_on for the actual sent time src_url: str | None = None # Source file or uri for this email + source_id: str | None = None # External source id (see IMessage.source_id) def get_knowledge(self) -> kplib.KnowledgeResponse: return self.metadata.get_knowledge() diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 8026472a..673695d2 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -3,6 +3,9 @@ """Base class for conversations with incremental indexing support.""" +import asyncio +from collections.abc import AsyncIterable, Callable, Sequence +import contextlib from dataclasses import dataclass from datetime import datetime, timezone from typing import Generic, Self, TypeVar @@ -36,10 +39,22 @@ MessageOrdinal, Topic, ) +from .interfaces_core import TextLocation +from .knowledge import extract_knowledge_from_text_batch +from .messageutils import get_all_message_chunk_locations TMessage = TypeVar("TMessage", bound=IMessage) +@dataclass(frozen=True) +class _ExtractionResult: + """Pre-extracted knowledge for a batch, ready to commit.""" + + messages: Sequence[IMessage] + text_locations: list[TextLocation] + knowledge_results: list[typechat.Result[kplib.KnowledgeResponse]] + + @dataclass(init=False) class ConversationBase( Generic[TMessage], IConversation[TMessage, ITermToSemanticRefIndex] @@ -93,10 +108,10 @@ async def create( tags if tags is not None else [], ) instance.storage_provider = storage_provider - instance.messages = await storage_provider.get_message_collection() - instance.semantic_refs = await storage_provider.get_semantic_ref_collection() - instance.semantic_ref_index = await storage_provider.get_semantic_ref_index() - instance.secondary_indexes = await secindex.ConversationSecondaryIndexes.create( + instance.messages = storage_provider.messages + instance.semantic_refs = storage_provider.semantic_refs + instance.semantic_ref_index = storage_provider.semantic_ref_index + instance.secondary_indexes = secindex.ConversationSecondaryIndexes( storage_provider, settings.related_term_index_settings ) return instance @@ -132,9 +147,12 @@ async def add_messages_with_indexing( Args: messages: Messages to add - source_ids: Optional list of source IDs to mark as ingested. These are - marked within the same transaction, so if the indexing fails, the - source IDs won't be marked as ingested (for SQLite storage). + source_ids: Optional explicit list of source IDs to mark as ingested, + one per message. When ``None`` (the default), each message's + ``source_id`` attribute is used instead — messages whose + ``source_id`` is ``None`` are silently skipped. These are marked + within the same transaction, so if the indexing fails, the source + IDs won't be marked as ingested (for SQLite storage). Returns: Result with counts of messages/semrefs added @@ -143,7 +161,7 @@ async def add_messages_with_indexing( Exception: Any error """ storage = await self.settings.get_storage_provider() - if source_ids: + if source_ids is not None: if len(source_ids) != len(messages): raise ValueError( f"Length of source_ids {len(source_ids)} " @@ -152,9 +170,13 @@ async def add_messages_with_indexing( async with storage: # Mark source IDs as ingested (will be rolled back on error) - if source_ids: - for source_id in source_ids: - await storage.mark_source_ingested(source_id) + sids = ( + source_ids + if source_ids is not None + else [m.source_id for m in messages if m.source_id is not None] + ) + if sids: + await storage.mark_sources_ingested_batch(sids) start_points = IndexingStartPoints( message_count=await self.messages.size(), @@ -173,8 +195,11 @@ async def add_messages_with_indexing( await self._update_secondary_indexes_incremental(start_points) + messages_added = await self.messages.size() - start_points.message_count + chunks_added = sum(len(m.text_chunks) for m in messages[:messages_added]) result = AddMessagesResult( - messages_added=await self.messages.size() - start_points.message_count, + messages_added=messages_added, + chunks_added=chunks_added, semrefs_added=await self.semantic_refs.size() - start_points.semref_count, ) @@ -186,6 +211,222 @@ async def add_messages_with_indexing( return result + async def add_messages_streaming( + self, + messages: AsyncIterable[TMessage], + *, + batch_size: int = 100, + on_batch_committed: Callable[[AddMessagesResult], None] | None = None, + ) -> AddMessagesResult: + """Add messages from an async iterable, committing in batches. + + Uses a two-stage pipeline: while batch N is being committed (DB writes, + embeddings, secondary indexes), batch N+1's LLM extraction runs + concurrently. LLM extraction is typically 95% of wall time, so this + nearly doubles throughput for multi-batch ingestions. + + **Source-ID tracking**: each message's ``source_id`` (if not ``None``) + is marked as ingested within the commit transaction. Callers are + responsible for filtering duplicates before yielding messages. + + **Extraction failures**: when knowledge extraction returns a + ``Failure`` for a chunk, the failure is recorded via + ``storage.record_chunk_failure`` and processing continues with the + remaining chunks. Raised exceptions (HTTP errors, timeouts, etc.) + are treated as systemic and stop the run immediately — the current + batch is rolled back and the exception propagates. + + Args: + messages: An async iterable of messages to ingest. + batch_size: Target number of text chunks per commit batch. + Messages are never split across batches, so the actual + chunk count may exceed ``batch_size`` if a single message + has more chunks than that. + on_batch_committed: Optional callback invoked after each batch is + committed, receiving the batch's ``AddMessagesResult``. + + Returns: + Cumulative ``AddMessagesResult`` across all committed batches. + """ + storage = await self.settings.get_storage_provider() + should_extract = ( + self.settings.semantic_ref_index_settings.auto_extract_knowledge + ) + total = AddMessagesResult() + + def _accumulate(result: AddMessagesResult) -> None: + total.messages_added += result.messages_added + total.semrefs_added += result.semrefs_added + total.chunks_added += result.chunks_added + if on_batch_committed: + on_batch_committed(result) + + pending_commit: asyncio.Task[AddMessagesResult] | None = None + pending_extraction: asyncio.Task[_ExtractionResult | None] | None = None + + async def _drain_commit() -> None: + nonlocal pending_commit + if pending_commit is not None: + _accumulate(await pending_commit) + pending_commit = None + + async def _submit_batch(batch: list[TMessage]) -> None: + nonlocal pending_commit, pending_extraction + if not batch: + return + + if should_extract: + next_extraction = asyncio.create_task( + self._extract_knowledge_for_batch(batch) + ) + else: + next_extraction = None + pending_extraction = next_extraction + + await _drain_commit() + + extraction = await next_extraction if next_extraction is not None else None + pending_extraction = None + + pending_commit = asyncio.create_task( + self._commit_batch_streaming(storage, batch, extraction) + ) + + try: + batch: list[TMessage] = [] + batch_chunks = 0 + async for msg in messages: + msg_chunks = len(msg.text_chunks) + if batch and batch_chunks + msg_chunks > batch_size: + await _submit_batch(batch) + batch = [] + batch_chunks = 0 + batch.append(msg) + batch_chunks += msg_chunks + if batch_chunks >= batch_size: + await _submit_batch(batch) + batch = [] + batch_chunks = 0 + + if batch: + await _submit_batch(batch) + + await _drain_commit() + except BaseException: + if pending_extraction is not None and not pending_extraction.done(): + pending_extraction.cancel() + with contextlib.suppress(asyncio.CancelledError): + await pending_extraction + if pending_commit is not None and not pending_commit.done(): + pending_commit.cancel() + with contextlib.suppress(asyncio.CancelledError): + await pending_commit + raise + + return total + + async def _extract_knowledge_for_batch( + self, + messages: list[TMessage], + ) -> _ExtractionResult | None: + """Run LLM extraction on message texts — no DB access. + + Uses 0-based ordinals; the caller remaps to global ordinals at commit + time. Safe to run concurrently with a DB transaction on another batch. + """ + text_locations = get_all_message_chunk_locations(messages, 0) + if not text_locations: + return None + + settings = self.settings.semantic_ref_index_settings + knowledge_extractor = ( + settings.knowledge_extractor or convknowledge.KnowledgeExtractor() + ) + + text_batch = [ + messages[tl.message_ordinal].text_chunks[tl.chunk_ordinal].strip() + for tl in text_locations + ] + + knowledge_results = await extract_knowledge_from_text_batch( + knowledge_extractor, + text_batch, + settings.concurrency, + ) + return _ExtractionResult( + messages=messages, + text_locations=text_locations, + knowledge_results=knowledge_results, + ) + + async def _apply_extraction_results( + self, + storage: IStorageProvider[TMessage], + extraction: _ExtractionResult, + global_message_start: int, + ) -> None: + """Write pre-extracted knowledge into the DB. Must be inside a transaction.""" + bulk_items: list[tuple[int, int, kplib.KnowledgeResponse]] = [] + for i, knowledge_result in enumerate(extraction.knowledge_results): + tl = extraction.text_locations[i] + global_msg_ord = tl.message_ordinal + global_message_start + if isinstance(knowledge_result, typechat.Failure): + await storage.record_chunk_failure( + global_msg_ord, + tl.chunk_ordinal, + type(knowledge_result).__name__, + knowledge_result.message[:500], + ) + continue + bulk_items.append( + (global_msg_ord, tl.chunk_ordinal, knowledge_result.value) + ) + if bulk_items: + await semrefindex.add_knowledge_batch_to_semantic_ref_index( + self, bulk_items + ) + + async def _commit_batch_streaming( + self, + storage: IStorageProvider[TMessage], + filtered: list[TMessage], + extraction: _ExtractionResult | None, + ) -> AddMessagesResult: + """Commit a single batch with pre-extracted knowledge.""" + async with storage: + start_points = IndexingStartPoints( + message_count=await self.messages.size(), + semref_count=await self.semantic_refs.size(), + ) + + await self.messages.extend(filtered) + + source_ids = [m.source_id for m in filtered if m.source_id is not None] + if source_ids: + await storage.mark_sources_ingested_batch(source_ids) + + await self._add_metadata_knowledge_incremental(start_points.message_count) + + if extraction is not None: + await self._apply_extraction_results( + storage, extraction, start_points.message_count + ) + + await self._update_secondary_indexes_incremental(start_points) + + await storage.update_conversation_timestamps( + updated_at=datetime.now(timezone.utc) + ) + + messages_added = await self.messages.size() - start_points.message_count + chunks_added = sum(len(m.text_chunks) for m in filtered[:messages_added]) + return AddMessagesResult( + messages_added=messages_added, + chunks_added=chunks_added, + semrefs_added=await self.semantic_refs.size() + - start_points.semref_count, + ) + async def _add_metadata_knowledge_incremental( self, start_from_message_ordinal: int, @@ -216,21 +457,17 @@ async def _add_llm_knowledge_incremental( settings.knowledge_extractor or convknowledge.KnowledgeExtractor() ) - # Get batches of text locations from the message list - from .messageutils import get_message_chunk_batch_from_list - - batches = get_message_chunk_batch_from_list( + text_locations = get_all_message_chunk_locations( messages, start_from_message_ordinal, - settings.batch_size, ) - for text_location_batch in batches: - await semrefindex.add_batch_to_semantic_ref_index_from_list( - self, - messages, - text_location_batch, - knowledge_extractor, - ) + await semrefindex.add_batch_to_semantic_ref_index_from_list( + self, + messages, + text_locations, + knowledge_extractor, + concurrency=settings.concurrency, + ) async def _update_secondary_indexes_incremental( self, @@ -354,12 +591,12 @@ async def query( """ # Create translators lazily (once per conversation instance) if self._query_translator is None: - model = model_adapters.create_chat_model() + model = model_adapters.create_chat_model(retrier=self.settings.chat_retrier) self._query_translator = utils.create_translator( model, search_query_schema.SearchQuery ) if self._answer_translator is None: - model = model_adapters.create_chat_model() + model = model_adapters.create_chat_model(retrier=self.settings.chat_retrier) self._answer_translator = utils.create_translator( model, answer_response_schema.AnswerResponse ) diff --git a/src/typeagent/knowpro/convsettings.py b/src/typeagent/knowpro/convsettings.py index 30aa22ad..bd05c19e 100644 --- a/src/typeagent/knowpro/convsettings.py +++ b/src/typeagent/knowpro/convsettings.py @@ -5,6 +5,8 @@ from dataclasses import dataclass +from stamina import BoundAsyncRetryingCaller + from ..aitools.embeddings import IEmbeddingModel from ..aitools.model_adapters import create_embedding_model from ..aitools.vectorbase import TextEmbeddingIndexSettings @@ -32,7 +34,7 @@ def __init__(self, embedding_index_settings: TextEmbeddingIndexSettings): @dataclass class SemanticRefIndexSettings: - batch_size: int + concurrency: int auto_extract_knowledge: bool knowledge_extractor: IKnowledgeExtractor | None = None @@ -44,9 +46,16 @@ def __init__( self, model: IEmbeddingModel | None = None, storage_provider: IStorageProvider | None = None, + *, + chat_retrier: BoundAsyncRetryingCaller | None = None, + embed_retrier: BoundAsyncRetryingCaller | None = None, ): + # Retry callers -- None means "use the default" in model_adapters. + self.chat_retrier = chat_retrier + self.embed_retrier = embed_retrier + # All settings share the same model, so they share the embedding cache. - model = model or create_embedding_model() + model = model or create_embedding_model(retrier=embed_retrier) self.embedding_model = model min_score = DEFAULT_RELATED_TERM_MIN_SCORE self.related_term_index_settings = RelatedTermIndexSettings( @@ -57,7 +66,7 @@ def __init__( TextEmbeddingIndexSettings(model, min_score=DEFAULT_MESSAGE_TEXT_MIN_SCORE) ) self.semantic_ref_index_settings = SemanticRefIndexSettings( - batch_size=4, # Effectively max concurrency + concurrency=4, auto_extract_knowledge=True, # The high-level API wants this ) diff --git a/src/typeagent/knowpro/factory.py b/src/typeagent/knowpro/factory.py index bdebf8f2..5c94bce1 100644 --- a/src/typeagent/knowpro/factory.py +++ b/src/typeagent/knowpro/factory.py @@ -60,10 +60,10 @@ async def create_conversation[TMessage: IMessage]( tags=tags if tags is not None else [], ) conversation.storage_provider = storage_provider - conversation.messages = await storage_provider.get_message_collection() - conversation.semantic_refs = await storage_provider.get_semantic_ref_collection() - conversation.semantic_ref_index = await storage_provider.get_semantic_ref_index() - conversation.secondary_indexes = await secindex.ConversationSecondaryIndexes.create( + conversation.messages = storage_provider.messages + conversation.semantic_refs = storage_provider.semantic_refs + conversation.semantic_ref_index = storage_provider.semantic_ref_index + conversation.secondary_indexes = secindex.ConversationSecondaryIndexes( storage_provider, settings.related_term_index_settings ) return conversation diff --git a/src/typeagent/knowpro/interfaces_core.py b/src/typeagent/knowpro/interfaces_core.py index 4dc8fc8e..10d1765c 100644 --- a/src/typeagent/knowpro/interfaces_core.py +++ b/src/typeagent/knowpro/interfaces_core.py @@ -90,8 +90,10 @@ class IndexingStartPoints: class AddMessagesResult: """Result of add_messages_with_indexing operation.""" - messages_added: int - semrefs_added: int + messages_added: int = 0 + chunks_added: int = 0 + semrefs_added: int = 0 + messages_skipped: int = 0 # Messages are referenced by their sequential ordinal numbers. @@ -129,6 +131,12 @@ class IMessage[TMetadata: IMessageMetadata](IKnowledgeSource, Protocol): # Metadata associated with the message such as its source. metadata: TMetadata | None = None + # Optional external identifier of the source this message was ingested from + # (e.g., an email ID, a file path, a URL). Used by ingestion pipelines to + # detect already-ingested sources for restartability. None means the message + # is not associated with an external source (e.g., synthesized in tests). + source_id: str | None = None + # Semantic references are also ordinal. type SemanticRefOrdinal = int diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index 97f7b600..9f17574d 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -67,6 +67,21 @@ class SemanticRefMetadata(NamedTuple): knowledge_type: KnowledgeType +@dataclass +class ChunkFailure: + """Record of a single failed knowledge-extraction attempt for one chunk. + + Stored in the storage provider so that ingestion pipelines can retry just + the failed chunks without re-processing whole messages. + """ + + message_ordinal: int + chunk_ordinal: int + error_class: str + error_message: str + failed_at: Datetime + + class IReadonlyCollection[T, TOrdinal](AsyncIterable[T], Protocol): async def size(self) -> int: ... @@ -111,23 +126,31 @@ async def get_metadata_multiple( class IStorageProvider[TMessage: IMessage](Protocol): """API spec for storage providers -- maybe in-memory or persistent.""" - async def get_message_collection(self) -> IMessageCollection[TMessage]: ... + @property + def messages(self) -> IMessageCollection[TMessage]: ... - async def get_semantic_ref_collection(self) -> ISemanticRefCollection: ... + @property + def semantic_refs(self) -> ISemanticRefCollection: ... - # Index getters - ALL 6 index types for this conversation + # Index properties - ALL 6 index types for this conversation - async def get_semantic_ref_index(self) -> ITermToSemanticRefIndex: ... + @property + def semantic_ref_index(self) -> ITermToSemanticRefIndex: ... - async def get_property_index(self) -> IPropertyToSemanticRefIndex: ... + @property + def property_index(self) -> IPropertyToSemanticRefIndex: ... - async def get_timestamp_index(self) -> ITimestampToTextRangeIndex: ... + @property + def timestamp_index(self) -> ITimestampToTextRangeIndex: ... - async def get_message_text_index(self) -> IMessageTextIndex[TMessage]: ... + @property + def message_text_index(self) -> IMessageTextIndex[TMessage]: ... - async def get_related_terms_index(self) -> ITermToRelatedTermsIndex: ... + @property + def related_terms_index(self) -> ITermToRelatedTermsIndex: ... - async def get_conversation_threads(self) -> IConversationThreads: ... + @property + def conversation_threads(self) -> IConversationThreads: ... # Metadata management @@ -158,6 +181,10 @@ async def is_source_ingested(self, source_id: str) -> bool: """Check if a source has already been ingested.""" ... + async def are_sources_ingested(self, source_ids: list[str]) -> set[str]: + """Return the subset of source_ids that have already been ingested.""" + ... + async def get_source_status(self, source_id: str) -> str | None: """Get the ingestion status of a source.""" ... @@ -168,6 +195,39 @@ async def mark_source_ingested( """Mark a source as ingested (no commit; call within transaction context).""" ... + async def mark_sources_ingested_batch( + self, source_ids: list[str], status: str = STATUS_INGESTED + ) -> None: + """Mark multiple sources as ingested in one operation.""" + ... + + # Chunk-level extraction failure tracking + + async def record_chunk_failure( + self, + message_ordinal: int, + chunk_ordinal: int, + error_class: str, + error_message: str, + ) -> None: + """Record an extraction failure for a single chunk. + + Idempotent: re-recording overwrites any prior entry for the same + (message_ordinal, chunk_ordinal). No commit; call within transaction + context. + """ + ... + + async def clear_chunk_failure( + self, message_ordinal: int, chunk_ordinal: int + ) -> None: + """Remove the failure record for one chunk (e.g., after a retry succeeds).""" + ... + + async def get_chunk_failures(self) -> list[ChunkFailure]: + """Return all recorded chunk failures, ordered by message and chunk.""" + ... + # Transaction management async def __aenter__(self) -> Self: """Enter transaction context. Calls begin_transaction().""" @@ -198,6 +258,7 @@ class IConversation[ __all__ = [ + "ChunkFailure", "ConversationMetadata", "ICollection", "IConversation", diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index bb889e58..9dedefbe 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -5,31 +5,17 @@ from collections.abc import Callable from dataclasses import dataclass -from typechat import Result, TypeChatLanguageModel +from typechat import Result -from . import convknowledge from . import knowledge_schema as kplib -from ..aitools import model_adapters from .interfaces import IKnowledgeExtractor -def create_knowledge_extractor( - chat_model: TypeChatLanguageModel | None = None, -) -> convknowledge.KnowledgeExtractor: - """Create a knowledge extractor using the given Chat Model.""" - chat_model = chat_model or model_adapters.create_chat_model() - extractor = convknowledge.KnowledgeExtractor( - chat_model, max_chars_per_chunk=4096, merge_action_knowledge=False - ) - return extractor - - async def extract_knowledge_from_text( knowledge_extractor: IKnowledgeExtractor, text: str, ) -> Result[kplib.KnowledgeResponse]: - """Extract knowledge from a single text input with retries.""" - # TODO: Add a retry mechanism to handle transient errors. + """Extract knowledge from a single text input.""" return await knowledge_extractor.extract(text) @@ -47,7 +33,7 @@ async def batch_worker( async def extract_knowledge_from_text_batch( knowledge_extractor: IKnowledgeExtractor, text_batch: list[str], - concurrency: int = 2, + concurrency: int = 4, ) -> list[Result[kplib.KnowledgeResponse]]: """Extract knowledge from a batch of text inputs concurrently.""" if not text_batch: @@ -193,25 +179,3 @@ def merge_topics(topics: list[str]) -> list[str]: # TODO: Preserve order of first occurrence? merged_topics = set(topics) return list(merged_topics) - - -async def extract_knowledge_for_text_batch_q( - knowledge_extractor: convknowledge.KnowledgeExtractor, - text_batch: list[str], - concurrency: int = 2, -) -> list[Result[kplib.KnowledgeResponse]]: - """Extract knowledge for a batch of text inputs using a task queue.""" - raise NotImplementedError("TODO") - # TODO: BatchTask etc. - # task_batch = [BatchTask(task=text) for text in text_batch] - - # await run_in_batches( - # task_batch, - # lambda text: extract_knowledge_from_text(knowledge_extractor, text), - # concurrency, - # ) - - # results = [] - # for task in task_batch: - # results.append(task.result if task.result else Failure("No result")) - # return results diff --git a/src/typeagent/knowpro/messageutils.py b/src/typeagent/knowpro/messageutils.py index bd7cf879..6b4afb6e 100644 --- a/src/typeagent/knowpro/messageutils.py +++ b/src/typeagent/knowpro/messageutils.py @@ -5,7 +5,6 @@ from .interfaces import ( IMessage, - IMessageCollection, MessageOrdinal, TextLocation, TextRange, @@ -23,90 +22,28 @@ def text_range_from_message_chunk( ) -async def get_message_chunk_batch[TMessage: IMessage]( - messages: IMessageCollection[TMessage], - message_ordinal_start_at: MessageOrdinal, - batch_size: int, -) -> list[list[TextLocation]]: - """ - Get batches of message chunk locations for processing. - - Args: - messages: Collection of messages to process - message_ordinal_start_at: Starting message ordinal - batch_size: Number of message chunks per batch - - Yields: - Lists of TextLocation objects, each representing a message chunk - """ - batches: list[list[TextLocation]] = [] - current_batch: list[TextLocation] = [] - - message_ordinal = message_ordinal_start_at - async for message in messages: - if message_ordinal < message_ordinal_start_at: - message_ordinal += 1 - continue - - # Process each text chunk in the message - for chunk_ordinal in range(len(message.text_chunks)): - text_location = TextLocation( - message_ordinal=message_ordinal, - chunk_ordinal=chunk_ordinal, - ) - current_batch.append(text_location) - - # When batch is full, yield it and start a new one - if len(current_batch) >= batch_size: - batches.append(current_batch) - current_batch = [] - - message_ordinal += 1 - - # Don't forget the last batch if it has items - if current_batch: - batches.append(current_batch) - - return batches - - -def get_message_chunk_batch_from_list[TMessage: IMessage]( +def get_all_message_chunk_locations[TMessage: IMessage]( messages: list[TMessage], message_ordinal_start_at: MessageOrdinal, - batch_size: int, -) -> list[list[TextLocation]]: +) -> list[TextLocation]: """ - Get batches of message chunk locations for processing from a list of messages. + Get a flat list of all message chunk locations from a list of messages. Args: messages: List of messages to process message_ordinal_start_at: Starting message ordinal (ordinal of first message in list) - batch_size: Number of message chunks per batch Returns: - Lists of TextLocation objects, each representing a message chunk + Flat list of TextLocation objects, one per message chunk """ - batches: list[list[TextLocation]] = [] - current_batch: list[TextLocation] = [] - + locations: list[TextLocation] = [] for idx, message in enumerate(messages): message_ordinal = message_ordinal_start_at + idx - - # Process each text chunk in the message for chunk_ordinal in range(len(message.text_chunks)): - text_location = TextLocation( - message_ordinal=message_ordinal, - chunk_ordinal=chunk_ordinal, + locations.append( + TextLocation( + message_ordinal=message_ordinal, + chunk_ordinal=chunk_ordinal, + ) ) - current_batch.append(text_location) - - # When batch is full, yield it and start a new one - if len(current_batch) >= batch_size: - batches.append(current_batch) - current_batch = [] - - # Don't forget the last batch if it has items - if current_batch: - batches.append(current_batch) - - return batches + return locations diff --git a/src/typeagent/knowpro/secindex.py b/src/typeagent/knowpro/secindex.py index baee18b9..f101f9cb 100644 --- a/src/typeagent/knowpro/secindex.py +++ b/src/typeagent/knowpro/secindex.py @@ -22,32 +22,12 @@ def __init__( settings: RelatedTermIndexSettings, ): self._storage_provider = storage_provider - # Initialize all indexes through storage provider immediately - self.property_to_semantic_ref_index = None - self.timestamp_index = None - self.term_to_related_terms_index = None - self.threads = None - self.message_index = None - - @classmethod - async def create( - cls, - storage_provider: IStorageProvider, - settings: RelatedTermIndexSettings, - ) -> "ConversationSecondaryIndexes": - """Create and initialize a ConversationSecondaryIndexes with all indexes.""" - self = cls(storage_provider, settings) # Initialize all indexes from storage provider - self.property_to_semantic_ref_index = ( - await storage_provider.get_property_index() - ) - self.timestamp_index = await storage_provider.get_timestamp_index() - self.term_to_related_terms_index = ( - await storage_provider.get_related_terms_index() - ) - self.threads = await storage_provider.get_conversation_threads() - self.message_index = await storage_provider.get_message_text_index() - return self + self.property_to_semantic_ref_index = storage_provider.property_index + self.timestamp_index = storage_provider.timestamp_index + self.term_to_related_terms_index = storage_provider.related_terms_index + self.threads = storage_provider.conversation_threads + self.message_index = storage_provider.message_text_index async def build_secondary_indexes[ @@ -59,7 +39,7 @@ async def build_secondary_indexes[ ) -> None: if conversation.secondary_indexes is None: storage_provider = await conversation_settings.get_storage_provider() - conversation.secondary_indexes = await ConversationSecondaryIndexes.create( + conversation.secondary_indexes = ConversationSecondaryIndexes( storage_provider, conversation_settings.related_term_index_settings ) else: @@ -82,9 +62,9 @@ async def build_transient_secondary_indexes[ settings: ConversationSettings, ) -> None: if conversation.secondary_indexes is None: - conversation.secondary_indexes = await ConversationSecondaryIndexes.create( + conversation.secondary_indexes = ConversationSecondaryIndexes( await settings.get_storage_provider(), - (settings.related_term_index_settings), + settings.related_term_index_settings, ) await build_property_index(conversation) await build_timestamp_index(conversation) diff --git a/src/typeagent/knowpro/universal_message.py b/src/typeagent/knowpro/universal_message.py index 01abfdf9..c5008fe2 100644 --- a/src/typeagent/knowpro/universal_message.py +++ b/src/typeagent/knowpro/universal_message.py @@ -204,6 +204,11 @@ class ConversationMessage(IMessage): Format: "2024-01-01T12:34:56Z" or "1970-01-01T00:01:23Z" (epoch-based) MUST include "Z" suffix to explicitly indicate UTC timezone. """ + source_id: str | None = None + """ + Optional external identifier of the source this message was ingested from + (e.g., a transcript file path or podcast episode id). See ``IMessage.source_id``. + """ def get_knowledge(self) -> kplib.KnowledgeResponse: return self.metadata.get_knowledge() diff --git a/src/typeagent/podcasts/podcast.py b/src/typeagent/podcasts/podcast.py index 5376d20e..8038ea60 100644 --- a/src/typeagent/podcasts/podcast.py +++ b/src/typeagent/podcasts/podcast.py @@ -187,8 +187,8 @@ async def read_from_file( data = Podcast._read_conversation_data_from_file(filename_prefix) provider = await settings.get_storage_provider() - msgs = await provider.get_message_collection() - semrefs = await provider.get_semantic_ref_collection() + msgs = provider.messages + semrefs = provider.semantic_refs if await msgs.size() or await semrefs.size(): raise RuntimeError( f"Database {dbname!r} already has messages or semantic refs." diff --git a/src/typeagent/podcasts/podcast_ingest.py b/src/typeagent/podcasts/podcast_ingest.py index b124f003..2ff5c764 100644 --- a/src/typeagent/podcasts/podcast_ingest.py +++ b/src/typeagent/podcasts/podcast_ingest.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from collections.abc import AsyncIterator from datetime import timedelta import os import re @@ -8,6 +9,7 @@ from ..knowpro.convsettings import ConversationSettings from ..knowpro.interfaces import Datetime +from ..knowpro.interfaces_core import AddMessagesResult from ..knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH from ..storage.utils import create_storage_provider from .podcast import Podcast, PodcastMessage, PodcastMessageMeta @@ -22,6 +24,7 @@ async def ingest_podcast( dbname: str | None = None, batch_size: int = 0, start_message: int = 0, + concurrency: int = 0, verbose: bool = False, ) -> Podcast: """ @@ -37,8 +40,10 @@ async def ingest_podcast( date is unknown (Unix "timestamp left at zero" convention). length_minutes: Total length of podcast in minutes (for proportional timestamp allocation) dbname: Database name or None (to use in-memory non-persistent storage) - batch_size: Number of messages to index per batch (default all messages) + batch_size: Number of messages per call to add_messages_with_indexing + (default: all messages at once). Used for recoverability on crash. start_message: Number of initial messages to skip (for resuming interrupted ingests) + concurrency: Max concurrent knowledge extractions (0 = use settings default) verbose: Whether to print progress information (default False) Returns: @@ -109,7 +114,7 @@ async def ingest_podcast( PodcastMessage, ) settings.storage_provider = provider - msg_coll = await provider.get_message_collection() + msg_coll = provider.messages if (msg_size := await msg_coll.size()) > start_message: raise RuntimeError( f"{dbname!r} has {msg_size} messages; start_message ({start_message}) should be at least that." @@ -121,20 +126,46 @@ async def ingest_podcast( tags=[podcast_name], ) - # Add messages with indexing to build embeddings, using batch_size - batch_size = batch_size or len(msgs) - settings.semantic_ref_index_settings.batch_size = batch_size - for i in range(start_message, len(msgs), batch_size): - batch = msgs[i : i + batch_size] - t0 = time.time() - await pod.add_messages_with_indexing(batch) - t1 = time.time() + # Set source_id on each message for restartability + for i, msg in enumerate(msgs): + msg.source_id = f"{transcript_file_path}#{i}" + + # Add messages using the streaming API (commit-per-batch) + if concurrency: + settings.semantic_ref_index_settings.concurrency = concurrency + + async def _message_stream() -> AsyncIterator[PodcastMessage]: + for msg in msgs[start_message:]: + yield msg + + cumulative_messages = 0 + t0 = time.time() + + def _on_batch_committed(result: AddMessagesResult) -> None: + nonlocal cumulative_messages + batch_start = cumulative_messages + cumulative_messages += result.messages_added if verbose: print( - f"Indexed messages {i} to {i + len(batch) - 1} " - f"in {t1 - t0:.1f} seconds." + f"Indexed messages {batch_start}-{cumulative_messages - 1} " + f"({result.chunks_added} chunks, {result.semrefs_added} semrefs) " + f"at t={time.time() - t0:.1f} seconds." ) + batch_size = batch_size or len(msgs) + result = await pod.add_messages_streaming( + _message_stream(), + batch_size=batch_size, + on_batch_committed=_on_batch_committed, + ) + t1 = time.time() + if verbose: + print( + f"Indexed {result.messages_added} messages " + f"({result.chunks_added} chunks, {result.semrefs_added} semrefs) " + f"in {t1 - t0:.1f} seconds." + ) + return pod diff --git a/src/typeagent/storage/memory/messageindex.py b/src/typeagent/storage/memory/messageindex.py index 8d742794..efcc4ddf 100644 --- a/src/typeagent/storage/memory/messageindex.py +++ b/src/typeagent/storage/memory/messageindex.py @@ -30,7 +30,7 @@ async def build_message_index[ if csi is None: return if csi.message_index is None: - csi.message_index = await storage_provider.get_message_text_index() + csi.message_index = storage_provider.message_text_index messages = conversation.messages # Convert collection to list for add_messages messages_list = await messages.get_slice(0, await messages.size()) diff --git a/src/typeagent/storage/memory/provider.py b/src/typeagent/storage/memory/provider.py index 540d31b8..e697fe01 100644 --- a/src/typeagent/storage/memory/provider.py +++ b/src/typeagent/storage/memory/provider.py @@ -3,10 +3,11 @@ """In-memory storage provider implementation.""" -from datetime import datetime +from datetime import datetime, timezone from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings from ...knowpro.interfaces import ( + ChunkFailure, ConversationMetadata, IConversationThreads, IMessage, @@ -40,6 +41,7 @@ class MemoryStorageProvider[TMessage: IMessage](IStorageProvider[TMessage]): _related_terms_index: RelatedTermsIndex _conversation_threads: ConversationThreads _ingested_sources: set[str] + _chunk_failures: dict[tuple[int, int], ChunkFailure] def __init__( self, @@ -60,6 +62,7 @@ def __init__( thread_settings = message_text_settings.embedding_index_settings self._conversation_threads = ConversationThreads(thread_settings) self._ingested_sources = set() + self._chunk_failures = {} async def __aenter__(self) -> "MemoryStorageProvider[TMessage]": """Enter transaction context. No-op for in-memory storage.""" @@ -74,30 +77,36 @@ async def __aexit__( """Exit transaction context. No-op for in-memory storage.""" pass - async def get_semantic_ref_index(self) -> ITermToSemanticRefIndex: + @property + def semantic_ref_index(self) -> ITermToSemanticRefIndex: return self._conversation_index - async def get_property_index(self) -> IPropertyToSemanticRefIndex: + @property + def property_index(self) -> IPropertyToSemanticRefIndex: return self._property_index - async def get_timestamp_index(self) -> ITimestampToTextRangeIndex: + @property + def timestamp_index(self) -> ITimestampToTextRangeIndex: return self._timestamp_index - async def get_message_text_index(self) -> IMessageTextIndex[TMessage]: + @property + def message_text_index(self) -> IMessageTextIndex[TMessage]: return self._message_text_index - async def get_related_terms_index(self) -> ITermToRelatedTermsIndex: + @property + def related_terms_index(self) -> ITermToRelatedTermsIndex: return self._related_terms_index - async def get_conversation_threads(self) -> IConversationThreads: + @property + def conversation_threads(self) -> IConversationThreads: return self._conversation_threads - async def get_message_collection( - self, message_type: type[TMessage] | None = None - ) -> MemoryMessageCollection[TMessage]: + @property + def messages(self) -> MemoryMessageCollection[TMessage]: return self._message_collection - async def get_semantic_ref_collection(self) -> MemorySemanticRefCollection: + @property + def semantic_refs(self) -> MemorySemanticRefCollection: return self._semantic_ref_collection async def close(self) -> None: @@ -150,6 +159,10 @@ async def is_source_ingested(self, source_id: str) -> bool: """ return source_id in self._ingested_sources + async def are_sources_ingested(self, source_ids: list[str]) -> set[str]: + """Return the subset of source_ids that have already been ingested.""" + return self._ingested_sources & set(source_ids) + async def get_source_status(self, source_id: str) -> str | None: """Get the ingestion status of a source. @@ -172,3 +185,35 @@ async def mark_source_ingested( source_id: External source identifier (email ID, file path, etc.) """ self._ingested_sources.add(source_id) + + async def mark_sources_ingested_batch( + self, source_ids: list[str], status: str = STATUS_INGESTED + ) -> None: + """Mark multiple sources as ingested in one operation.""" + self._ingested_sources.update(source_ids) + + async def record_chunk_failure( + self, + message_ordinal: int, + chunk_ordinal: int, + error_class: str, + error_message: str, + ) -> None: + """Record a knowledge-extraction failure for a single chunk.""" + self._chunk_failures[(message_ordinal, chunk_ordinal)] = ChunkFailure( + message_ordinal=message_ordinal, + chunk_ordinal=chunk_ordinal, + error_class=error_class, + error_message=error_message, + failed_at=datetime.now(timezone.utc), + ) + + async def clear_chunk_failure( + self, message_ordinal: int, chunk_ordinal: int + ) -> None: + """Remove a previously recorded chunk failure (no-op if absent).""" + self._chunk_failures.pop((message_ordinal, chunk_ordinal), None) + + async def get_chunk_failures(self) -> list[ChunkFailure]: + """Return all recorded chunk failures, ordered by (msg_ordinal, chunk_ordinal).""" + return [self._chunk_failures[k] for k in sorted(self._chunk_failures)] diff --git a/src/typeagent/storage/memory/semrefindex.py b/src/typeagent/storage/memory/semrefindex.py index 773f9212..76feb2d1 100644 --- a/src/typeagent/storage/memory/semrefindex.py +++ b/src/typeagent/storage/memory/semrefindex.py @@ -30,7 +30,6 @@ ) from ...knowpro.knowledge import extract_knowledge_from_text_batch from ...knowpro.messageutils import ( - get_message_chunk_batch, text_range_from_message_chunk, ) @@ -49,8 +48,9 @@ async def add_batch_to_semantic_ref_index[ conversation: IConversation[TMessage, TTermToSemanticRefIndex], batch: list[TextLocation], knowledge_extractor: IKnowledgeExtractor, - terms_added: set[str] | None = None, + concurrency: int = 4, ) -> None: + """Extract knowledge and bulk-add to the semantic ref index.""" messages = conversation.messages text_batch = [ @@ -63,22 +63,20 @@ async def add_batch_to_semantic_ref_index[ knowledge_results = await extract_knowledge_from_text_batch( knowledge_extractor, text_batch, - len(text_batch), + concurrency, ) + bulk_items: list[tuple[int, int, kplib.KnowledgeResponse]] = [] for i, knowledge_result in enumerate(knowledge_results): if isinstance(knowledge_result, Failure): raise RuntimeError( f"Knowledge extraction failed: {knowledge_result.message}" ) - text_location = batch[i] - knowledge = knowledge_result.value - await add_knowledge_to_semantic_ref_index( - conversation, - text_location.message_ordinal, - text_location.chunk_ordinal, - knowledge, - terms_added, + tl = batch[i] + bulk_items.append( + (tl.message_ordinal, tl.chunk_ordinal, knowledge_result.value) ) + if bulk_items: + await add_knowledge_batch_to_semantic_ref_index(conversation, bulk_items) async def add_batch_to_semantic_ref_index_from_list[ @@ -88,55 +86,40 @@ async def add_batch_to_semantic_ref_index_from_list[ messages: list[TMessage], batch: list[TextLocation], knowledge_extractor: IKnowledgeExtractor, - terms_added: set[str] | None = None, + concurrency: int = 4, ) -> None: - """ - Add a batch of knowledge to semantic ref index, extracting from provided message list. - - Args: - conversation: The conversation containing semantic refs and index - messages: List of messages containing the text to extract from - batch: List of text locations (ordinals) to process - knowledge_extractor: Extractor for LLM-based knowledge extraction - terms_added: Optional set to track newly added terms - """ - # Get the starting ordinal of the message list + """Extract knowledge from messages and bulk-add to the semantic ref index.""" if not batch: return start_ordinal = batch[0].message_ordinal - # Extract text from the messages list text_batch = [] for tl in batch: - # Calculate index in the list from the ordinal list_index = tl.message_ordinal - start_ordinal if list_index < 0 or list_index >= len(messages): raise IndexError( - f"Message ordinal {tl.message_ordinal} out of range for list starting at {start_ordinal}" + f"Message ordinal {tl.message_ordinal} out of range " + f"for list starting at {start_ordinal}" ) - message = messages[list_index] - text = message.text_chunks[tl.chunk_ordinal].strip() - text_batch.append(text) + text_batch.append(messages[list_index].text_chunks[tl.chunk_ordinal].strip()) knowledge_results = await extract_knowledge_from_text_batch( knowledge_extractor, text_batch, - len(text_batch), + concurrency, ) + bulk_items: list[tuple[int, int, kplib.KnowledgeResponse]] = [] for i, knowledge_result in enumerate(knowledge_results): if isinstance(knowledge_result, Failure): raise RuntimeError( f"Knowledge extraction failed: {knowledge_result.message:.150}" ) - text_location = batch[i] - knowledge = knowledge_result.value - await add_knowledge_to_semantic_ref_index( - conversation, - text_location.message_ordinal, - text_location.chunk_ordinal, - knowledge, - terms_added, + tl = batch[i] + bulk_items.append( + (tl.message_ordinal, tl.chunk_ordinal, knowledge_result.value) ) + if bulk_items: + await add_knowledge_batch_to_semantic_ref_index(conversation, bulk_items) async def add_term_to_index( @@ -356,22 +339,89 @@ async def add_action( # TODO:L KnowledgeValidator +def _collect_knowledge_refs_and_terms( + base_ordinal: SemanticRefOrdinal, + message_ordinal: MessageOrdinal, + chunk_ordinal: int, + knowledge: kplib.KnowledgeResponse, +) -> tuple[list[SemanticRef], list[tuple[str, SemanticRefOrdinal]]]: + """Collect SemanticRefs and index terms without writing to storage.""" + refs: list[SemanticRef] = [] + terms: list[tuple[str, SemanticRefOrdinal]] = [] + ordinal = base_ordinal + text_range = text_range_from_message_chunk(message_ordinal, chunk_ordinal) + + for entity in knowledge.entities: + if not validate_entity(entity): + continue + refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range, + knowledge=entity, + ) + ) + terms.append((entity.name, ordinal)) + for type_name in entity.type: + terms.append((type_name, ordinal)) + if entity.facets: + for facet in entity.facets: + if facet is not None: + terms.append((facet.name, ordinal)) + if facet.value is not None: + terms.append((str(facet.value), ordinal)) + ordinal += 1 + + for action in list(knowledge.actions) + list(knowledge.inverse_actions): + refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range, + knowledge=action, + ) + ) + terms.append((" ".join(action.verbs), ordinal)) + if action.subject_entity_name != "none": + terms.append((action.subject_entity_name, ordinal)) + if action.object_entity_name != "none": + terms.append((action.object_entity_name, ordinal)) + if action.indirect_object_entity_name != "none": + terms.append((action.indirect_object_entity_name, ordinal)) + if action.params: + for param in action.params: + if isinstance(param, str): + terms.append((param, ordinal)) + else: + terms.append((param.name, ordinal)) + if isinstance(param.value, str): + terms.append((param.value, ordinal)) + if action.subject_entity_facet is not None: + terms.append((action.subject_entity_facet.name, ordinal)) + if action.subject_entity_facet.value is not None: + terms.append((str(action.subject_entity_facet.value), ordinal)) + ordinal += 1 + + for topic_text in knowledge.topics: + refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range, + knowledge=Topic(text=topic_text), + ) + ) + terms.append((topic_text, ordinal)) + ordinal += 1 + + return refs, terms + + async def add_knowledge_to_semantic_ref_index( conversation: IConversation, message_ordinal: MessageOrdinal, chunk_ordinal: int, knowledge: kplib.KnowledgeResponse, - terms_added: set[str] | None = None, ) -> None: - """Add knowledge to the semantic reference index of a conversation. - - Args: - conversation: The conversation to add knowledge to - message_ordinal: Ordinal of the message containing the knowledge - chunk_ordinal: Ordinal of the chunk within the message - knowledge: Knowledge response containing entities, actions and topics - terms_added: Optional set to track terms added to the index - """ + """Add knowledge to the semantic reference index of a conversation.""" verify_has_semantic_ref_index(conversation) semantic_refs = conversation.semantic_refs @@ -379,47 +429,52 @@ async def add_knowledge_to_semantic_ref_index( semantic_ref_index = conversation.semantic_ref_index assert semantic_ref_index is not None - for entity in knowledge.entities: - if validate_entity(entity): - await add_entity( - entity, - semantic_refs, - semantic_ref_index, - message_ordinal, - chunk_ordinal, - terms_added, - ) + base_ordinal = await semantic_refs.size() + refs, terms = _collect_knowledge_refs_and_terms( + base_ordinal, + message_ordinal, + chunk_ordinal, + knowledge, + ) - for action in knowledge.actions: - await add_action( - action, - semantic_refs, - semantic_ref_index, - message_ordinal, - chunk_ordinal, - terms_added, - ) + if refs: + await semantic_refs.extend(refs) + if terms: + await semantic_ref_index.add_terms_batch(terms) - for inverse_action in knowledge.inverse_actions: - await add_action( - inverse_action, - semantic_refs, - semantic_ref_index, - message_ordinal, - chunk_ordinal, - terms_added, - ) - for topic in knowledge.topics: - topic_obj = Topic(text=topic) - await add_topic( - topic_obj, - semantic_refs, - semantic_ref_index, - message_ordinal, - chunk_ordinal, - terms_added, +async def add_knowledge_batch_to_semantic_ref_index( + conversation: IConversation, + items: list[tuple[MessageOrdinal, int, kplib.KnowledgeResponse]], +) -> None: + """Bulk-add knowledge from multiple chunks in two DB round-trips.""" + if not items: + return + verify_has_semantic_ref_index(conversation) + + semantic_refs = conversation.semantic_refs + assert semantic_refs is not None + semantic_ref_index = conversation.semantic_ref_index + assert semantic_ref_index is not None + + all_refs: list[SemanticRef] = [] + all_terms: list[tuple[str, SemanticRefOrdinal]] = [] + base_ordinal = await semantic_refs.size() + + for msg_ord, chunk_ord, knowledge in items: + refs, terms = _collect_knowledge_refs_and_terms( + base_ordinal + len(all_refs), + msg_ord, + chunk_ord, + knowledge, ) + all_refs.extend(refs) + all_terms.extend(terms) + + if all_refs: + await semantic_refs.extend(all_refs) + if all_terms: + await semantic_ref_index.add_terms_batch(all_terms) def validate_entity(entity: kplib.ConcreteEntity) -> bool: @@ -721,30 +776,37 @@ async def add_to_semantic_ref_index[ conversation: IConversation[TMessage, TTermToSemanticRefIndex], settings: SemanticRefIndexSettings, message_ordinal_start_at: MessageOrdinal, - terms_added: set[str] | None = None, ) -> None: """Add semantic references to the conversation's semantic reference index.""" + if not settings.auto_extract_knowledge: + return - # Only create knowledge extractor if auto extraction is enabled - knowledge_extractor = None - if settings.auto_extract_knowledge: - knowledge_extractor = ( - settings.knowledge_extractor or convknowledge.KnowledgeExtractor() - ) + knowledge_extractor = ( + settings.knowledge_extractor or convknowledge.KnowledgeExtractor() + ) - # Process messages in batches for LLM knowledge extraction - batches = await get_message_chunk_batch( - conversation.messages, - message_ordinal_start_at, - settings.batch_size, - ) - for text_location_batch in batches: - await add_batch_to_semantic_ref_index( - conversation, - text_location_batch, - knowledge_extractor, - terms_added, + text_locations: list[TextLocation] = [] + message_ordinal = message_ordinal_start_at + async for message in conversation.messages: + if message_ordinal < message_ordinal_start_at: + message_ordinal += 1 + continue + for chunk_ordinal in range(len(message.text_chunks)): + text_locations.append( + TextLocation( + message_ordinal=message_ordinal, + chunk_ordinal=chunk_ordinal, + ) ) + message_ordinal += 1 + + if text_locations: + await add_batch_to_semantic_ref_index( + conversation, + text_locations, + knowledge_extractor, + concurrency=settings.concurrency, + ) def verify_has_semantic_ref_index(conversation: IConversation) -> None: diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 3d5a3185..a8ba9c06 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -11,6 +11,8 @@ from ...knowpro import interfaces from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings from ...knowpro.interfaces import ConversationMetadata, STATUS_INGESTED +from ...knowpro.interfaces_storage import ChunkFailure +from ..memory.convthreads import ConversationThreads from .collections import SqliteMessageCollection, SqliteSemanticRefCollection from .messageindex import SqliteMessageTextIndex from .propindex import SqlitePropertyIndex @@ -99,6 +101,11 @@ def __init__( self.db, self.related_term_index_settings.embedding_index_settings ) + # Initialize conversation threads + self._conversation_threads = ConversationThreads( + self.message_text_index_settings.embedding_index_settings + ) + # Connect message collection to message text index for automatic indexing self._message_collection.set_message_text_index(self._message_text_index) @@ -324,7 +331,7 @@ def semantic_refs(self) -> SqliteSemanticRefCollection: return self._semantic_ref_collection @property - def term_to_semantic_ref_index(self) -> SqliteTermToSemanticRefIndex: + def semantic_ref_index(self) -> SqliteTermToSemanticRefIndex: return self._term_to_semantic_ref_index @property @@ -343,46 +350,9 @@ def message_text_index(self) -> SqliteMessageTextIndex: def related_terms_index(self) -> SqliteRelatedTermsIndex: return self._related_terms_index - # Async getters required by base class - async def get_message_collection( - self, message_type: type[TMessage] | None = None - ) -> interfaces.IMessageCollection[TMessage]: - """Get the message collection.""" - return self._message_collection - - async def get_semantic_ref_collection(self) -> interfaces.ISemanticRefCollection: - """Get the semantic reference collection.""" - return self._semantic_ref_collection - - async def get_semantic_ref_index(self) -> interfaces.ITermToSemanticRefIndex: - """Get the semantic reference index.""" - return self._term_to_semantic_ref_index - - async def get_property_index(self) -> interfaces.IPropertyToSemanticRefIndex: - """Get the property index.""" - return self._property_index - - async def get_timestamp_index(self) -> interfaces.ITimestampToTextRangeIndex: - """Get the timestamp index.""" - return self._timestamp_index - - async def get_message_text_index(self) -> interfaces.IMessageTextIndex[TMessage]: - """Get the message text index.""" - return self._message_text_index - - async def get_related_terms_index(self) -> interfaces.ITermToRelatedTermsIndex: - """Get the related terms index.""" - return self._related_terms_index - - async def get_conversation_threads(self) -> interfaces.IConversationThreads: - """Get the conversation threads.""" - # For now, return a simple implementation - # In a full implementation, this would be stored/retrieved from SQLite - from ...storage.memory.convthreads import ConversationThreads - - return ConversationThreads( - self.message_text_index_settings.embedding_index_settings - ) + @property + def conversation_threads(self) -> ConversationThreads: + return self._conversation_threads async def clear(self) -> None: """Clear all data from the storage provider.""" @@ -594,6 +564,26 @@ async def is_source_ingested(self, source_id: str) -> bool: row = cursor.fetchone() return row is not None and row[0] == STATUS_INGESTED + async def are_sources_ingested(self, source_ids: list[str]) -> set[str]: + """Return the subset of source_ids that have already been ingested.""" + if not source_ids: + return set() + cursor = self.db.cursor() + result: set[str] = set() + # Chunk to stay within SQLite's SQLITE_MAX_VARIABLE_NUMBER + # (999 on older builds, 32766 on 3.32.0+). + chunk_size = 500 + for i in range(0, len(source_ids), chunk_size): + chunk = source_ids[i : i + chunk_size] + placeholders = ",".join("?" for _ in chunk) + cursor.execute( + f"SELECT source_id FROM IngestedSources" + f" WHERE source_id IN ({placeholders}) AND status = ?", + [*chunk, STATUS_INGESTED], + ) + result.update(row[0] for row in cursor.fetchall()) + return result + async def get_source_status(self, source_id: str) -> str | None: """Get the ingestion status of a source. @@ -627,3 +617,68 @@ async def mark_source_ingested( "INSERT OR REPLACE INTO IngestedSources (source_id, status) VALUES (?, ?)", (source_id, status), ) + + async def mark_sources_ingested_batch( + self, source_ids: list[str], status: str = STATUS_INGESTED + ) -> None: + """Mark multiple sources as ingested in one operation.""" + if not source_ids: + return + cursor = self.db.cursor() + cursor.executemany( + "INSERT OR REPLACE INTO IngestedSources (source_id, status) VALUES (?, ?)", + [(sid, status) for sid in source_ids], + ) + + async def record_chunk_failure( + self, + message_ordinal: int, + chunk_ordinal: int, + error_class: str, + error_message: str, + ) -> None: + """Record a knowledge-extraction failure for a single chunk. + + Idempotent: re-recording overwrites any prior entry for the same + (message_ordinal, chunk_ordinal). No commit; call within a transaction + context. + """ + failed_at = datetime.now(timezone.utc).isoformat() + cursor = self.db.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO ChunkFailures + (msg_id, chunk_ordinal, error_class, error_message, failed_at) + VALUES (?, ?, ?, ?, ?) + """, + (message_ordinal, chunk_ordinal, error_class, error_message, failed_at), + ) + + async def clear_chunk_failure( + self, message_ordinal: int, chunk_ordinal: int + ) -> None: + """Remove a previously recorded chunk failure (no-op if absent).""" + cursor = self.db.cursor() + cursor.execute( + "DELETE FROM ChunkFailures WHERE msg_id = ? AND chunk_ordinal = ?", + (message_ordinal, chunk_ordinal), + ) + + async def get_chunk_failures(self) -> list[ChunkFailure]: + """Return all recorded chunk failures, ordered by (msg_id, chunk_ordinal).""" + cursor = self.db.cursor() + cursor.execute(""" + SELECT msg_id, chunk_ordinal, error_class, error_message, failed_at + FROM ChunkFailures + ORDER BY msg_id, chunk_ordinal + """) + return [ + ChunkFailure( + message_ordinal=row[0], + chunk_ordinal=row[1], + error_class=row[2], + error_message=row[3], + failed_at=datetime.fromisoformat(row[4]), + ) + for row in cursor.fetchall() + ] diff --git a/src/typeagent/storage/sqlite/reltermsindex.py b/src/typeagent/storage/sqlite/reltermsindex.py index dec29db2..cf5b201b 100644 --- a/src/typeagent/storage/sqlite/reltermsindex.py +++ b/src/typeagent/storage/sqlite/reltermsindex.py @@ -209,30 +209,24 @@ async def get_terms(self) -> list[str]: return [row[0] for row in cursor.fetchall()] async def add_terms(self, texts: list[str]) -> None: - """Add terms.""" + """Add terms with batched embedding generation and DB writes.""" + new_terms = [t for t in texts if t not in self._added_terms] + if not new_terms: + return + + embeddings = await self._vector_base.add_keys(new_terms) + assert embeddings is not None + cursor = self.db.cursor() - # TODO: Batch additions to database - for text in texts: - if text in self._added_terms: - continue - - # Add to VectorBase for fuzzy lookup - await self._vector_base.add_key(text) - self._terms_list.append(text) - self._added_terms.add(text) - - # Generate embedding for term and store in database - embedding = await self._vector_base.get_embedding(text) # Cached - serialized_embedding = serialize_embedding(embedding) - # Insert term and embedding - cursor.execute( - """ - INSERT OR REPLACE INTO RelatedTermsFuzzy - (term, term_embedding) - VALUES (?, ?) - """, - (text, serialized_embedding), - ) + cursor.executemany( + "INSERT OR REPLACE INTO RelatedTermsFuzzy (term, term_embedding) VALUES (?, ?)", + [ + (term, serialize_embedding(embeddings[i])) + for i, term in enumerate(new_terms) + ], + ) + self._terms_list.extend(new_terms) + self._added_terms.update(new_terms) async def lookup_terms( self, diff --git a/src/typeagent/storage/sqlite/schema.py b/src/typeagent/storage/sqlite/schema.py index db6933db..99117c24 100644 --- a/src/typeagent/storage/sqlite/schema.py +++ b/src/typeagent/storage/sqlite/schema.py @@ -148,6 +148,28 @@ ); """ +# Table for tracking knowledge-extraction failures at the chunk level. +# Each row records a (message_ordinal, chunk_ordinal) pair whose extraction +# failed (typically because the LLM returned malformed JSON or an invalid +# schema). The message text itself is still stored in the Messages table; only +# the *enrichment* of that chunk is missing. A future "re-extract" tool can +# read this table to retry just the failed chunks. +CHUNK_FAILURES_SCHEMA = """ +CREATE TABLE IF NOT EXISTS ChunkFailures ( + msg_id INTEGER NOT NULL, -- Message ordinal (matches Messages.msg_id) + chunk_ordinal INTEGER NOT NULL, -- 0-based index into the message's text_chunks + error_class TEXT NOT NULL, -- Fully-qualified class name of the failure + error_message TEXT NOT NULL, -- Human-readable failure description + failed_at TEXT NOT NULL, -- ISO-8601 UTC timestamp of the failure + + PRIMARY KEY (msg_id, chunk_ordinal) +); +""" + +CHUNK_FAILURES_MSG_INDEX = """ +CREATE INDEX IF NOT EXISTS idx_chunk_failures_msg ON ChunkFailures(msg_id); +""" + # Type aliases for database row tuples type ShreddedMessage = tuple[ str | None, str | None, str | None, str | None, str | None, str | None @@ -271,6 +293,7 @@ def init_db_schema(db: sqlite3.Connection) -> None: cursor.execute(RELATED_TERMS_FUZZY_SCHEMA) cursor.execute(TIMESTAMP_INDEX_SCHEMA) cursor.execute(INGESTED_SOURCES_SCHEMA) + cursor.execute(CHUNK_FAILURES_SCHEMA) # Create additional indexes cursor.execute(SEMANTIC_REF_INDEX_TERM_INDEX) @@ -279,6 +302,7 @@ def init_db_schema(db: sqlite3.Connection) -> None: cursor.execute(RELATED_TERMS_ALIASES_TERM_INDEX) cursor.execute(RELATED_TERMS_ALIASES_ALIAS_INDEX) cursor.execute(RELATED_TERMS_FUZZY_TERM_INDEX) + cursor.execute(CHUNK_FAILURES_MSG_INDEX) def get_db_schema_version(db: sqlite3.Connection) -> int: diff --git a/src/typeagent/storage/sqlite/timestampindex.py b/src/typeagent/storage/sqlite/timestampindex.py index 1419b340..8fe017dd 100644 --- a/src/typeagent/storage/sqlite/timestampindex.py +++ b/src/typeagent/storage/sqlite/timestampindex.py @@ -88,12 +88,13 @@ async def add_timestamps( self, message_timestamps: list[tuple[interfaces.MessageOrdinal, str]] ) -> None: """Add multiple timestamps.""" + if not message_timestamps: + return cursor = self.db.cursor() - for message_ordinal, timestamp in message_timestamps: - cursor.execute( - "UPDATE Messages SET start_timestamp = ? WHERE msg_id = ?", - (timestamp, message_ordinal), - ) + cursor.executemany( + "UPDATE Messages SET start_timestamp = ? WHERE msg_id = ?", + [(ts, ordinal) for ordinal, ts in message_timestamps], + ) async def lookup_range( self, date_range: interfaces.DateRange diff --git a/src/typeagent/transcripts/transcript.py b/src/typeagent/transcripts/transcript.py index 5033e293..08c4fdae 100644 --- a/src/typeagent/transcripts/transcript.py +++ b/src/typeagent/transcripts/transcript.py @@ -187,8 +187,8 @@ async def read_from_file( data = Transcript._read_conversation_data_from_file(filename_prefix) provider = await settings.get_storage_provider() - msgs = await provider.get_message_collection() - semrefs = await provider.get_semantic_ref_collection() + msgs = provider.messages + semrefs = provider.semantic_refs if await msgs.size() or await semrefs.size(): raise RuntimeError( f"Database {dbname!r} already has messages or semantic refs." diff --git a/tests/conftest.py b/tests/conftest.py index 7f0f11f5..cd707c09 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,10 @@ from dotenv import load_dotenv import pytest import pytest_asyncio +import stamina + +stamina.set_testing(True) + from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.model_adapters import create_test_embedding_model @@ -328,7 +332,7 @@ async def ensure_initialized(self): storage_provider = await self.settings.get_storage_provider() self._storage_provider = storage_provider if self.semantic_ref_index is None: - self.semantic_ref_index = await storage_provider.get_semantic_ref_index() # type: ignore + self.semantic_ref_index = storage_provider.semantic_ref_index # type: ignore if self._has_secondary_indexes: # Set up secondary indexes diff --git a/tests/test_add_messages_streaming.py b/tests/test_add_messages_streaming.py new file mode 100644 index 00000000..9a8b1fd6 --- /dev/null +++ b/tests/test_add_messages_streaming.py @@ -0,0 +1,842 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for add_messages_streaming.""" + +import asyncio +from collections.abc import AsyncIterator +import os +import tempfile + +import pytest + +import typechat + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import AddMessagesResult, IKnowledgeExtractor +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_message( + text: str, + speaker: str = "Alice", + source_id: str | None = None, +) -> TranscriptMessage: + return TranscriptMessage( + text_chunks=[text], + metadata=TranscriptMessageMeta(speaker=speaker), + tags=["test"], + source_id=source_id, + ) + + +async def _create_transcript( + db_path: str, + *, + auto_extract: bool = False, + knowledge_extractor: IKnowledgeExtractor | None = None, +) -> tuple[Transcript, SqliteStorageProvider]: + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = auto_extract + if knowledge_extractor is not None: + settings.semantic_ref_index_settings.knowledge_extractor = knowledge_extractor + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + transcript = await Transcript.create(settings, name="test") + return transcript, storage + + +async def _async_iter( + items: list[TranscriptMessage], +) -> AsyncIterator[TranscriptMessage]: + for item in items: + yield item + + +def _ingested_count(storage: SqliteStorageProvider) -> int: + cursor = storage.db.cursor() + cursor.execute("SELECT COUNT(*) FROM IngestedSources") + return cursor.fetchone()[0] + + +def _failure_count(storage: SqliteStorageProvider) -> int: + cursor = storage.db.cursor() + cursor.execute("SELECT COUNT(*) FROM ChunkFailures") + return cursor.fetchone()[0] + + +# --------------------------------------------------------------------------- +# A test IKnowledgeExtractor that lets us control per-call results +# --------------------------------------------------------------------------- + +_EMPTY_RESPONSE = kplib.KnowledgeResponse( + entities=[], actions=[], inverse_actions=[], topics=[] +) + + +class ControlledExtractor: + """An IKnowledgeExtractor that returns Success or Failure per call. + + ``fail_on`` is a set of 0-based call indices for which the extractor + returns a Failure instead of a Success. + ``raise_on`` is a set of call indices that raise an exception. + """ + + def __init__( + self, + *, + fail_on: set[int] | None = None, + raise_on: set[int] | None = None, + ) -> None: + self.fail_on = fail_on or set() + self.raise_on = raise_on or set() + self.call_count = 0 + + async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: + idx = self.call_count + self.call_count += 1 + if idx in self.raise_on: + raise RuntimeError(f"Systemic failure at call {idx}") + if idx in self.fail_on: + return typechat.Failure(f"Extraction failed for call {idx}") + return typechat.Success(_EMPTY_RESPONSE) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_streaming_basic() -> None: + """Streaming ingest of a few messages with no extraction.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}") for i in range(5)] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 5 + assert await transcript.messages.size() == 5 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_batching() -> None: + """Messages are committed in batches of the requested size.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(7)] + result = await transcript.add_messages_streaming( + _async_iter(msgs), batch_size=3 + ) + + # 3 batches: [0,1,2], [3,4,5], [6] + assert result.messages_added == 7 + assert await transcript.messages.size() == 7 + # All 7 sources marked + assert _ingested_count(storage) == 7 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_no_source_id_always_ingested() -> None: + """Messages without source_id are always ingested (never skipped).""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}") for i in range(3)] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 3 + assert _ingested_count(storage) == 0 # no source IDs to track + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_records_chunk_failures() -> None: + """Extraction Failure results are recorded, not raised.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor(fail_on={1}) # second chunk fails + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + _make_message("good chunk 0"), + _make_message("bad chunk 1"), + _make_message("good chunk 2"), + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 3 + assert _failure_count(storage) == 1 + + failures = await storage.get_chunk_failures() + assert len(failures) == 1 + assert failures[0].message_ordinal == 1 + assert failures[0].chunk_ordinal == 0 + assert "Extraction failed" in failures[0].error_message + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_exception_stops_run() -> None: + """A raised exception stops processing; committed batches survive.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Raise on the 4th extract call (first chunk of second batch) + extractor = ControlledExtractor(raise_on={3}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + + with pytest.raises(ExceptionGroup) as exc_info: + await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) + + # Verify the wrapped exception is our RuntimeError + assert any( + isinstance(e, RuntimeError) and "Systemic failure" in str(e) + for e in exc_info.value.exceptions + ) + + # First batch (3 messages, 3 extract calls 0-2) committed + assert await transcript.messages.size() == 3 + assert _ingested_count(storage) == 3 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_empty_iterable() -> None: + """Streaming with no messages returns zeros.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + result = await transcript.add_messages_streaming(_async_iter([])) + + assert result.messages_added == 0 + assert result.semrefs_added == 0 + + await storage.close() + + +@pytest.mark.asyncio +# --------------------------------------------------------------------------- +# Pipeline overlap and DB batching tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_streaming_on_batch_committed_fires_per_batch() -> None: + """on_batch_committed fires once per non-empty batch with the pipelined approach.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(7)] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=3, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 7 + # 3 batches: [0,1,2], [3,4,5], [6] + assert batch_results == [3, 3, 1] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_extraction_with_multiple_batches() -> None: + """Extraction results are correctly applied across batches with ordinal remapping.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor() + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + result = await transcript.add_messages_streaming( + _async_iter(msgs), batch_size=3 + ) + + assert result.messages_added == 6 + assert await transcript.messages.size() == 6 + # All 6 chunks extracted (no failures) + assert extractor.call_count == 6 + assert _failure_count(storage) == 0 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_extraction_failure_across_batches() -> None: + """Extraction failures are recorded with correct global ordinals across batches.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Fail on call index 1 (batch 0, msg 1) and 4 (batch 1, msg 1) + extractor = ControlledExtractor(fail_on={1, 4}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + result = await transcript.add_messages_streaming( + _async_iter(msgs), batch_size=3 + ) + + assert result.messages_added == 6 + assert _failure_count(storage) == 2 + + failures = await storage.get_chunk_failures() + failure_ordinals = sorted(f.message_ordinal for f in failures) + # msg 1 in batch 0 → global ordinal 1, msg 1 in batch 1 → global ordinal 4 + assert failure_ordinals == [1, 4] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_exception_in_later_batch_preserves_earlier() -> None: + """A raised exception in batch 1 stops processing; batch 0 is committed.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Raise on call 4 (first call of batch 1, since batch 0 has 3 msgs) + extractor = ControlledExtractor(raise_on={3}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + with pytest.raises(ExceptionGroup) as exc_info: + await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) + + assert any( + isinstance(e, RuntimeError) and "Systemic failure" in str(e) + for e in exc_info.value.exceptions + ) + + # Batch 0 committed (3 messages), batch 1 rolled back + assert await transcript.messages.size() == 3 + assert _ingested_count(storage) == 3 + + await storage.close() + + +@pytest.mark.asyncio +async def test_mark_sources_ingested_batch_sqlite() -> None: + """mark_sources_ingested_batch marks multiple sources in one call.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + _, storage = await _create_transcript(db_path) + + async with storage: + await storage.mark_sources_ingested_batch(["a", "b", "c"]) + + assert await storage.is_source_ingested("a") + assert await storage.is_source_ingested("b") + assert await storage.is_source_ingested("c") + assert not await storage.is_source_ingested("d") + assert _ingested_count(storage) == 3 + + await storage.close() + + +@pytest.mark.asyncio +async def test_mark_sources_ingested_batch_empty() -> None: + """mark_sources_ingested_batch with empty list is a no-op.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + _, storage = await _create_transcript(db_path) + + async with storage: + await storage.mark_sources_ingested_batch([]) + + assert _ingested_count(storage) == 0 + + await storage.close() + + +@pytest.mark.asyncio +async def test_mark_sources_ingested_batch_idempotent() -> None: + """mark_sources_ingested_batch is idempotent via INSERT OR REPLACE.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + _, storage = await _create_transcript(db_path) + + async with storage: + await storage.mark_sources_ingested_batch(["a", "b"]) + async with storage: + await storage.mark_sources_ingested_batch(["b", "c"]) + + assert _ingested_count(storage) == 3 + assert await storage.is_source_ingested("a") + assert await storage.is_source_ingested("b") + assert await storage.is_source_ingested("c") + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_extraction_with_empty_text_chunks() -> None: + """Messages with empty text_chunks skip extraction gracefully.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor() + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + TranscriptMessage( + text_chunks=[], + metadata=TranscriptMessageMeta(speaker="Alice"), + tags=["test"], + source_id="empty-chunks", + ), + _make_message("has content", source_id="has-content"), + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 2 + # Only the message with content triggers extraction + assert extractor.call_count == 1 + + await storage.close() + + +# --------------------------------------------------------------------------- +# Multi-chunk messages and chunk-based batching +# --------------------------------------------------------------------------- + + +def _make_multi_chunk_message( + chunks: list[str], + speaker: str = "Alice", + source_id: str | None = None, +) -> TranscriptMessage: + return TranscriptMessage( + text_chunks=chunks, + metadata=TranscriptMessageMeta(speaker=speaker), + tags=["test"], + source_id=source_id, + ) + + +@pytest.mark.asyncio +async def test_streaming_multi_chunk_extraction() -> None: + """Each chunk in a multi-chunk message triggers a separate extraction call.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor() + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + _make_multi_chunk_message(["c0", "c1", "c2"], source_id="s-0"), + _make_message("single chunk", source_id="s-1"), + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 2 + assert result.chunks_added == 4 # 3 + 1 + # 4 extraction calls: one per chunk + assert extractor.call_count == 4 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_batch_size_counts_chunks() -> None: + """batch_size counts chunks, not messages — a 3-chunk message fills batch_size=3.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_multi_chunk_message(["a", "b", "c"], source_id="s-0"), # 3 chunks + _make_message("d", source_id="s-1"), # 1 chunk + ] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=3, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 2 + # First message (3 chunks) fills batch_size=3, second message goes to batch 2 + assert batch_results == [1, 1] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_large_message_exceeds_batch_size() -> None: + """A single message with more chunks than batch_size becomes its own batch.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_multi_chunk_message( + [f"chunk-{i}" for i in range(5)], source_id="s-big" + ), + _make_message("small", source_id="s-small"), + ] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=3, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 2 + # 5-chunk msg in batch 1, then 1-chunk msg in batch 2 + assert batch_results == [1, 1] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_mixed_chunk_sizes_batching() -> None: + """Messages of varying chunk counts are batched by cumulative chunk count.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_message("a", source_id="s-0"), # 1 chunk, total=1 + _make_multi_chunk_message( + ["b1", "b2"], source_id="s-1" + ), # 2 chunks, total=3 → flush + _make_message("c", source_id="s-2"), # 1 chunk, total=1 + _make_message("d", source_id="s-3"), # 1 chunk, total=2 + _make_message("e", source_id="s-4"), # 1 chunk, total=3 → flush + ] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=3, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 5 + assert result.chunks_added == 6 + # Batch 1: msgs 0+1 (3 chunks), Batch 2: msgs 2+3+4 (3 chunks) + assert batch_results == [2, 3] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_multi_chunk_failure_ordinals() -> None: + """Extraction failures in multi-chunk messages record correct ordinals.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Fail on call index 1 (chunk 1 of first message) and 3 (chunk 0 of second message) + extractor = ControlledExtractor(fail_on={1, 3}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + _make_multi_chunk_message( + ["c0", "c1", "c2"], source_id="s-0" + ), # calls 0,1,2 + _make_multi_chunk_message(["d0", "d1"], source_id="s-1"), # calls 3,4 + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 2 + assert extractor.call_count == 5 + assert _failure_count(storage) == 2 + + failures = await storage.get_chunk_failures() + failure_locs = sorted((f.message_ordinal, f.chunk_ordinal) for f in failures) + # call 1 → msg 0, chunk 1; call 3 → msg 1, chunk 0 + assert failure_locs == [(0, 1), (1, 0)] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_multi_chunk_exception_preserves_earlier_batch() -> None: + """Exception during extraction of multi-chunk batch preserves committed batches.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Batch 1: 3-chunk msg (calls 0,1,2). Batch 2: 2-chunk msg (calls 3,4) — raise on 3 + extractor = ControlledExtractor(raise_on={3}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + _make_multi_chunk_message(["a", "b", "c"], source_id="s-0"), # batch 1 + _make_multi_chunk_message(["d", "e"], source_id="s-1"), # batch 2 + ] + + with pytest.raises(ExceptionGroup): + await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) + + # Batch 1 committed (1 message, 3 chunks), batch 2 rolled back + assert await transcript.messages.size() == 1 + assert _ingested_count(storage) == 1 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_batch_size_1_separates_all() -> None: + """batch_size=1 commits every single-chunk message individually.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(4)] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=1, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 4 + assert batch_results == [1, 1, 1, 1] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_preflush_avoids_oversized_batch() -> None: + """Adding a message that would exceed batch_size flushes first. + + With batch_size=10 and four 3-chunk messages, batches should be + [msg0,msg1,msg2] (9 chunks) and [msg3] (3 chunks) — never a single + batch of 12 chunks. + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_multi_chunk_message( + [f"m{i}c{j}" for j in range(3)], source_id=f"s-{i}" + ) + for i in range(4) + ] + batch_chunks: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=10, + on_batch_committed=lambda r: batch_chunks.append(r.chunks_added), + ) + + assert result.messages_added == 4 + assert result.chunks_added == 12 + # Batch 1: 3 msgs × 3 chunks = 9, Batch 2: 1 msg × 3 chunks = 3 + assert batch_chunks == [9, 3] + + await storage.close() + + +# --------------------------------------------------------------------------- +# Coverage gap tests +# --------------------------------------------------------------------------- + + +class SlowExtractor: + """Extractor that blocks on an event, allowing tests to control timing.""" + + def __init__(self, block_from: int) -> None: + self.call_count = 0 + self.block_from = block_from + self.blocked = asyncio.Event() + self.cancelled = False + + async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: + idx = self.call_count + self.call_count += 1 + if idx >= self.block_from: + self.blocked.set() + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + self.cancelled = True + raise + return typechat.Success(_EMPTY_RESPONSE) + + +@pytest.mark.asyncio +async def test_streaming_pending_extraction_cancelled_on_commit_failure() -> None: + """pending_extraction is cancelled when a prior commit raises during _drain_commit. + + Timeline: + 1. Batch 0: extraction succeeds (calls 0-2, fast), commit task created + (pending_commit = failing_commit) + 2. Batch 1: extraction task created (pending_extraction, calls 3+, slow), + _drain_commit awaits batch 0's pending_commit which raises + 3. except block: pending_extraction (batch 1's) is still in-flight → cancelled + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Block extraction starting from call 3 (first call of batch 1) + # so that pending_extraction is still running when the except fires + extractor = SlowExtractor(block_from=3) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + async def failing_commit(*args, **kwargs): + raise RuntimeError("Simulated commit failure") + + transcript._commit_batch_streaming = failing_commit # type: ignore[assignment] + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + + with pytest.raises(RuntimeError, match="Simulated commit failure"): + await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) + + assert extractor.cancelled + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_pending_commit_cancelled_on_iterator_error() -> None: + """pending_commit is cancelled when the message iterator raises. + + After batch 0 is submitted (pending_commit in flight), the async iterator + raises on the next message. The except block must cancel the still-running + pending_commit. + """ + + async def _error_after( + items: list[TranscriptMessage], error_after: int + ) -> AsyncIterator[TranscriptMessage]: + for i, item in enumerate(items): + if i == error_after: + # Yield to event loop so pending tasks start running + await asyncio.sleep(0) + raise ValueError("Iterator error") + yield item + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + commit_cancelled = False + + async def slow_commit(*args, **kwargs): + nonlocal commit_cancelled + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + commit_cancelled = True + raise + return AddMessagesResult() + + transcript._commit_batch_streaming = slow_commit # type: ignore[assignment] + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + + with pytest.raises(ValueError, match="Iterator error"): + await transcript.add_messages_streaming( + _error_after(msgs, error_after=4), batch_size=3 + ) + + assert commit_cancelled + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_empty_iterator() -> None: + """Streaming with an empty iterator returns zeros.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + # Ingest one real message, then do a second call with an empty iterator + msgs = [_make_message("msg-0", source_id="s-0")] + r1 = await transcript.add_messages_streaming(_async_iter(msgs)) + assert r1.messages_added == 1 + + # Empty iterator → _submit_batch never called with content + r2 = await transcript.add_messages_streaming(_async_iter([])) + assert r2.messages_added == 0 + assert r2.messages_skipped == 0 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_extraction_returns_none_for_empty_chunks() -> None: + """_extract_knowledge_for_batch returns None when no text_locations exist. + + Messages with empty text_chunks produce no TextLocations, so extraction + should be skipped entirely. + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor() + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + TranscriptMessage( + text_chunks=[], + metadata=TranscriptMessageMeta(speaker="Alice"), + tags=["test"], + source_id="empty-0", + ), + TranscriptMessage( + text_chunks=[], + metadata=TranscriptMessageMeta(speaker="Bob"), + tags=["test"], + source_id="empty-1", + ), + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 2 + assert result.chunks_added == 0 + # No extraction calls since there are no chunks + assert extractor.call_count == 0 + + await storage.close() diff --git a/tests/test_convthreads.py b/tests/test_convthreads.py new file mode 100644 index 00000000..e4d5e2d5 --- /dev/null +++ b/tests/test_convthreads.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for storage/memory/convthreads.py.""" + +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings +from typeagent.knowpro.interfaces import TextLocation, TextRange, Thread +from typeagent.knowpro.interfaces_serialization import ConversationThreadData +from typeagent.storage.memory.convthreads import ConversationThreads + + +@pytest.fixture +def settings() -> TextEmbeddingIndexSettings: + return TextEmbeddingIndexSettings(create_test_embedding_model()) + + +@pytest.fixture +def threads(settings: TextEmbeddingIndexSettings) -> ConversationThreads: + return ConversationThreads(settings) + + +def make_thread(description: str, start: int = 0, end: int = 1) -> Thread: + return Thread( + description=description, + ranges=[ + TextRange(start=TextLocation(start), end=TextLocation(end)), + ], + ) + + +@pytest.mark.asyncio +async def test_add_thread_appends(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("topic one")) + assert len(threads.threads) == 1 + assert threads.threads[0].description == "topic one" + + +@pytest.mark.asyncio +async def test_add_multiple_threads(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("alpha")) + await threads.add_thread(make_thread("beta")) + await threads.add_thread(make_thread("gamma")) + assert len(threads.threads) == 3 + + +@pytest.mark.asyncio +async def test_clear_resets_state(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("something")) + threads.clear() + assert len(threads.threads) == 0 + assert len(threads.vector_base) == 0 + + +@pytest.mark.asyncio +async def test_build_index_rebuilds_from_threads(threads: ConversationThreads) -> None: + # Manually add threads without building the vector index. + t1 = make_thread("python programming") + t2 = make_thread("data science") + threads.threads.append(t1) + threads.threads.append(t2) + # build_index should embed all existing threads. + await threads.build_index() + assert len(threads.vector_base) == 2 + + +@pytest.mark.asyncio +async def test_serialize_roundtrip(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("episode one", 0, 5)) + await threads.add_thread(make_thread("episode two", 5, 10)) + + data = threads.serialize() + assert "threads" in data + thread_list = data["threads"] + assert thread_list is not None + assert len(thread_list) == 2 + + # Deserialize into a fresh instance. + settings = TextEmbeddingIndexSettings(create_test_embedding_model()) + fresh = ConversationThreads(settings) + fresh.deserialize(data) + assert len(fresh.threads) == 2 + assert fresh.threads[0].description == "episode one" + assert fresh.threads[1].description == "episode two" + + +@pytest.mark.asyncio +async def test_deserialize_empty_data(threads: ConversationThreads) -> None: + data: ConversationThreadData = {} # type: ignore[typeddict-item] + threads.deserialize(data) + assert len(threads.threads) == 0 + + +@pytest.mark.asyncio +async def test_serialize_without_embeddings(threads: ConversationThreads) -> None: + # Add a thread without going through add_thread (so no embedding yet). + threads.threads.append(make_thread("bare thread")) + data = threads.serialize() + thread_list = data["threads"] + assert thread_list is not None + assert len(thread_list) == 1 + # Embedding may be None because vector_base has no entries for this slot. + assert thread_list[0]["embedding"] is None or isinstance( + thread_list[0]["embedding"], list + ) + + +@pytest.mark.asyncio +async def test_lookup_thread_returns_matches(threads: ConversationThreads) -> None: + await threads.add_thread(make_thread("machine learning and AI")) + await threads.add_thread(make_thread("cooking recipes")) + results = await threads.lookup_thread("artificial intelligence") + assert len(results) > 0 + assert results[0].thread_ordinal == 0 # ordinal of the matching thread + + +@pytest.mark.asyncio +async def test_lookup_thread_empty_index(threads: ConversationThreads) -> None: + results = await threads.lookup_thread("anything") + assert results == [] diff --git a/tests/test_convutils.py b/tests/test_convutils.py new file mode 100644 index 00000000..b9ac654c --- /dev/null +++ b/tests/test_convutils.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +from typeagent.knowpro.convutils import ( + get_time_range_for_conversation, + get_time_range_prompt_section_for_conversation, +) + +from conftest import FakeConversation, FakeMessage + + +class TestGetTimeRangeForConversation: + @pytest.mark.asyncio + async def test_empty_conversation_returns_none(self) -> None: + conv = FakeConversation(messages=[]) + result = await get_time_range_for_conversation(conv) + assert result is None + + @pytest.mark.asyncio + async def test_message_without_timestamp_returns_none(self) -> None: + msg = FakeMessage("hello") # no message_ordinal → timestamp=None + conv = FakeConversation(messages=[msg]) + result = await get_time_range_for_conversation(conv) + assert result is None + + @pytest.mark.asyncio + async def test_single_message_with_timestamp(self) -> None: + msg = FakeMessage("hello", message_ordinal=0) + conv = FakeConversation(messages=[msg]) + result = await get_time_range_for_conversation(conv) + assert result is not None + assert result.start.isoformat().startswith("2020-01-01T00") + + @pytest.mark.asyncio + async def test_multiple_messages_range_start_end(self) -> None: + msgs = [FakeMessage(f"msg{i}", message_ordinal=i) for i in range(3)] + conv = FakeConversation(messages=msgs) + result = await get_time_range_for_conversation(conv) + assert result is not None + assert result.start < result.end # type: ignore[operator] + + +class TestGetTimeRangePromptSection: + @pytest.mark.asyncio + async def test_no_timestamps_returns_none(self) -> None: + conv = FakeConversation(messages=[FakeMessage("hello")]) + result = await get_time_range_prompt_section_for_conversation(conv) + assert result is None + + @pytest.mark.asyncio + async def test_with_timestamps_returns_prompt_section(self) -> None: + msgs = [FakeMessage(f"msg{i}", message_ordinal=i) for i in range(2)] + conv = FakeConversation(messages=msgs) + result = await get_time_range_prompt_section_for_conversation(conv) + assert result is not None + assert result["role"] == "system" + assert "CONVERSATION TIME RANGE" in result["content"] + assert "2020-01-01" in result["content"] diff --git a/tests/test_email_message.py b/tests/test_email_message.py new file mode 100644 index 00000000..17930486 --- /dev/null +++ b/tests/test_email_message.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typeagent.emails.email_message import EmailMessage, EmailMessageMeta + + +def make_meta( + sender: str = "Alice ", + recipients: list[str] | None = None, + cc: list[str] | None = None, + bcc: list[str] | None = None, + subject: str | None = None, +) -> EmailMessageMeta: + return EmailMessageMeta( + sender=sender, + recipients=recipients or [], + cc=cc or [], + bcc=bcc or [], + subject=subject, + ) + + +class TestEmailMessageMetaProperties: + def test_source_returns_sender(self) -> None: + meta = make_meta(sender="bob@example.com") + assert meta.source == "bob@example.com" + + def test_dest_returns_recipients(self) -> None: + meta = make_meta(recipients=["a@b.com", "c@d.com"]) + assert meta.dest == ["a@b.com", "c@d.com"] + + def test_dest_empty_list(self) -> None: + meta = make_meta(recipients=[]) + assert meta.dest == [] + + +class TestEmailAddressToEntities: + def test_plain_address_no_display_name(self) -> None: + meta = make_meta() + entities = meta._email_address_to_entities("bob@example.com") + names = [e.name for e in entities] + assert "bob@example.com" in names + assert len(entities) == 1 + + def test_address_with_display_name(self) -> None: + meta = make_meta() + entities = meta._email_address_to_entities("Alice ") + names = [e.name for e in entities] + assert "Alice" in names + assert "alice@example.com" in names + assert len(entities) == 2 + + def test_display_name_entity_has_email_facet(self) -> None: + meta = make_meta() + entities = meta._email_address_to_entities("Alice ") + person_entity = next(e for e in entities if e.name == "Alice") + assert person_entity.facets is not None + assert len(person_entity.facets) == 1 + assert person_entity.facets[0].name == "email_address" + assert person_entity.facets[0].value == "alice@example.com" + + def test_display_name_only_no_address(self) -> None: + # parseaddr("Alice") returns ("", "Alice") — treated as address only + meta = make_meta() + entities = meta._email_address_to_entities("Alice") + # No display name, just the address "Alice" + assert len(entities) == 1 + assert entities[0].name == "Alice" + + +class TestToEntities: + def test_entities_include_sender(self) -> None: + meta = make_meta(sender="Alice ") + entities = meta.to_entities() + names = [e.name for e in entities] + assert "Alice" in names + assert "alice@example.com" in names + + def test_entities_include_recipient(self) -> None: + meta = make_meta( + sender="alice@example.com", + recipients=["Bob "], + ) + entities = meta.to_entities() + names = [e.name for e in entities] + assert "Bob" in names + assert "bob@example.com" in names + + def test_entities_include_cc(self) -> None: + meta = make_meta( + sender="a@x.com", + cc=["cc@example.com"], + ) + entities = meta.to_entities() + names = [e.name for e in entities] + assert "cc@example.com" in names + + def test_entities_include_bcc(self) -> None: + meta = make_meta( + sender="a@x.com", + bcc=["bcc@example.com"], + ) + entities = meta.to_entities() + names = [e.name for e in entities] + assert "bcc@example.com" in names + + def test_entities_always_include_email_message_entity(self) -> None: + meta = make_meta() + entities = meta.to_entities() + msg_entity = next((e for e in entities if e.name == "email"), None) + assert msg_entity is not None + assert "message" in msg_entity.type + + +class TestToTopics: + def test_no_subject_returns_empty(self) -> None: + meta = make_meta(subject=None) + assert meta.to_topics() == [] + + def test_subject_returned_as_topic(self) -> None: + meta = make_meta(subject="Hello World") + topics = meta.to_topics() + assert topics == ["Hello World"] + + +class TestToActions: + def test_no_recipients_returns_empty(self) -> None: + meta = make_meta(sender="alice@example.com", recipients=[]) + assert meta.to_actions() == [] + + def test_sent_and_received_actions_created(self) -> None: + meta = make_meta( + sender="Alice ", + recipients=["Bob "], + ) + actions = meta.to_actions() + verbs = [a.verbs[0] for a in actions] + assert "sent" in verbs + assert "received" in verbs + + def test_multiple_recipients_produce_actions(self) -> None: + meta = make_meta( + sender="alice@example.com", + recipients=["bob@example.com", "carol@example.com"], + ) + actions = meta.to_actions() + assert len(actions) > 0 + + def test_action_subject_is_sender(self) -> None: + meta = make_meta( + sender="alice@example.com", + recipients=["bob@example.com"], + ) + actions = meta.to_actions() + sent_actions = [a for a in actions if "sent" in a.verbs] + assert all(a.subject_entity_name == "alice@example.com" for a in sent_actions) + + +class TestGetKnowledge: + def test_get_knowledge_returns_response(self) -> None: + meta = make_meta( + sender="Alice ", + recipients=["Bob "], + subject="Test Subject", + ) + result = meta.get_knowledge() + assert result is not None + assert len(result.entities) > 0 + assert len(result.topics) > 0 + assert len(result.actions) > 0 + + +class TestEmailMessage: + def test_basic_construction(self) -> None: + meta = make_meta(sender="alice@example.com") + msg = EmailMessage( + text_chunks=["Hello world"], + metadata=meta, + ) + assert msg.text_chunks == ["Hello world"] + assert msg.metadata is meta + + def test_get_knowledge_delegates_to_metadata(self) -> None: + meta = make_meta( + sender="Alice ", + recipients=["bob@example.com"], + subject="Hi", + ) + msg = EmailMessage(text_chunks=["body"], metadata=meta) + result = msg.get_knowledge() + assert result is not None + + def test_add_timestamp(self) -> None: + meta = make_meta() + msg = EmailMessage(text_chunks=["body"], metadata=meta) + msg.add_timestamp("2025-01-01T00:00:00") + assert msg.timestamp == "2025-01-01T00:00:00" + + def test_add_content_empty_chunks(self) -> None: + meta = make_meta() + msg = EmailMessage(text_chunks=[], metadata=meta) + msg.add_content("new content") + assert msg.text_chunks == ["new content"] + + def test_add_content_existing_chunk(self) -> None: + meta = make_meta() + msg = EmailMessage(text_chunks=["existing"], metadata=meta) + msg.add_content(" more") + assert msg.text_chunks[0] == "existing more" + + def test_serialize_roundtrip(self) -> None: + meta = make_meta( + sender="Alice ", + recipients=["bob@example.com"], + subject="Hi", + ) + msg = EmailMessage(text_chunks=["Hello"], metadata=meta, tags=["work"]) + data = msg.serialize() + assert isinstance(data, dict) + restored = EmailMessage.deserialize(data) + assert restored.text_chunks == msg.text_chunks + assert restored.metadata.sender == msg.metadata.sender + assert restored.tags == msg.tags diff --git a/tests/test_knowledge.py b/tests/test_knowledge.py index 0374dfec..44515ada 100644 --- a/tests/test_knowledge.py +++ b/tests/test_knowledge.py @@ -8,7 +8,6 @@ from typeagent.knowpro import convknowledge from typeagent.knowpro import knowledge_schema as kplib from typeagent.knowpro.knowledge import ( - create_knowledge_extractor, extract_knowledge_from_text, extract_knowledge_from_text_batch, merge_concrete_entities, @@ -34,12 +33,6 @@ def mock_knowledge_extractor() -> convknowledge.KnowledgeExtractor: return MockKnowledgeExtractor() # type: ignore -def test_create_knowledge_extractor(really_needs_auth: None): - """Test creating a knowledge extractor.""" - extractor = create_knowledge_extractor() - assert isinstance(extractor, convknowledge.KnowledgeExtractor) - - @pytest.mark.asyncio async def test_extract_knowledge_from_text( mock_knowledge_extractor: convknowledge.KnowledgeExtractor, diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 16577f86..7e3bd3b7 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -6,9 +6,8 @@ import json import os import sys -from types import SimpleNamespace from typing import Any, cast -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -22,11 +21,19 @@ SamplingMessage, TextContent, ) -from openai.types.chat import ChatCompletionMessageParam import typechat -from typeagent.aitools.utils import create_async_openai_client, resolve_azure_model_name -from typeagent.mcp.server import MCPTypeChatModel, QuestionResponse +from typeagent.aitools.model_adapters import create_chat_model +from typeagent.knowpro import answers, searchlang +from typeagent.knowpro.answer_response_schema import AnswerResponse +from typeagent.knowpro.convsettings import ConversationSettings +import typeagent.mcp.server as typeagent_mcp_server +from typeagent.mcp.server import ( + load_podcast_database_or_index, + MCPTypeChatModel, + ProcessingContext, + QuestionResponse, +) from conftest import EPISODE_53_INDEX @@ -52,49 +59,32 @@ async def sampling_callback( params: CreateMessageRequestParams, ) -> CreateMessageResult: """Sampling callback that uses OpenAI to generate responses.""" - client = create_async_openai_client() + model = create_chat_model() - # Convert MCP SamplingMessage to OpenAI format - messages: list[ChatCompletionMessageParam] = [] + # Convert MCP SamplingMessage to TypeChat PromptSection list + sections: list[typechat.PromptSection] = [] + if params.systemPrompt: + sections.append({"role": "system", "content": params.systemPrompt}) for msg in params.messages: - # Handle TextContent - content: str if isinstance(msg.content, TextContent): content = msg.content.text else: raise ValueError( f"Unsupported content type in sampling message: {type(msg.content)}" ) + role = "user" if msg.role == "user" else "assistant" + sections.append({"role": role, "content": content}) - # MCP roles are "user" or "assistant", which are compatible with OpenAI - if msg.role == "user": - messages.append({"role": "user", "content": content}) - else: - messages.append({"role": "assistant", "content": content}) + result = await model.complete(sections) + if isinstance(result, typechat.Success): + text = result.value + else: + text = result.message - # Add system prompt if provided - if params.systemPrompt: - messages.insert(0, {"role": "system", "content": params.systemPrompt}) - - # Call OpenAI - model_name = "gpt-4o" - if os.getenv("AZURE_OPENAI_API_KEY") and not os.getenv("OPENAI_API_KEY"): - model_name = resolve_azure_model_name(model_name) - - response = await client.chat.completions.create( - model=model_name, - messages=messages, - max_tokens=params.maxTokens, - temperature=params.temperature if params.temperature is not None else 1.0, - ) - - # Convert response to MCP format return CreateMessageResult( role="assistant", - content=TextContent( - type="text", text=response.choices[0].message.content or "" - ), - model=response.model, + content=TextContent(type="text", text=text), + model="gpt-4o", stopReason="endTurn", ) @@ -204,9 +194,7 @@ async def test_mcp_server_empty_question(server_params: StdioServerParameters): def test_server_module_imports() -> None: """Importing the server module should not raise even without coverage.""" - import typeagent.mcp.server as mod - - assert hasattr(mod, "mcp") # The FastMCP instance exists + assert hasattr(typeagent_mcp_server, "mcp") # The FastMCP instance exists # --------------------------------------------------------------------------- @@ -342,8 +330,6 @@ def test_known_types(self) -> None: def test_answer_type_coverage(self) -> None: """AnswerResponse.type should only be 'Answered' or 'NoAnswer'.""" - from typeagent.knowpro.answer_response_schema import AnswerResponse - answered = AnswerResponse(type="Answered", answer="yes") assert answered.type == "Answered" no_answer = AnswerResponse(type="NoAnswer", why_no_answer="dunno") @@ -351,39 +337,17 @@ def test_answer_type_coverage(self) -> None: @pytest.mark.asyncio -async def test_sampling_callback_uses_azure_deployment_name( +async def test_sampling_callback_delegates_to_chat_model( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Azure-only sampling should send the resolved deployment name.""" - create = AsyncMock( - return_value=SimpleNamespace( - choices=[ - SimpleNamespace( - message=SimpleNamespace(content="response"), - ) - ], - model="gpt-4o-2", - ) - ) - fake_client = SimpleNamespace( - chat=SimpleNamespace( - completions=SimpleNamespace( - create=create, - ) - ) - ) + """sampling_callback should delegate to create_chat_model().complete().""" + fake_model = AsyncMock() + fake_model.complete.return_value = typechat.Success("response") - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-key") - monkeypatch.setattr( - sys.modules[__name__], - "create_async_openai_client", - lambda: fake_client, - ) monkeypatch.setattr( sys.modules[__name__], - "resolve_azure_model_name", - lambda model_name: f"{model_name}-2", + "create_chat_model", + lambda: fake_model, ) params = CreateMessageRequestParams( @@ -401,10 +365,75 @@ async def test_sampling_callback_uses_azure_deployment_name( params, ) - create.assert_awaited_once_with( - model="gpt-4o-2", - messages=[{"role": "user", "content": "hello"}], - max_tokens=32, - temperature=1.0, - ) - assert result.model == "gpt-4o-2" + fake_model.complete.assert_awaited_once() + call_args = fake_model.complete.call_args[0][0] + assert call_args == [{"role": "user", "content": "hello"}] + assert isinstance(result.content, TextContent) + assert result.content.text == "response" + + +# --------------------------------------------------------------------------- +# MCPTypeChatModel — additional response format coverage +# --------------------------------------------------------------------------- + + +class TestMCPTypeChatModelResponseFormats: + @staticmethod + def _make_model_with_result(content: Any) -> MCPTypeChatModel: + session = AsyncMock() + session.create_message.return_value = AsyncMock(content=content) + return MCPTypeChatModel(session) + + @pytest.mark.asyncio + async def test_list_content_no_text_items_returns_failure(self) -> None: + """A list response with no TextContent items should return Failure.""" + # Use a non-TextContent item type (ImageContent would work but we mock with a dict) + model = self._make_model_with_result([]) + result = await model.complete("test") + assert isinstance(result, typechat.Failure) + assert "No text content" in result.message + + @pytest.mark.asyncio + async def test_unknown_content_type_returns_failure(self) -> None: + """A response with an unrecognized content type should return Failure.""" + # Simulate some unknown object that is neither TextContent nor list + model = self._make_model_with_result(42) + result = await model.complete("test") + assert isinstance(result, typechat.Failure) + assert "No text content" in result.message + + +# --------------------------------------------------------------------------- +# ProcessingContext.__repr__ +# --------------------------------------------------------------------------- + + +class TestProcessingContextRepr: + def test_repr_contains_options(self) -> None: + lang_opts = searchlang.LanguageSearchOptions(max_message_matches=10) + ctx_opts = answers.AnswerContextOptions(entities_top_k=5) + + proc = ProcessingContext( + lang_search_options=lang_opts, + answer_context_options=ctx_opts, + query_context=MagicMock(), + embedding_model=MagicMock(), + query_translator=MagicMock(), + answer_translator=MagicMock(), + ) + r = repr(proc) + assert r.startswith("Context(") + assert "LanguageSearchOptions" in r + + +# --------------------------------------------------------------------------- +# load_podcast_database_or_index — ValueError path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_load_podcast_no_args_raises() -> None: + """Passing neither dbname nor podcast_index must raise ValueError.""" + settings = ConversationSettings() + with pytest.raises(ValueError, match="Either --database or --podcast-index"): + await load_podcast_database_or_index(settings, dbname=None, podcast_index=None) diff --git a/tests/test_memory_semrefindex.py b/tests/test_memory_semrefindex.py new file mode 100644 index 00000000..723cdd41 --- /dev/null +++ b/tests/test_memory_semrefindex.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for storage/memory/semrefindex.py helper functions.""" + +import pytest + +from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.interfaces import Topic +from typeagent.storage.memory import MemorySemanticRefCollection +from typeagent.storage.memory.semrefindex import ( + add_action, + add_entity, + add_facet, + add_term_to_index, + add_topic, +) + +from conftest import FakeTermIndex + + +def make_semrefs() -> MemorySemanticRefCollection: + return MemorySemanticRefCollection([]) + + +def make_index() -> FakeTermIndex: + return FakeTermIndex() + + +# --------------------------------------------------------------------------- +# add_term_to_index +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_term_to_index_basic() -> None: + index = make_index() + terms_added: set[str] = set() + await add_term_to_index(index, "hello", 0, terms_added) + assert "hello" in terms_added + assert await index.size() == 1 + + +@pytest.mark.asyncio +async def test_add_term_to_index_no_terms_added_set() -> None: + index = make_index() + await add_term_to_index(index, "world", 1) + assert await index.size() == 1 + + +@pytest.mark.asyncio +async def test_add_term_empty_string_is_stored() -> None: + """The function does not filter empty terms — delegated to the index.""" + index = make_index() + await add_term_to_index(index, "", 0) + # FakeTermIndex stores empty strings too + assert await index.size() == 1 + + +# --------------------------------------------------------------------------- +# add_facet +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_facet_none_does_nothing() -> None: + index = make_index() + await add_facet(None, 0, index) + assert await index.size() == 0 + + +@pytest.mark.asyncio +async def test_add_facet_string_value() -> None: + index = make_index() + facet = kplib.Facet(name="colour", value="red") + await add_facet(facet, 0, index) + terms = await index.get_terms() + assert "colour" in terms + assert "red" in terms + + +@pytest.mark.asyncio +async def test_add_facet_numeric_value() -> None: + index = make_index() + facet = kplib.Facet(name="count", value=42.0) + await add_facet(facet, 0, index) + terms = await index.get_terms() + assert "count" in terms + assert "42.0" in terms + + +# --------------------------------------------------------------------------- +# add_entity +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_entity_registers_name_and_types() -> None: + semrefs = make_semrefs() + index = make_index() + entity = kplib.ConcreteEntity(name="Alice", type=["person", "employee"]) + terms_added: set[str] = set() + await add_entity( + entity, + semrefs, + index, + message_ordinal=0, + chunk_ordinal=0, + terms_added=terms_added, + ) + assert "Alice" in terms_added + assert "person" in terms_added + assert "employee" in terms_added + assert await semrefs.size() == 1 + + +@pytest.mark.asyncio +async def test_add_entity_with_facets() -> None: + semrefs = make_semrefs() + index = make_index() + entity = kplib.ConcreteEntity( + name="Book", + type=["item"], + facets=[kplib.Facet(name="genre", value="fiction")], + ) + await add_entity(entity, semrefs, index, message_ordinal=1, chunk_ordinal=0) + terms = await index.get_terms() + assert "genre" in terms + assert "fiction" in terms + + +# --------------------------------------------------------------------------- +# add_topic +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_topic_registers_text() -> None: + semrefs = make_semrefs() + index = make_index() + topic = Topic(text="machine learning") + terms_added: set[str] = set() + await add_topic( + topic, + semrefs, + index, + message_ordinal=2, + chunk_ordinal=0, + terms_added=terms_added, + ) + assert "machine learning" in terms_added + assert await semrefs.size() == 1 + + +# --------------------------------------------------------------------------- +# add_action +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_action_registers_verbs() -> None: + semrefs = make_semrefs() + index = make_index() + action = kplib.Action( + verbs=["run", "execute"], + verb_tense="present", + subject_entity_name="Alice", + object_entity_name="script", + indirect_object_entity_name="none", + ) + terms_added: set[str] = set() + await add_action( + action, + semrefs, + index, + message_ordinal=0, + chunk_ordinal=0, + terms_added=terms_added, + ) + terms = set(await index.get_terms()) + assert "run execute" in terms + assert "Alice" in terms + assert "script" in terms + assert await semrefs.size() == 1 + + +@pytest.mark.asyncio +async def test_add_action_none_entities_skipped() -> None: + semrefs = make_semrefs() + index = make_index() + action = kplib.Action( + verbs=["go"], + verb_tense="present", + subject_entity_name="none", + object_entity_name="none", + indirect_object_entity_name="none", + ) + await add_action(action, semrefs, index, message_ordinal=0, chunk_ordinal=0) + terms = await index.get_terms() + assert "none" not in terms + assert "go" in terms diff --git a/tests/test_message_text_index_population.py b/tests/test_message_text_index_population.py index 13d53c00..a069d979 100644 --- a/tests/test_message_text_index_population.py +++ b/tests/test_message_text_index_population.py @@ -59,7 +59,7 @@ async def test_message_text_index_population_from_database(): ), ] - msg_collection = await storage1.get_message_collection() + msg_collection = storage1.messages await msg_collection.extend(test_messages) assert await msg_collection.size() == len(test_messages) @@ -74,7 +74,7 @@ async def test_message_text_index_population_from_database(): ) # Check message collection size - msg_collection2 = await storage2.get_message_collection() + msg_collection2 = storage2.messages msg_count = await msg_collection2.size() print(f"Message collection size: {msg_count}") assert msg_count == len( @@ -82,7 +82,7 @@ async def test_message_text_index_population_from_database(): ), f"Expected {len(test_messages)} messages, got {msg_count}" # Check message text index - msg_text_index = await storage2.get_message_text_index() + msg_text_index = storage2.message_text_index # Check that it implements the interface correctly from typeagent.knowpro.interfaces import IMessageTextIndex diff --git a/tests/test_messageutils.py b/tests/test_messageutils.py new file mode 100644 index 00000000..97c61c13 --- /dev/null +++ b/tests/test_messageutils.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typeagent.knowpro.interfaces import TextLocation, TextRange +from typeagent.knowpro.messageutils import ( + text_range_from_message_chunk, +) + + +class TestTextRangeFromMessageChunk: + def test_default_chunk_ordinal(self) -> None: + tr = text_range_from_message_chunk(message_ordinal=3) + assert tr.start == TextLocation(3, 0) + assert tr.end is None + + def test_explicit_chunk_ordinal(self) -> None: + tr = text_range_from_message_chunk(message_ordinal=5, chunk_ordinal=2) + assert tr.start == TextLocation(5, 2) + assert tr.end is None + + def test_returns_text_range(self) -> None: + tr = text_range_from_message_chunk(0) + assert isinstance(tr, TextRange) diff --git a/tests/test_podcasts.py b/tests/test_podcasts.py index 02ccf8e8..d77f6ba3 100644 --- a/tests/test_podcasts.py +++ b/tests/test_podcasts.py @@ -118,7 +118,7 @@ async def test_ingest_podcast( @pytest.mark.asyncio -async def test_ingest_podcast_parallelism_uses_batch_size( +async def test_ingest_podcast_parallelism_uses_concurrency( temp_dir: str, embedding_model: IEmbeddingModel ) -> None: transcript_path = os.path.join(temp_dir, "parallel_podcast.txt") @@ -130,15 +130,15 @@ async def test_ingest_podcast_parallelism_uses_batch_size( extractor = TrackingKnowledgeExtractor() settings.semantic_ref_index_settings.knowledge_extractor = extractor - batch_size = 20 + concurrency = 5 podcast = await podcast_ingest.ingest_podcast( transcript_path, settings, start_date=Datetime.now(timezone.utc), length_minutes=5.0, - batch_size=batch_size, + concurrency=concurrency, ) assert await podcast.messages.size() == 25 - assert extractor.max_concurrency == batch_size + assert extractor.max_concurrency == concurrency assert len(extractor.started_texts) == 25 diff --git a/tests/test_property_index_population.py b/tests/test_property_index_population.py index 5a158353..f1f24c7c 100644 --- a/tests/test_property_index_population.py +++ b/tests/test_property_index_population.py @@ -79,7 +79,7 @@ async def test_property_index_population_from_database(really_needs_auth): ), ] - sem_ref_collection = await storage1.get_semantic_ref_collection() + sem_ref_collection = storage1.semantic_refs for sem_ref in test_data: await sem_ref_collection.append(sem_ref) @@ -111,7 +111,7 @@ async def test_property_index_population_from_database(really_needs_auth): # Build property index from the semantic refs await build_property_index(conversation) - prop_index = await storage2.get_property_index() + prop_index = storage2.property_index from typeagent.knowpro.interfaces import IPropertyToSemanticRefIndex assert isinstance(prop_index, IPropertyToSemanticRefIndex) diff --git a/tests/test_related_terms_index_population.py b/tests/test_related_terms_index_population.py index bf40722e..a4e615c4 100644 --- a/tests/test_related_terms_index_population.py +++ b/tests/test_related_terms_index_population.py @@ -61,12 +61,12 @@ async def test_related_terms_index_population_from_database(really_needs_auth): ), ] - msg_collection = await storage1.get_message_collection() + msg_collection = storage1.messages for message in test_messages: await msg_collection.append(message) # Add some semantic refs to create terms for the related terms index - sem_ref_collection = await storage1.get_semantic_ref_collection() + sem_ref_collection = storage1.semantic_refs # Add some entities entity_refs = [ @@ -97,7 +97,7 @@ async def test_related_terms_index_population_from_database(really_needs_auth): await sem_ref_collection.append(sem_ref) # Manually populate the semantic ref index since the user guarantees it's complete externally - semantic_ref_index = await storage1.get_semantic_ref_index() + semantic_ref_index = storage1.semantic_ref_index for sem_ref in entity_refs: knowledge = sem_ref.knowledge @@ -119,7 +119,7 @@ async def test_related_terms_index_population_from_database(really_needs_auth): ) # Check message collection size - msg_collection2 = await storage2.get_message_collection() + msg_collection2 = storage2.messages msg_count = await msg_collection2.size() print(f"Message collection size: {msg_count}") assert msg_count == len( @@ -127,7 +127,7 @@ async def test_related_terms_index_population_from_database(really_needs_auth): ), f"Expected {len(test_messages)} messages, got {msg_count}" # Check semantic ref collection size - sem_ref_collection2 = await storage2.get_semantic_ref_collection() + sem_ref_collection2 = storage2.semantic_refs sem_ref_count = await sem_ref_collection2.size() print(f"Semantic ref collection size: {sem_ref_count}") assert sem_ref_count == len( @@ -148,7 +148,7 @@ async def test_related_terms_index_population_from_database(really_needs_auth): await build_related_terms_index(conversation, related_terms_settings) # Check related terms index - related_terms_index = await storage2.get_related_terms_index() + related_terms_index = storage2.related_terms_index assert isinstance(related_terms_index, SqliteRelatedTermsIndex) # Check if fuzzy index has entries diff --git a/tests/test_reltermsindex.py b/tests/test_reltermsindex.py index 20752084..20c0f0b7 100644 --- a/tests/test_reltermsindex.py +++ b/tests/test_reltermsindex.py @@ -50,7 +50,7 @@ def get_knowledge(self): message_text_settings=message_text_settings, related_terms_settings=related_terms_settings, ) - index = await storage_provider.get_related_terms_index() + index = storage_provider.related_terms_index yield index else: provider = SqliteStorageProvider( @@ -59,7 +59,7 @@ def get_knowledge(self): message_text_index_settings=message_text_settings, related_term_index_settings=related_terms_settings, ) - index = await provider.get_related_terms_index() + index = provider.related_terms_index yield index await provider.close() diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 00000000..7028403d --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for knowpro/search.py — SearchOptions, ConversationSearchResult.""" + +import pytest + +from typeagent.knowpro.interfaces import ( + SearchTerm, + SearchTermGroup, + Term, +) +from typeagent.knowpro.interfaces_core import ScoredMessageOrdinal +from typeagent.knowpro.query import is_conversation_searchable +from typeagent.knowpro.search import ( + ConversationSearchResult, + search_conversation_knowledge, + SearchOptions, +) + +from conftest import FakeConversation, FakeMessage, FakeTermIndex + +# --------------------------------------------------------------------------- +# SearchOptions +# --------------------------------------------------------------------------- + + +def test_search_options_defaults() -> None: + opts = SearchOptions() + assert opts.max_knowledge_matches is None + assert opts.exact_match is False + assert opts.max_message_matches is None + assert opts.max_chars_in_budget is None + assert opts.threshold_score is None + + +def test_search_options_repr_empty() -> None: + opts = SearchOptions() + # Only non-None values appear in repr; exact_match=False is still included. + r = repr(opts) + assert r.startswith("SearchOptions(") + + +def test_search_options_repr_with_fields() -> None: + opts = SearchOptions(max_knowledge_matches=5, exact_match=True) + r = repr(opts) + assert "max_knowledge_matches=5" in r + assert "exact_match=True" in r + + +# --------------------------------------------------------------------------- +# ConversationSearchResult +# --------------------------------------------------------------------------- + + +def test_conversation_search_result_basic() -> None: + result = ConversationSearchResult( + message_matches=[ScoredMessageOrdinal(0, 0.9)], + knowledge_matches={}, + raw_query_text="test", + ) + assert len(result.message_matches) == 1 + assert result.raw_query_text == "test" + + +def test_conversation_search_result_defaults() -> None: + result = ConversationSearchResult(message_matches=[], knowledge_matches={}) + assert result.raw_query_text is None + + +# --------------------------------------------------------------------------- +# is_conversation_searchable (from query.py, used heavily in search.py) +# --------------------------------------------------------------------------- + + +def test_is_conversation_searchable_true() -> None: + conv = FakeConversation( + messages=[FakeMessage("hello", 0)], + has_secondary_indexes=False, + ) + conv.semantic_ref_index = FakeTermIndex() + assert is_conversation_searchable(conv) is True + + +def test_is_conversation_searchable_no_index() -> None: + conv = FakeConversation(has_secondary_indexes=False) + conv.semantic_ref_index = None + assert is_conversation_searchable(conv) is False + + +def test_is_conversation_searchable_no_semrefs() -> None: + conv = FakeConversation(has_secondary_indexes=False) + conv.semantic_refs = None # type: ignore[assignment] + assert is_conversation_searchable(conv) is False + + +# --------------------------------------------------------------------------- +# search_conversation_knowledge returns None when not searchable +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_search_conversation_knowledge_non_searchable_returns_none() -> None: + """When the conversation has no semantic ref index, result should be None.""" + conv = FakeConversation(has_secondary_indexes=False) + conv.semantic_ref_index = None + + group = SearchTermGroup( + boolean_op="or", + terms=[SearchTerm(term=Term("hello"))], + ) + result = await search_conversation_knowledge(conv, group) + assert result is None diff --git a/tests/test_searchlang_compile.py b/tests/test_searchlang_compile.py new file mode 100644 index 00000000..9b208fbb --- /dev/null +++ b/tests/test_searchlang_compile.py @@ -0,0 +1,638 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for searchlang.py — compile_search_query, SearchQueryCompiler, +and related helper functions that don't require a live LLM.""" + +import datetime +from typing import Literal + +from typeagent.knowpro.date_time_schema import DateTime, DateTimeRange, DateVal, TimeVal +from typeagent.knowpro.interfaces import SearchTerm, SearchTermGroup +from typeagent.knowpro.search_query_schema import ( + ActionTerm, + EntityTerm, + FacetTerm, + SearchExpr, + SearchFilter, + SearchQuery, + VerbsTerm, +) +from typeagent.knowpro.searchlang import ( + _compile_fallback_query, + compile_search_filter, + compile_search_query, + date_range_from_datetime_range, + datetime_from_date_time, + is_entity_term_list, + LanguageQueryCompileOptions, + LanguageSearchFilter, + optimize_or_max, + SearchQueryCompiler, +) + +from conftest import FakeConversation + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_entity( + name: str, + types: list[str] | None = None, + facets: list[FacetTerm] | None = None, + is_pronoun: bool = False, +) -> EntityTerm: + return EntityTerm(name=name, is_name_pronoun=is_pronoun, type=types, facets=facets) + + +def make_action( + actor: list[EntityTerm] | Literal["*"] = "*", + verbs: list[str] | None = None, + targets: list[EntityTerm] | None = None, + additional: list[EntityTerm] | None = None, + is_informational: bool = False, +) -> ActionTerm: + return ActionTerm( + actor_entities=actor, + is_informational=is_informational, + action_verbs=VerbsTerm(words=verbs) if verbs else None, + target_entities=targets, + additional_entities=additional, + ) + + +def make_filter( + entities: list[EntityTerm] | None = None, + action: ActionTerm | None = None, + search_terms: list[str] | None = None, + time_range: DateTimeRange | None = None, +) -> SearchFilter: + return SearchFilter( + entity_search_terms=entities, + action_search_term=action, + search_terms=search_terms, + time_range=time_range, + ) + + +def make_query(filters: list[SearchFilter]) -> SearchQuery: + expr = SearchExpr( + rewritten_query="test query", + filters=filters, + ) + return SearchQuery(search_expressions=[expr]) + + +def make_compiler( + options: LanguageQueryCompileOptions | None = None, + lang_filter: LanguageSearchFilter | None = None, +) -> SearchQueryCompiler: + conv = FakeConversation() + return SearchQueryCompiler(conv, options, lang_filter) + + +# --------------------------------------------------------------------------- +# is_entity_term_list +# --------------------------------------------------------------------------- + + +class TestIsEntityTermList: + def test_list_returns_true(self) -> None: + terms = [make_entity("Alice")] + assert is_entity_term_list(terms) is True + + def test_empty_list_returns_true(self) -> None: + assert is_entity_term_list([]) is True + + def test_star_returns_false(self) -> None: + assert is_entity_term_list("*") is False + + def test_none_returns_false(self) -> None: + assert is_entity_term_list(None) is False + + +# --------------------------------------------------------------------------- +# optimize_or_max +# --------------------------------------------------------------------------- + + +class TestOptimizeOrMax: + def test_single_term_unwrapped(self) -> None: + inner = SearchTermGroup(boolean_op="and", terms=[]) + group = SearchTermGroup(boolean_op="or_max", terms=[inner]) + result = optimize_or_max(group) + assert result is inner + + def test_multiple_terms_kept_as_group(self) -> None: + inner1 = SearchTermGroup(boolean_op="and", terms=[]) + inner2 = SearchTermGroup(boolean_op="and", terms=[]) + group = SearchTermGroup(boolean_op="or_max", terms=[inner1, inner2]) + result = optimize_or_max(group) + assert result is group + + +# --------------------------------------------------------------------------- +# date_range_from_datetime_range / datetime_from_date_time +# --------------------------------------------------------------------------- + + +class TestDatetimeFromDateTime: + def test_date_only_zeros_time(self) -> None: + dt = datetime_from_date_time(DateTime(date=DateVal(day=15, month=6, year=2024))) + assert dt.year == 2024 + assert dt.month == 6 + assert dt.day == 15 + assert dt.hour == 0 + assert dt.minute == 0 + assert dt.second == 0 + assert dt.tzinfo == datetime.timezone.utc + + def test_with_time(self) -> None: + dt = datetime_from_date_time( + DateTime( + date=DateVal(day=1, month=1, year=2020), + time=TimeVal(hour=14, minute=30, seconds=45), + ) + ) + assert dt.hour == 14 + assert dt.minute == 30 + assert dt.second == 45 + + +class TestDateRangeFromDatetimeRange: + def test_start_only(self) -> None: + dtr = DateTimeRange( + start_date=DateTime(date=DateVal(day=1, month=1, year=2023)) + ) + dr = date_range_from_datetime_range(dtr) + assert dr.start.year == 2023 + assert dr.end is None + + def test_start_and_stop(self) -> None: + dtr = DateTimeRange( + start_date=DateTime(date=DateVal(day=1, month=1, year=2023)), + stop_date=DateTime(date=DateVal(day=31, month=12, year=2023)), + ) + dr = date_range_from_datetime_range(dtr) + assert dr.start.year == 2023 + assert dr.end is not None + assert dr.end.year == 2023 + assert dr.end.month == 12 + assert dr.end.day == 31 + + +# --------------------------------------------------------------------------- +# compile_search_query (standalone function) +# --------------------------------------------------------------------------- + + +class TestCompileSearchQuery: + def test_empty_search_expressions(self) -> None: + conv = FakeConversation() + query = SearchQuery(search_expressions=[]) + result = compile_search_query(conv, query) + assert result == [] + + def test_single_search_terms_filter(self) -> None: + conv = FakeConversation() + query = make_query([make_filter(search_terms=["robots", "AI"])]) + result = compile_search_query(conv, query) + assert len(result) == 1 + expr = result[0] + assert len(expr.select_expressions) == 1 + terms_in_group = expr.select_expressions[0].search_term_group.terms + assert any( + isinstance(t, SearchTerm) and t.term.text == "robots" + for t in terms_in_group + ) + + def test_entity_filter_produces_expr(self) -> None: + conv = FakeConversation() + query = make_query([make_filter(entities=[make_entity("Alice", ["person"])])]) + result = compile_search_query(conv, query) + assert len(result) == 1 + + def test_multiple_filters_produce_multiple_select_exprs(self) -> None: + conv = FakeConversation() + filter1 = make_filter(search_terms=["alpha"]) + filter2 = make_filter(search_terms=["beta"]) + expr = SearchExpr(rewritten_query="test", filters=[filter1, filter2]) + query = SearchQuery(search_expressions=[expr]) + result = compile_search_query(conv, query) + assert len(result) == 1 + assert len(result[0].select_expressions) == 2 + + def test_raw_query_preserved(self) -> None: + conv = FakeConversation() + query = make_query([make_filter(search_terms=["foo"])]) + query.search_expressions[0].rewritten_query = "my rewritten query" + result = compile_search_query(conv, query) + assert result[0].raw_query == "my rewritten query" + + +# --------------------------------------------------------------------------- +# compile_search_filter (standalone function) +# --------------------------------------------------------------------------- + + +class TestCompileSearchFilter: + def test_entity_filter(self) -> None: + conv = FakeConversation() + f = make_filter(entities=[make_entity("Bob")]) + result = compile_search_filter(conv, f) + assert result.search_term_group is not None + + def test_search_terms_filter(self) -> None: + conv = FakeConversation() + f = make_filter(search_terms=["climate", "change"]) + result = compile_search_filter(conv, f) + terms = result.search_term_group.terms + assert len(terms) == 2 + + def test_empty_filter_uses_topic_wildcard(self) -> None: + """A filter with no entity, action, or search_terms should produce a topic:* term.""" + conv = FakeConversation() + f = SearchFilter() + result = compile_search_filter(conv, f) + # Should produce a single topic:* property search term + terms = result.search_term_group.terms + assert len(terms) == 1 + + def test_time_range_produces_when(self) -> None: + conv = FakeConversation() + dtr = DateTimeRange( + start_date=DateTime(date=DateVal(day=1, month=1, year=2024)) + ) + f = make_filter(search_terms=["foo"], time_range=dtr) + result = compile_search_filter(conv, f) + assert result.when is not None + assert result.when.date_range is not None + + def test_no_time_range_when_is_none(self) -> None: + conv = FakeConversation() + f = make_filter(search_terms=["foo"]) + result = compile_search_filter(conv, f) + assert result.when is None + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — compile_term_group and related +# --------------------------------------------------------------------------- + + +class TestSearchQueryCompilerTermGroup: + def test_search_terms_added(self) -> None: + compiler = make_compiler() + f = make_filter(search_terms=["hello", "world"]) + group = compiler.compile_term_group(f) + texts = [t.term.text for t in group.terms if isinstance(t, SearchTerm)] + assert "hello" in texts + assert "world" in texts + + def test_entity_name_added_as_property_term(self) -> None: + compiler = make_compiler() + f = make_filter(entities=[make_entity("Ada")]) + group = compiler.compile_term_group(f) + # Should have at least one term + assert len(group.terms) > 0 + + def test_empty_entity_name_ignored(self) -> None: + compiler = make_compiler() + f = make_filter(entities=[make_entity("")]) + group = compiler.compile_term_group(f) + # Empty string is not searchable; fallback to topic:* for empty term group + # (there are topic terms added for entity_terms in compile_entity_terms) + # We just check no crash and group is returned + assert group is not None + + def test_star_entity_name_ignored(self) -> None: + compiler = make_compiler() + f = make_filter(entities=[make_entity("*")]) + group = compiler.compile_term_group(f) + assert group is not None + + def test_noise_term_ignored(self) -> None: + compiler = make_compiler() + f = make_filter(search_terms=["thing", "object", "hello"]) + group = compiler.compile_term_group(f) + texts = [t.term.text for t in group.terms if isinstance(t, SearchTerm)] + # noise terms filtered from property groups but not from search_terms path + # search_terms path does NOT call add_property_term_to_group + assert "hello" in texts + + def test_custom_term_filter_excludes_property_terms(self) -> None: + # term_filter applies to add_property_term_to_group, not compile_search_terms. + options = LanguageQueryCompileOptions(term_filter=lambda t: t != "excluded") + compiler = make_compiler(options=options) + group = SearchTermGroup(boolean_op="or", terms=[]) + compiler.add_property_term_to_group("name", "excluded", group) + compiler.add_property_term_to_group("name", "included", group) + assert len(group.terms) == 1 + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — entity terms with facets +# --------------------------------------------------------------------------- + + +class TestEntityTermsWithFacets: + def test_entity_with_type(self) -> None: + compiler = make_compiler() + entity = make_entity("Alice", types=["person"]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_entity_with_facet_name_and_value(self) -> None: + compiler = make_compiler() + facet = FacetTerm(facet_name="profession", facet_value="writer") + entity = make_entity("Bob", facets=[facet]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_entity_with_wildcard_facet_value(self) -> None: + compiler = make_compiler() + facet = FacetTerm(facet_name="profession", facet_value="*") + entity = make_entity("Bob", facets=[facet]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_entity_with_wildcard_facet_name(self) -> None: + compiler = make_compiler() + facet = FacetTerm(facet_name="*", facet_value="writer") + entity = make_entity("Bob", facets=[facet]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_entity_with_both_wildcards_no_facet_term(self) -> None: + compiler = make_compiler() + facet = FacetTerm(facet_name="*", facet_value="*") + entity = make_entity("Bob", facets=[facet]) + f = make_filter(entities=[entity]) + group = compiler.compile_term_group(f) + # Both wildcards => no facet term added, but entity name term (or_max) + # and topic term for "Bob" are still generated — 2 terms total. + assert len(group.terms) == 2 + + def test_pronoun_entity_skipped(self) -> None: + compiler = make_compiler() + pronoun = make_entity("it", is_pronoun=True) + normal = make_entity("Alice") + f = make_filter(entities=[pronoun, normal]) + group = compiler.compile_term_group(f) + # Only Alice's term should be added + assert len(group.terms) > 0 + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — action terms +# --------------------------------------------------------------------------- + + +class TestActionTerms: + def test_action_with_verbs_adds_verb_terms(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + action = make_action(actor=[actor], verbs=["sent", "emailed"]) + f = make_filter(action=action) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_action_with_target_entities(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + target = make_entity("Bob") + action = make_action(actor=[actor], verbs=["sent"], targets=[target]) + f = make_filter(action=action) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_action_with_additional_entities(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + extra = make_entity("Charlie") + action = make_action(actor=[actor], verbs=["spoke"], additional=[extra]) + f = make_filter(action=action) + group = compiler.compile_term_group(f) + assert len(group.terms) > 0 + + def test_action_star_actor_no_scope(self) -> None: + """When actor_entities is '*', scope is not applied.""" + action = make_action(actor="*", verbs=["played"]) + f = make_filter(action=action) + result = compile_search_filter(FakeConversation(), f) + # should have no scope (when is None or when.scope_defining_terms is empty) + when = result.when + assert when is None or ( + when.scope_defining_terms is None + or len(when.scope_defining_terms.terms) == 0 + ) + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — compile_when with scope +# --------------------------------------------------------------------------- + + +class TestCompileWhen: + def test_no_action_no_when(self) -> None: + compiler = make_compiler() + f = make_filter(search_terms=["foo"]) + when = compiler.compile_when(f) + assert when is None + + def test_time_range_produces_date_range(self) -> None: + compiler = make_compiler() + dtr = DateTimeRange( + start_date=DateTime(date=DateVal(day=1, month=3, year=2025)), + stop_date=DateTime(date=DateVal(day=31, month=3, year=2025)), + ) + f = make_filter(search_terms=["foo"], time_range=dtr) + when = compiler.compile_when(f) + assert when is not None + assert when.date_range is not None + assert when.date_range.start.month == 3 + + def test_informational_action_no_scope(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + action = make_action(actor=[actor], verbs=["spoke"], is_informational=True) + f = make_filter(action=action) + when = compiler.compile_when(f) + # is_informational = True → should_add_scope returns False → no scope in when + assert when is None or ( + when.scope_defining_terms is None + or len(when.scope_defining_terms.terms) == 0 + ) + + def test_actor_entities_list_adds_scope(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + action = make_action(actor=[actor], verbs=["sent"]) + f = make_filter(action=action) + when = compiler.compile_when(f) + assert when is not None + assert when.scope_defining_terms is not None + assert len(when.scope_defining_terms.terms) > 0 + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — compile_search_terms +# --------------------------------------------------------------------------- + + +class TestCompileSearchTerms: + def test_returns_search_term_group(self) -> None: + compiler = make_compiler() + group = compiler.compile_search_terms(["alpha", "beta"]) + texts = [t.term.text for t in group.terms if isinstance(t, SearchTerm)] + assert "alpha" in texts + assert "beta" in texts + + def test_appends_to_existing_group(self) -> None: + compiler = make_compiler() + existing = SearchTermGroup(boolean_op="or", terms=[]) + compiler.compile_search_terms(["gamma"], existing) + texts = [t.term.text for t in existing.terms if isinstance(t, SearchTerm)] + assert "gamma" in texts + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — is_searchable_string / is_noise_term +# --------------------------------------------------------------------------- + + +class TestIsSearchableString: + def test_normal_string_is_searchable(self) -> None: + compiler = make_compiler() + assert compiler.is_searchable_string("hello") is True + + def test_empty_string_not_searchable(self) -> None: + compiler = make_compiler() + assert compiler.is_searchable_string("") is False + + def test_star_not_searchable(self) -> None: + compiler = make_compiler() + assert compiler.is_searchable_string("*") is False + + def test_term_filter_respected(self) -> None: + options = LanguageQueryCompileOptions(term_filter=lambda t: t != "skip") + compiler = make_compiler(options=options) + assert compiler.is_searchable_string("skip") is False + assert compiler.is_searchable_string("keep") is True + + +class TestIsNoiseTerm: + def test_noise_words(self) -> None: + compiler = make_compiler() + for word in ("thing", "object", "concept", "idea", "entity"): + assert compiler.is_noise_term(word) is True + + def test_non_noise_word(self) -> None: + compiler = make_compiler() + assert compiler.is_noise_term("robot") is False + + def test_case_insensitive(self) -> None: + compiler = make_compiler() + assert compiler.is_noise_term("THING") is True + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — deduplication +# --------------------------------------------------------------------------- + + +class TestDeduplication: + def test_duplicate_property_term_not_added_twice(self) -> None: + compiler = make_compiler() + group = SearchTermGroup(boolean_op="or", terms=[]) + compiler.add_property_term_to_group("name", "Alice", group) + compiler.add_property_term_to_group("name", "Alice", group) + assert len(group.terms) == 1 + + def test_different_property_names_both_added(self) -> None: + compiler = make_compiler() + group = SearchTermGroup(boolean_op="or", terms=[]) + compiler.add_property_term_to_group("name", "Alice", group) + compiler.add_property_term_to_group("topic", "Alice", group) + assert len(group.terms) == 2 + + def test_dedupe_disabled_allows_duplicates(self) -> None: + compiler = make_compiler() + compiler.dedupe = False + group = SearchTermGroup(boolean_op="or", terms=[]) + compiler.add_property_term_to_group("name", "Alice", group) + compiler.add_property_term_to_group("name", "Alice", group) + assert len(group.terms) == 2 + + +# --------------------------------------------------------------------------- +# _compile_fallback_query +# --------------------------------------------------------------------------- + + +class TestCompileFallbackQuery: + def test_exact_scope_no_fallback(self) -> None: + conv = FakeConversation() + options = LanguageQueryCompileOptions(exact_scope=True, verb_scope=True) + query = make_query([make_filter(search_terms=["foo"])]) + result = _compile_fallback_query(conv, query, options) + assert result is None + + def test_no_verb_scope_no_fallback(self) -> None: + conv = FakeConversation() + options = LanguageQueryCompileOptions(exact_scope=False, verb_scope=False) + query = make_query([make_filter(search_terms=["foo"])]) + result = _compile_fallback_query(conv, query, options) + assert result is None + + def test_verb_scope_and_not_exact_produces_fallback(self) -> None: + conv = FakeConversation() + options = LanguageQueryCompileOptions(exact_scope=False, verb_scope=True) + query = make_query([make_filter(search_terms=["foo"])]) + result = _compile_fallback_query(conv, query, options) + # Should return a list of SearchQueryExpr (fallback without verb matching) + assert result is not None + assert isinstance(result, list) + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# SearchQueryCompiler — compile_action_term_as_search_terms (use_or_max=False) +# --------------------------------------------------------------------------- + + +class TestCompileActionTermAsSearchTerms: + def test_no_verbs_no_actor_empty_group(self) -> None: + compiler = make_compiler() + action = ActionTerm( + actor_entities="*", + is_informational=False, + ) + group = compiler.compile_action_term_as_search_terms(action, use_or_max=False) + # actor is "*" so no actor entities; no verbs; result depends on implementation + assert group is not None + + def test_use_or_max_false_merges_into_same_group(self) -> None: + compiler = make_compiler() + actor = make_entity("Alice") + action = make_action(actor=[actor], verbs=["sent"]) + group = compiler.compile_action_term_as_search_terms(action, use_or_max=False) + assert len(group.terms) > 0 + + def test_empty_or_max_not_appended(self) -> None: + """With use_or_max=True but no verbs/actors, or_max wrapper should not be appended.""" + compiler = make_compiler() + action = ActionTerm( + actor_entities="*", + is_informational=False, + ) + outer = SearchTermGroup(boolean_op="or", terms=[]) + compiler.compile_action_term_as_search_terms(action, outer, use_or_max=True) + # or_max only appended if non-empty + assert len(outer.terms) == 0 diff --git a/tests/test_secindex.py b/tests/test_secindex.py index 39665b05..730cb831 100644 --- a/tests/test_secindex.py +++ b/tests/test_secindex.py @@ -45,15 +45,15 @@ def test_conversation_secondary_indexes_initialization( embedding_settings = TextEmbeddingIndexSettings(test_model) settings = RelatedTermIndexSettings(embedding_settings) indexes = ConversationSecondaryIndexes(storage_provider, settings) - # Note: indexes are None until initialize() is called - assert indexes.property_to_semantic_ref_index is None - assert indexes.timestamp_index is None - assert indexes.term_to_related_terms_index is None + # Indexes are initialized from storage provider in __init__ + assert indexes.property_to_semantic_ref_index is not None + assert indexes.timestamp_index is not None + assert indexes.term_to_related_terms_index is not None # Test with custom settings settings2 = RelatedTermIndexSettings(embedding_settings) indexes_with_settings = ConversationSecondaryIndexes(storage_provider, settings2) - assert indexes_with_settings.property_to_semantic_ref_index is None + assert indexes_with_settings.property_to_semantic_ref_index is not None @pytest.mark.asyncio diff --git a/tests/test_secindex_storage_integration.py b/tests/test_secindex_storage_integration.py index 15738bb6..0c3751fe 100644 --- a/tests/test_secindex_storage_integration.py +++ b/tests/test_secindex_storage_integration.py @@ -23,9 +23,7 @@ async def test_secondary_indexes_use_storage_provider( embedding_settings = TextEmbeddingIndexSettings(test_model) related_terms_settings = RelatedTermIndexSettings(embedding_settings) - indexes = await ConversationSecondaryIndexes.create( - storage_provider, related_terms_settings - ) + indexes = ConversationSecondaryIndexes(storage_provider, related_terms_settings) assert indexes.property_to_semantic_ref_index is not None assert indexes.timestamp_index is not None @@ -34,11 +32,11 @@ async def test_secondary_indexes_use_storage_provider( assert indexes.message_index is not None # Verify they are the same instances as those from storage provider - storage_prop_index = await storage_provider.get_property_index() - storage_timestamp_index = await storage_provider.get_timestamp_index() - storage_related_terms = await storage_provider.get_related_terms_index() - storage_threads = await storage_provider.get_conversation_threads() - storage_message_index = await storage_provider.get_message_text_index() + storage_prop_index = storage_provider.property_index + storage_timestamp_index = storage_provider.timestamp_index + storage_related_terms = storage_provider.related_terms_index + storage_threads = storage_provider.conversation_threads + storage_message_index = storage_provider.message_text_index assert indexes.property_to_semantic_ref_index is storage_prop_index assert indexes.timestamp_index is storage_timestamp_index diff --git a/tests/test_semrefindex.py b/tests/test_semrefindex.py index 5f580992..f12de683 100644 --- a/tests/test_semrefindex.py +++ b/tests/test_semrefindex.py @@ -65,7 +65,7 @@ def get_knowledge(self): message_text_settings=message_text_settings, related_terms_settings=related_terms_settings, ) - index = await provider.get_semantic_ref_index() + index = provider.semantic_ref_index yield index else: provider = SqliteStorageProvider( @@ -83,7 +83,7 @@ def get_knowledge(self): Topic, ) - collection = await provider.get_semantic_ref_collection() + collection = provider.semantic_refs # Create semantic refs with ordinals 1, 2, 3 that the tests expect for i in range(1, 4): @@ -94,7 +94,7 @@ def get_knowledge(self): ) await collection.append(ref) - index = await provider.get_semantic_ref_index() + index = provider.semantic_ref_index yield index await provider.close() @@ -125,8 +125,8 @@ def get_knowledge(self): message_text_settings=message_text_settings, related_terms_settings=related_terms_settings, ) - index = await provider.get_semantic_ref_index() - collection = await provider.get_semantic_ref_collection() + index = provider.semantic_ref_index + collection = provider.semantic_refs yield {"index": index, "collection": collection} else: provider = SqliteStorageProvider( @@ -135,8 +135,8 @@ def get_knowledge(self): message_text_index_settings=message_text_settings, related_term_index_settings=related_terms_settings, ) - index = await provider.get_semantic_ref_index() - collection = await provider.get_semantic_ref_collection() + index = provider.semantic_ref_index + collection = provider.semantic_refs yield {"index": index, "collection": collection} await provider.close() diff --git a/tests/test_serialization.py b/tests/test_serialization.py index eec08fda..d32b9526 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,15 +10,22 @@ from typeagent.knowpro.interfaces import ( ConversationDataWithIndexes, MessageTextIndexData, + Tag, TermsToRelatedTermsIndexData, TextToTextLocationIndexData, + Topic, ) from typeagent.knowpro.knowledge_schema import ConcreteEntity, Quantity from typeagent.knowpro.serialization import ( + ConversationBinaryData, + ConversationFileData, + ConversationJsonData, create_file_header, DeserializationError, + deserialize_knowledge, deserialize_object, from_conversation_file_data, + is_primitive, serialize_embeddings, serialize_object, to_conversation_file_data, @@ -133,3 +140,136 @@ def test_deserialization_error(): """Test that DeserializationError is raised for invalid data.""" with pytest.raises(DeserializationError, match="Pydantic validation failed"): deserialize_object(Quantity, {"invalid_key": "value"}) + + +# --------------------------------------------------------------------------- +# Additional tests for broader coverage +# --------------------------------------------------------------------------- + + +def test_from_conversation_file_data_missing_header_raises(): + """from_conversation_file_data raises when fileHeader is absent.""" + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="x", messages=[], tags=[], semanticRefs=None + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=[]), + ) + with pytest.raises(DeserializationError, match="Missing file header"): + from_conversation_file_data(file_data) + + +def test_from_conversation_file_data_bad_version_raises(): + """from_conversation_file_data raises on unsupported version.""" + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="x", + messages=[], + tags=[], + semanticRefs=None, + fileHeader={"version": "99.9"}, + embeddingFileHeader={}, + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=[]), + ) + with pytest.raises(DeserializationError, match="Unsupported file version"): + from_conversation_file_data(file_data) + + +def test_from_conversation_file_data_missing_embedding_header_raises(): + """from_conversation_file_data raises when embeddingFileHeader is absent.""" + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="x", + messages=[], + tags=[], + semanticRefs=None, + fileHeader={"version": "0.1"}, + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=[]), + ) + with pytest.raises(DeserializationError, match="Missing embedding file header"): + from_conversation_file_data(file_data) + + +def test_from_conversation_file_data_missing_embeddings_list_raises(): + """from_conversation_file_data raises when embeddingsList is None.""" + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="x", + messages=[], + tags=[], + semanticRefs=None, + fileHeader={"version": "0.1"}, + embeddingFileHeader={}, + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=None), + ) + with pytest.raises(DeserializationError, match="Missing embeddings list"): + from_conversation_file_data(file_data) + + +def test_from_conversation_file_data_success_empty(): + """from_conversation_file_data succeeds with minimal valid data.""" + emb = np.zeros((0, 4), dtype=np.float32) + json_data: ConversationJsonData[Any] = ConversationJsonData( + nameTag="test", + messages=[], + tags=[], + semanticRefs=None, + fileHeader={"version": "0.1"}, + embeddingFileHeader={}, + ) + file_data: ConversationFileData[Any] = ConversationFileData( + jsonData=json_data, + binaryData=ConversationBinaryData(embeddingsList=[emb]), + ) + result = from_conversation_file_data(file_data) + assert result["nameTag"] == "test" + + +def test_is_primitive(): + """Test is_primitive classification.""" + for t in (int, float, bool, str, type(None)): + assert is_primitive(t), f"Expected {t} to be primitive" + assert not is_primitive(list) + assert not is_primitive(dict) + + +def test_deserialize_object_union_none(): + """deserialize_object handles optional (X | None) type with None input.""" + result = deserialize_object(int | None, None) + assert result is None + + +def test_deserialize_object_list_of_int(): + """deserialize_object can deserialize a list of ints.""" + result = deserialize_object(list[int], [1, 2, 3]) + assert result == [1, 2, 3] + + +def test_deserialize_knowledge_entity(): + """deserialize_knowledge reconstructs a ConcreteEntity.""" + obj = {"name": "Bob", "type": ["person"]} + result = deserialize_knowledge("entity", obj) + assert isinstance(result, ConcreteEntity) + assert result.name == "Bob" + + +def test_deserialize_knowledge_topic(): + """deserialize_knowledge reconstructs a Topic.""" + obj = {"text": "AI ethics"} + result = deserialize_knowledge("topic", obj) + assert isinstance(result, Topic) + assert result.text == "AI ethics" + + +def test_deserialize_knowledge_tag(): + """deserialize_knowledge reconstructs a Tag.""" + obj = {"text": "important"} + result = deserialize_knowledge("tag", obj) + assert isinstance(result, Tag) diff --git a/tests/test_source_id_ingestion.py b/tests/test_source_id_ingestion.py new file mode 100644 index 00000000..1886a7b7 --- /dev/null +++ b/tests/test_source_id_ingestion.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for source_id-based ingestion tracking in add_messages_with_indexing.""" + +import os +import tempfile + +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + + +def _make_message( + text: str, speaker: str = "Alice", source_id: str | None = None +) -> TranscriptMessage: + return TranscriptMessage( + text_chunks=[text], + metadata=TranscriptMessageMeta(speaker=speaker), + tags=["test"], + source_id=source_id, + ) + + +async def _create_transcript( + db_path: str, +) -> tuple[Transcript, SqliteStorageProvider]: + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + transcript = await Transcript.create(settings, name="test") + return transcript, storage + + +def _ingested_count(storage: SqliteStorageProvider) -> int: + """Count rows in IngestedSources table.""" + cursor = storage.db.cursor() + cursor.execute("SELECT COUNT(*) FROM IngestedSources") + return cursor.fetchone()[0] + + +@pytest.mark.asyncio +async def test_explicit_source_ids_marks_ingested() -> None: + """Passing source_ids= explicitly marks those IDs as ingested.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message("Hello"), _make_message("World")] + await transcript.add_messages_with_indexing(msgs, source_ids=["src-1", "src-2"]) + + assert await storage.is_source_ingested("src-1") + assert await storage.is_source_ingested("src-2") + assert not await storage.is_source_ingested("src-3") + + await storage.close() + + +@pytest.mark.asyncio +async def test_message_source_id_marks_ingested() -> None: + """When source_ids is omitted, message.source_id is used.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_message("Hello", source_id="msg-src-1"), + _make_message("World", source_id="msg-src-2"), + ] + await transcript.add_messages_with_indexing(msgs) + + assert await storage.is_source_ingested("msg-src-1") + assert await storage.is_source_ingested("msg-src-2") + + await storage.close() + + +@pytest.mark.asyncio +async def test_message_source_id_none_skipped() -> None: + """Messages with source_id=None are silently skipped (no ingestion mark).""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_message("Hello", source_id="only-one"), + _make_message("World"), # source_id=None + ] + await transcript.add_messages_with_indexing(msgs) + + assert await storage.is_source_ingested("only-one") + # The second message had no source_id, so nothing extra was marked + assert await storage.get_source_status("only-one") == "ingested" + assert _ingested_count(storage) == 1 + + await storage.close() + + +@pytest.mark.asyncio +async def test_explicit_source_ids_overrides_message_source_id() -> None: + """Passing source_ids= takes precedence; message.source_id is ignored.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_message("Hello", source_id="msg-level"), + ] + await transcript.add_messages_with_indexing(msgs, source_ids=["explicit-id"]) + + assert await storage.is_source_ingested("explicit-id") + assert not await storage.is_source_ingested("msg-level") + + await storage.close() + + +@pytest.mark.asyncio +async def test_source_ids_length_mismatch_raises() -> None: + """Passing source_ids with wrong length raises ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message("Hello"), _make_message("World")] + with pytest.raises(ValueError, match="Length of source_ids"): + await transcript.add_messages_with_indexing(msgs, source_ids=["only-one"]) + + await storage.close() + + +@pytest.mark.asyncio +async def test_no_source_ids_no_message_source_id() -> None: + """When neither source_ids nor message.source_id is set, nothing is marked.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message("Hello"), _make_message("World")] + result = await transcript.add_messages_with_indexing(msgs) + + assert result.messages_added == 2 + # No source tracking happened + assert _ingested_count(storage) == 0 + + await storage.close() diff --git a/tests/test_sqlitestore.py b/tests/test_sqlitestore.py index 27a522d0..3d973215 100644 --- a/tests/test_sqlitestore.py +++ b/tests/test_sqlitestore.py @@ -74,7 +74,7 @@ def make_dummy_semantic_ref(ordinal: int = 0) -> SemanticRef: async def test_sqlite_message_collection_append_and_get( dummy_sqlite_storage_provider: SqliteStorageProvider[DummyMessage], ): - store = await dummy_sqlite_storage_provider.get_message_collection() + store = dummy_sqlite_storage_provider.messages msg = DummyMessage(["foo"]) await store.append(msg) assert await store.size() == 1 @@ -90,7 +90,7 @@ async def test_sqlite_message_collection_append_and_get( async def test_sqlite_message_collection_iter( dummy_sqlite_storage_provider: SqliteStorageProvider[DummyMessage], ): - collection = await dummy_sqlite_storage_provider.get_message_collection() + collection = dummy_sqlite_storage_provider.messages msgs = [DummyMessage([f"msg{i}"]) for i in range(3)] for m in msgs: await collection.append(m) @@ -101,7 +101,7 @@ async def test_sqlite_message_collection_iter( async def test_sqlite_semantic_ref_collection_append_and_get( dummy_sqlite_storage_provider: SqliteStorageProvider[DummyMessage], ): - collection = await dummy_sqlite_storage_provider.get_semantic_ref_collection() + collection = dummy_sqlite_storage_provider.semantic_refs ref = make_dummy_semantic_ref(123) await collection.append(ref) assert await collection.size() == 1 @@ -117,7 +117,7 @@ async def test_sqlite_semantic_ref_collection_append_and_get( async def test_sqlite_semantic_ref_collection_iter( dummy_sqlite_storage_provider: SqliteStorageProvider[DummyMessage], ): - collection = await dummy_sqlite_storage_provider.get_semantic_ref_collection() + collection = dummy_sqlite_storage_provider.semantic_refs refs = [make_dummy_semantic_ref(i) for i in range(2)] for r in refs: await collection.append(r) @@ -132,7 +132,7 @@ async def test_sqlite_timestamp_index( from typeagent.knowpro.interfaces import DateRange # Set up database with some messages - message_collection = await dummy_sqlite_storage_provider.get_message_collection() + message_collection = dummy_sqlite_storage_provider.messages # Add test messages messages = [ @@ -145,7 +145,7 @@ async def test_sqlite_timestamp_index( await message_collection.append(msg) # Create timestamp index - timestamp_index = await dummy_sqlite_storage_provider.get_timestamp_index() + timestamp_index = dummy_sqlite_storage_provider.timestamp_index # Test add_timestamp - use actual message ordinals from the database test_timestamps = [ diff --git a/tests/test_storage_providers_unified.py b/tests/test_storage_providers_unified.py index 179b1a7b..f12f6fde 100644 --- a/tests/test_storage_providers_unified.py +++ b/tests/test_storage_providers_unified.py @@ -102,27 +102,27 @@ async def test_all_index_creation( storage_provider, _ = storage_provider_type # Test all index types are created and return proper interface objects - conv_index = await storage_provider.get_semantic_ref_index() + conv_index = storage_provider.semantic_ref_index assert conv_index is not None assert hasattr(conv_index, "lookup_term") # Basic interface check - prop_index = await storage_provider.get_property_index() + prop_index = storage_provider.property_index assert prop_index is not None assert hasattr(prop_index, "lookup_property") # Basic interface check - time_index = await storage_provider.get_timestamp_index() + time_index = storage_provider.timestamp_index assert time_index is not None assert hasattr(time_index, "lookup_range") # Basic interface check - msg_index = await storage_provider.get_message_text_index() + msg_index = storage_provider.message_text_index assert msg_index is not None assert hasattr(msg_index, "lookup_messages") # Basic interface check - rel_index = await storage_provider.get_related_terms_index() + rel_index = storage_provider.related_terms_index assert rel_index is not None assert hasattr(rel_index, "aliases") # Basic interface check - threads = await storage_provider.get_conversation_threads() + threads = storage_provider.conversation_threads assert threads is not None assert hasattr(threads, "threads") # Basic interface check @@ -135,16 +135,16 @@ async def test_index_persistence( storage_provider, _ = storage_provider_type # All index types should return same instance across calls - conv1 = await storage_provider.get_semantic_ref_index() - conv2 = await storage_provider.get_semantic_ref_index() + conv1 = storage_provider.semantic_ref_index + conv2 = storage_provider.semantic_ref_index assert conv1 is conv2 - prop1 = await storage_provider.get_property_index() - prop2 = await storage_provider.get_property_index() + prop1 = storage_provider.property_index + prop2 = storage_provider.property_index assert prop1 is prop2 - time1 = await storage_provider.get_timestamp_index() - time2 = await storage_provider.get_timestamp_index() + time1 = storage_provider.timestamp_index + time2 = storage_provider.timestamp_index assert time1 is time2 @@ -156,7 +156,7 @@ async def test_message_collection_basic_operations( storage_provider, _ = storage_provider_type # Create message collection - collection = await storage_provider.get_message_collection() + collection = storage_provider.messages # Test initial state assert await collection.size() == 0 @@ -196,7 +196,7 @@ async def test_semantic_ref_collection_basic_operations( storage_provider, _ = storage_provider_type # Create semantic ref collection - collection = await storage_provider.get_semantic_ref_collection() + collection = storage_provider.semantic_refs # Test initial state assert await collection.size() == 0 @@ -251,7 +251,7 @@ async def test_semantic_ref_index_behavior_parity( """Test that semantic ref index behaves identically in both providers.""" storage_provider, _ = storage_provider_type - conv_index = await storage_provider.get_semantic_ref_index() + conv_index = storage_provider.semantic_ref_index # Test empty state empty_results = await conv_index.lookup_term("nonexistent") @@ -269,7 +269,7 @@ async def test_timestamp_index_behavior_parity( """Test that timestamp index behaves identically in both providers.""" storage_provider, _provider_type = storage_provider_type - time_index = await storage_provider.get_timestamp_index() + time_index = storage_provider.timestamp_index # Test empty lookup_range interface start_time = Datetime.fromisoformat("2024-01-01T00:00:00Z") @@ -288,7 +288,7 @@ async def test_message_text_index_interface_parity( """Test that message text index interface works identically in both providers.""" storage_provider, _ = storage_provider_type - msg_index = await storage_provider.get_message_text_index() + msg_index = storage_provider.message_text_index # Test empty lookup_messages empty_results = await msg_index.lookup_messages("nonexistent query", 10) @@ -303,7 +303,7 @@ async def test_related_terms_index_interface_parity( """Test that related terms index interface works identically in both providers.""" storage_provider, _ = storage_provider_type - rel_index = await storage_provider.get_related_terms_index() + rel_index = storage_provider.related_terms_index # Test interface properties aliases = rel_index.aliases @@ -321,7 +321,7 @@ async def test_conversation_threads_interface_parity( """Test that conversation threads interface works identically in both providers.""" storage_provider, _ = storage_provider_type - threads = await storage_provider.get_conversation_threads() + threads = storage_provider.conversation_threads # Test initial empty state assert len(threads.threads) == 0 @@ -352,8 +352,8 @@ async def test_cross_provider_message_collection_equivalence( try: # Create collections in both - memory_collection = await memory_provider.get_message_collection() - sqlite_collection = await sqlite_provider.get_message_collection() + memory_collection = memory_provider.messages + sqlite_collection = sqlite_provider.messages # Add identical data to both test_messages = [ @@ -394,8 +394,8 @@ async def test_property_index_population_from_semantic_refs( storage_provider, provider_type = storage_provider_type # Get collections - sem_ref_collection = await storage_provider.get_semantic_ref_collection() - prop_index = await storage_provider.get_property_index() + sem_ref_collection = storage_provider.semantic_refs + prop_index = storage_provider.property_index # Check initial state initial_sem_ref_count = await sem_ref_collection.size() @@ -476,7 +476,7 @@ async def test_property_index_basic_operations( """Test basic property index operations work identically in both providers.""" storage_provider, _ = storage_provider_type - prop_index = await storage_provider.get_property_index() + prop_index = storage_provider.property_index # Test initial state - should be able to handle lookups even when empty empty_results = await prop_index.lookup_property("name", "nonexistent") @@ -495,7 +495,7 @@ async def test_timestamp_index_range_queries( """Test timestamp index range query functionality in both providers.""" storage_provider, _ = storage_provider_type - timestamp_index = await storage_provider.get_timestamp_index() + timestamp_index = storage_provider.timestamp_index # Test basic interface - empty range query start_time = Datetime.fromisoformat("2024-01-01T00:00:00Z") @@ -526,8 +526,8 @@ async def test_timestamp_index_with_data( storage_provider, provider_type = storage_provider_type # First add some messages to work with - message_collection = await storage_provider.get_message_collection() - timestamp_index = await storage_provider.get_timestamp_index() + message_collection = storage_provider.messages + timestamp_index = storage_provider.timestamp_index # Add test messages test_messages = [ @@ -631,12 +631,12 @@ async def test_storage_provider_independence( ) # Test memory provider independence - memory_index1 = await memory_provider1.get_semantic_ref_index() - memory_index2 = await memory_provider2.get_semantic_ref_index() + memory_index1 = memory_provider1.semantic_ref_index + memory_index2 = memory_provider2.semantic_ref_index assert memory_index1 is not memory_index2 - memory_collection1 = await memory_provider1.get_message_collection() - memory_collection2 = await memory_provider2.get_message_collection() + memory_collection1 = memory_provider1.messages + memory_collection2 = memory_provider2.messages # Add data to first memory provider await memory_collection1.append(DummyTestMessage(["memory test 1"])) @@ -644,12 +644,12 @@ async def test_storage_provider_independence( assert await memory_collection2.size() == 0 # Second provider unaffected # Test sqlite provider independence - sqlite_index1 = await sqlite_provider1.get_semantic_ref_index() - sqlite_index2 = await sqlite_provider2.get_semantic_ref_index() + sqlite_index1 = sqlite_provider1.semantic_ref_index + sqlite_index2 = sqlite_provider2.semantic_ref_index assert sqlite_index1 is not sqlite_index2 - sqlite_collection1 = await sqlite_provider1.get_message_collection() - sqlite_collection2 = await sqlite_provider2.get_message_collection() + sqlite_collection1 = sqlite_provider1.messages + sqlite_collection2 = sqlite_provider2.messages # Add data to first sqlite provider await sqlite_collection1.append(DummyTestMessage(["sqlite test 1"])) @@ -681,7 +681,7 @@ async def test_collection_operations_comprehensive( storage_provider, _ = storage_provider_type # Test message collection operations - message_collection = await storage_provider.get_message_collection() + message_collection = storage_provider.messages # Test initial state assert await message_collection.size() == 0 diff --git a/tests/test_textlocindex.py b/tests/test_textlocindex.py new file mode 100644 index 00000000..a9e6454f --- /dev/null +++ b/tests/test_textlocindex.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for knowpro/textlocindex.py (TextToTextLocationIndex).""" + +import numpy as np +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings +from typeagent.knowpro.interfaces import TextLocation, TextToTextLocationIndexData +from typeagent.knowpro.textlocindex import TextToTextLocationIndex + + +@pytest.fixture +def settings() -> TextEmbeddingIndexSettings: + return TextEmbeddingIndexSettings(create_test_embedding_model()) + + +@pytest.fixture +def index(settings: TextEmbeddingIndexSettings) -> TextToTextLocationIndex: + return TextToTextLocationIndex(settings) + + +# --------------------------------------------------------------------------- +# Empty index +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_empty_size(index: TextToTextLocationIndex) -> None: + assert await index.size() == 0 + + +@pytest.mark.asyncio +async def test_empty_is_empty(index: TextToTextLocationIndex) -> None: + assert await index.is_empty() + + +def test_get_out_of_range_returns_default(index: TextToTextLocationIndex) -> None: + assert index.get(0) is None + assert index.get(-1) is None + assert index.get(0, TextLocation(99)) == TextLocation(99) + + +# --------------------------------------------------------------------------- +# clear() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_clear_resets(index: TextToTextLocationIndex) -> None: + loc = TextLocation(message_ordinal=0) + await index.add_text_location("hello world", loc) + assert await index.size() == 1 + index.clear() + assert await index.size() == 0 + assert await index.is_empty() + + +# --------------------------------------------------------------------------- +# serialize / deserialize round-trip (no real embeddings needed) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_serialize_empty(index: TextToTextLocationIndex) -> None: + data = index.serialize() + assert data["textLocations"] == [] + # embeddings may be None or an empty ndarray + emb = data["embeddings"] + assert emb is None or (hasattr(emb, "shape") and emb.size == 0) + + +def test_deserialize_raises_on_no_embeddings( + index: TextToTextLocationIndex, +) -> None: + data: TextToTextLocationIndexData = { + "textLocations": [{"messageOrdinal": 0, "chunkOrdinal": 0}], + "embeddings": None, + } + with pytest.raises(ValueError, match="No embeddings found"): + index.deserialize(data) + + +def test_deserialize_raises_on_length_mismatch( + index: TextToTextLocationIndex, settings: TextEmbeddingIndexSettings +) -> None: + # The test embedding model uses size 3 by default. + emb_size = 3 + fake_emb = np.zeros((3, emb_size), dtype=np.float32) + data: TextToTextLocationIndexData = { + # 2 locations but 3 embeddings → mismatch + "textLocations": [ + {"messageOrdinal": 0, "chunkOrdinal": 0}, + {"messageOrdinal": 1, "chunkOrdinal": 0}, + ], + "embeddings": fake_emb, + } + with pytest.raises(ValueError): + index.deserialize(data) + + +def test_deserialize_valid_data( + index: TextToTextLocationIndex, settings: TextEmbeddingIndexSettings +) -> None: + emb_size = 3 # default size for create_test_embedding_model() + n = 2 + fake_emb = np.zeros((n, emb_size), dtype=np.float32) + data: TextToTextLocationIndexData = { + "textLocations": [ + {"messageOrdinal": 0, "chunkOrdinal": 0}, + {"messageOrdinal": 1, "chunkOrdinal": 0}, + ], + "embeddings": fake_emb, + } + index.deserialize(data) + assert index.get(0) == TextLocation(0) + assert index.get(1) == TextLocation(1) + assert index.get(2) is None + + +# --------------------------------------------------------------------------- +# get() helper +# --------------------------------------------------------------------------- + + +def test_get_returns_correct_location( + index: TextToTextLocationIndex, settings: TextEmbeddingIndexSettings +) -> None: + emb_size = 3 # default size for create_test_embedding_model() + n = 3 + fake_emb = np.zeros((n, emb_size), dtype=np.float32) + data: TextToTextLocationIndexData = { + "textLocations": [ + {"messageOrdinal": 10, "chunkOrdinal": 0}, + {"messageOrdinal": 20, "chunkOrdinal": 1}, + {"messageOrdinal": 30, "chunkOrdinal": 0}, + ], + "embeddings": fake_emb, + } + index.deserialize(data) + assert index.get(0) == TextLocation(10, 0) + assert index.get(1) == TextLocation(20, 1) + assert index.get(2) == TextLocation(30, 0) + assert index.get(3) is None diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index 9d98ae88..e2354405 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -11,7 +11,13 @@ from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH +from typeagent.storage.memory.collections import ( + MemoryMessageCollection, + MemorySemanticRefCollection, +) +from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex from typeagent.transcripts.transcript import ( + split_speaker_name, Transcript, TranscriptMessage, TranscriptMessageMeta, @@ -20,6 +26,7 @@ extract_speaker_from_text, get_transcript_duration, get_transcript_speakers, + parse_voice_tags, webvtt_timestamp_to_seconds, ) @@ -103,13 +110,6 @@ def conversation_settings( @pytest.mark.asyncio async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings): """Test importing a VTT file into a Transcript object.""" - from typeagent.storage.memory.collections import ( - MemoryMessageCollection, - MemorySemanticRefCollection, - ) - from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex - from typeagent.transcripts.transcript_ingest import parse_voice_tags - vtt_file = CONFUSE_A_CAT_VTT # Use in-memory storage to avoid database cleanup issues @@ -252,12 +252,6 @@ async def test_transcript_knowledge_extraction_slow( 4. Verifies both mechanical extraction (entities/actions from metadata) and LLM extraction (topics from content) work correctly """ - from typeagent.storage.memory.collections import ( - MemoryMessageCollection, - MemorySemanticRefCollection, - ) - from typeagent.storage.memory.semrefindex import TermToSemanticRefIndex - # Use in-memory storage for speed settings = ConversationSettings(embedding_model) @@ -311,7 +305,7 @@ async def test_transcript_knowledge_extraction_slow( # Enable knowledge extraction settings.semantic_ref_index_settings.auto_extract_knowledge = True - settings.semantic_ref_index_settings.batch_size = 10 + settings.semantic_ref_index_settings.concurrency = 10 # Add messages with indexing (this should extract knowledge) result = await transcript.add_messages_with_indexing(messages_list) @@ -345,3 +339,207 @@ async def test_transcript_knowledge_extraction_slow( ) print(f"Knowledge types: {knowledge_types}") print(f"Indexed terms: {len(terms)}") + + +# --------------------------------------------------------------------------- +# split_speaker_name +# --------------------------------------------------------------------------- + + +class TestSplitSpeakerName: + def test_single_word(self) -> None: + result = split_speaker_name("alice") + assert result is not None + assert result.first_name == "alice" + assert result.last_name is None + assert result.middle_name is None + + def test_two_words(self) -> None: + result = split_speaker_name("john smith") + assert result is not None + assert result.first_name == "john" + assert result.last_name == "smith" + assert result.middle_name is None + + def test_three_words(self) -> None: + result = split_speaker_name("john michael smith") + assert result is not None + assert result.first_name == "john" + assert result.middle_name == "michael" + assert result.last_name == "smith" + + def test_van_prefix_merged_into_last_name(self) -> None: + result = split_speaker_name("jan van eyck") + assert result is not None + assert result.first_name == "jan" + assert result.last_name == "van eyck" + assert result.middle_name is None + + def test_empty_string_returns_none(self) -> None: + result = split_speaker_name("") + assert result is None + + +# --------------------------------------------------------------------------- +# Serialize / deserialize roundtrip (in-memory, no LLM) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_transcript_serialize_deserialize_roundtrip() -> None: + """Serialize a transcript and deserialize into a fresh one — data is preserved.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + + # Build original transcript — use add_messages_with_indexing so the + # message text index (and its embeddings) are populated before serializing. + original = await Transcript.create(settings, name="roundtrip-test", tags=["foo"]) + msg1 = TranscriptMessage( + text_chunks=["Hello world"], + metadata=TranscriptMessageMeta(speaker="Alice", recipients=["Bob"]), + tags=["t1"], + timestamp="2024-01-01T00:00:00Z", + ) + msg2 = TranscriptMessage( + text_chunks=["Goodbye"], + metadata=TranscriptMessageMeta(speaker="Bob", recipients=[]), + tags=[], + timestamp="2024-01-01T00:01:00Z", + ) + await original.add_messages_with_indexing([msg1, msg2]) + data = await original.serialize() + + # Deserialize into a fresh transcript. + fresh_settings = ConversationSettings(embedding_model) + fresh_settings.semantic_ref_index_settings.auto_extract_knowledge = False + fresh = await Transcript.create(fresh_settings, name="", tags=[]) + await fresh.deserialize(data) + + assert fresh.name_tag == "roundtrip-test" + assert "foo" in fresh.tags + assert await fresh.messages.size() == 2 + + first = await fresh.messages.get_item(0) + assert first.text_chunks == ["Hello world"] + assert first.metadata.speaker == "Alice" + assert first.metadata.recipients == ["Bob"] + assert first.timestamp == "2024-01-01T00:00:00Z" + + +@pytest.mark.asyncio +async def test_transcript_deserialize_non_empty_raises() -> None: + """Deserializing into a non-empty Transcript raises RuntimeError.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + + transcript = await Transcript.create(settings, name="test", tags=[]) + await transcript.messages.append( + TranscriptMessage( + text_chunks=["existing"], + metadata=TranscriptMessageMeta(speaker=None, recipients=[]), + ) + ) + data = await transcript.serialize() + + # Trying to deserialize into it again must raise. + with pytest.raises(RuntimeError): + await transcript.deserialize(data) + + +# --------------------------------------------------------------------------- +# write_to_file / read_from_file roundtrip +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_write_and_read_from_file(tmp_path: os.PathLike[str]) -> None: + """write_to_file + read_from_file preserves names, tags, and messages.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + + original = await Transcript.create(settings, name="file-test", tags=["persisted"]) + msg = TranscriptMessage( + text_chunks=["Persisted message"], + metadata=TranscriptMessageMeta(speaker="Eve", recipients=[]), + timestamp="2024-06-01T12:00:00Z", + ) + # Use add_messages_with_indexing so embeddings are built before writing. + await original.add_messages_with_indexing([msg]) + prefix = os.path.join(str(tmp_path), "test_transcript") + await original.write_to_file(prefix) + + # Verify the _data.json file was written. + assert os.path.exists(prefix + "_data.json") + + # Read it back. + fresh_settings = ConversationSettings(embedding_model) + fresh_settings.semantic_ref_index_settings.auto_extract_knowledge = False + loaded = await Transcript.read_from_file(prefix, fresh_settings) + + assert loaded.name_tag == "file-test" + assert "persisted" in loaded.tags + assert await loaded.messages.size() == 1 + first = await loaded.messages.get_item(0) + assert first.text_chunks == ["Persisted message"] + assert first.metadata.speaker == "Eve" + + +# --------------------------------------------------------------------------- +# Speaker alias building +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_build_speaker_aliases_full_name() -> None: + """Full-name speakers create first-name ↔ full-name aliases.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + + transcript = await Transcript.create(settings, name="alias-test", tags=[]) + msg = TranscriptMessage( + text_chunks=["Hi"], + metadata=TranscriptMessageMeta(speaker="John Smith", recipients=[]), + ) + await transcript.messages.append(msg) + + # Rebuild aliases explicitly. + await transcript._build_speaker_aliases() + + secondary = transcript._get_secondary_indexes() + assert secondary.term_to_related_terms_index is not None + aliases = secondary.term_to_related_terms_index.aliases + + # "john" should be aliased to "john smith" and vice-versa. + john_aliases = await aliases.lookup_term("john") + assert john_aliases is not None + alias_texts = [t.text for t in john_aliases] + assert "john smith" in alias_texts + + full_aliases = await aliases.lookup_term("john smith") + assert full_aliases is not None + assert "john" in [t.text for t in full_aliases] + + +@pytest.mark.asyncio +async def test_build_speaker_aliases_single_name_no_alias() -> None: + """Single-word speaker names produce no aliases.""" + embedding_model = create_test_embedding_model() + settings = ConversationSettings(embedding_model) + + transcript = await Transcript.create(settings, name="alias-test2", tags=[]) + msg = TranscriptMessage( + text_chunks=["Hello"], + metadata=TranscriptMessageMeta(speaker="Alice", recipients=[]), + ) + await transcript.messages.append(msg) + await transcript._build_speaker_aliases() + + secondary = transcript._get_secondary_indexes() + assert secondary.term_to_related_terms_index is not None + aliases = secondary.term_to_related_terms_index.aliases + + # Single-name speaker — no alias entry expected. + result = await aliases.lookup_term("alice") + assert not result diff --git a/tests/test_utils.py b/tests/test_utils.py index ad526129..8eac9307 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -208,36 +208,134 @@ def test_no_api_version_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: with pytest.raises(RuntimeError, match="doesn't contain valid api-version"): utils.parse_azure_endpoint("TEST_ENDPOINT") + def test_no_deployment_returns_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Endpoint without /deployments/ yields deployment_name=None.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai?api-version=2024-06-01", + ) + endpoint, version, deployment = utils.parse_azure_endpoint_parts( + "TEST_ENDPOINT" + ) + assert endpoint == "https://myhost.openai.azure.com" + assert version == "2024-06-01" + assert deployment is None -class TestResolveAzureModelName: - """Tests for resolve_azure_model_name.""" - - def test_returns_deployment_name_from_endpoint( + def test_apim_style_deployment_extracted( self, monkeypatch: pytest.MonkeyPatch ) -> None: - """Deployment name in the endpoint takes precedence over the fallback.""" + """APIM-style URL: prefix before /openai kept, deployment name extracted.""" monkeypatch.setenv( "TEST_ENDPOINT", - "https://myhost.openai.azure.com/openai/deployments/gpt-4o-custom?api-version=2025-01-01-preview", + "https://apim.net/openai/openai/deployments/gpt-4o/chat/completions?api-version=2025-01-01-preview", ) - result = utils.resolve_azure_model_name("gpt-4o", "TEST_ENDPOINT") - assert result == "gpt-4o-custom" + endpoint, version, deployment = utils.parse_azure_endpoint_parts( + "TEST_ENDPOINT" + ) + assert endpoint == "https://apim.net/openai" + assert version == "2025-01-01-preview" + assert deployment == "gpt-4o" + + +class TestReindent: + def test_four_spaces_to_two(self) -> None: + text = "def foo():\n pass\n return 1" + result = utils.reindent(text) + assert result == "def foo():\n pass\n return 1" + + def test_empty_string(self) -> None: + assert utils.reindent("") == "" + + def test_no_indent(self) -> None: + assert utils.reindent("hello") == "hello" + + def test_nested_indent(self) -> None: + text = "a\n b\n c" + result = utils.reindent(text) + assert result == "a\n b\n c" + + +class TestTimelog: + def test_verbose_false_no_output(self) -> None: + buf = StringIO() + with redirect_stderr(buf): + with utils.timelog("silent", verbose=False): + pass + assert buf.getvalue() == "" + + def test_verbose_true_shows_label(self) -> None: + buf = StringIO() + with redirect_stderr(buf): + with utils.timelog("myblock", verbose=True): + pass + assert "myblock" in buf.getvalue() + + +class TestListDiff: + def test_identical_lists(self) -> None: + buf = StringIO() + with redirect_stdout(buf): + utils.list_diff("a", [1, 2, 3], "b", [1, 2, 3], max_items=10) + out = buf.getvalue() + assert "1" in out + assert "2" in out + + def test_different_lists(self) -> None: + buf = StringIO() + with redirect_stdout(buf): + utils.list_diff("left", [1, 2], "right", [1, 3], max_items=10) + assert buf.getvalue() != "" + + def test_no_max_items(self) -> None: + buf = StringIO() + with redirect_stdout(buf): + utils.list_diff("a", [1], "b", [2], max_items=0) + assert "1" in buf.getvalue() or "2" in buf.getvalue() + + def test_empty_lists(self) -> None: + buf = StringIO() + with redirect_stdout(buf): + utils.list_diff("a", [], "b", [], max_items=10) + # No output expected (nothing to diff) + assert buf.getvalue() == "" + + +class TestGetAzureApiKey: + def test_plain_key_returned_as_is(self) -> None: + assert utils.get_azure_api_key("my-secret-key") == "my-secret-key" + + def test_uppercase_identity_not_plain(self) -> None: + # "IDENTITY" as a plain key is not routed to token provider; only "identity" + # (lowercased) triggers that path. Since we can't call the identity provider + # in tests, just verify non-identity keys pass through unchanged. + assert utils.get_azure_api_key("APIKEY123") == "APIKEY123" + + +class TestMakeAgent: + def test_no_keys_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) + with pytest.raises(RuntimeError, match="Neither OPENAI_API_KEY"): + utils.make_agent(str) + - def test_falls_back_to_model_name(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Falls back to the provided model_name when no deployment in the endpoint.""" +class TestResolveAzureModelName: + def test_returns_model_name_when_no_deployment( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: monkeypatch.setenv( - "TEST_ENDPOINT", "https://myhost.openai.azure.com?api-version=2024-06-01" + "AZURE_OPENAI_ENDPOINT", + "https://myhost.openai.azure.com/openai?api-version=2024-06-01", ) - result = utils.resolve_azure_model_name("gpt-4o", "TEST_ENDPOINT") + result = utils.resolve_azure_model_name("gpt-4o") assert result == "gpt-4o" - def test_uses_default_endpoint_envvar( + def test_returns_deployment_when_present( self, monkeypatch: pytest.MonkeyPatch ) -> None: - """Uses AZURE_OPENAI_ENDPOINT by default.""" monkeypatch.setenv( "AZURE_OPENAI_ENDPOINT", - "https://myhost.openai.azure.com/openai/deployments/my-deploy?api-version=2024-06-01", + "https://myhost.openai.azure.com/openai/deployments/my-deploy/chat?api-version=2024-06-01", ) result = utils.resolve_azure_model_name("gpt-4o") assert result == "my-deploy" diff --git a/tools/benchmark_query.py b/tools/benchmark_query.py new file mode 100644 index 00000000..fbebe869 --- /dev/null +++ b/tools/benchmark_query.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benchmark lookup_term_filtered as a standalone script. + +Usage: + uv run python tools/benchmark_query.py +""" + +from __future__ import annotations + +import argparse +from collections.abc import Awaitable, Callable +import hashlib +import os +import shutil +import statistics +import tempfile +import time + +import numpy as np + +from typeagent.aitools.embeddings import ( + CachingEmbeddingModel, + NormalizedEmbedding, + NormalizedEmbeddings, +) +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import Term +from typeagent.knowpro.query import lookup_term_filtered +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + + +class DeterministicBenchmarkEmbedder: + def __init__(self, embedding_size: int) -> None: + self._embedding_size = embedding_size + + @property + def model_name(self) -> str: + return "benchmark-local" + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + return _compute_embedding(input, self._embedding_size) + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + if not input: + raise ValueError("Cannot embed an empty list") + return np.stack( + [_compute_embedding(value, self._embedding_size) for value in input] + ).astype(np.float32) + + +def _compute_embedding(text: str, embedding_size: int) -> NormalizedEmbedding: + digest = hashlib.sha256(text.encode("utf-8")).digest() + repeats = (embedding_size + len(digest) - 1) // len(digest) + data = (digest * repeats)[:embedding_size] + embedding = np.frombuffer(data, dtype=np.uint8).astype(np.float32) + embedding = embedding - np.float32(127.5) + norm = np.float32(np.linalg.norm(embedding)) + if norm > 0: + embedding = embedding / norm + return embedding.astype(np.float32) + + +def create_benchmark_embedding_model(embedding_size: int) -> CachingEmbeddingModel: + return CachingEmbeddingModel(DeterministicBenchmarkEmbedder(embedding_size)) + + +def create_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Benchmark lookup_term_filtered with a synthetic transcript.", + ) + parser.add_argument( + "--messages", + type=int, + default=200, + help="Number of synthetic messages to index before running the benchmark.", + ) + parser.add_argument( + "--rounds", + type=int, + default=200, + help="Number of timed rounds to run.", + ) + parser.add_argument( + "--warmup-rounds", + type=int, + default=20, + help="Number of untimed warmup rounds to run first.", + ) + parser.add_argument( + "--embedding-size", + type=int, + default=16, + help="Embedding size for the local deterministic benchmark model.", + ) + return parser + + +def make_settings(embedding_size: int) -> ConversationSettings: + settings = ConversationSettings( + model=create_benchmark_embedding_model(embedding_size) + ) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + return settings + + +def synthetic_messages(count: int) -> list[TranscriptMessage]: + return [ + TranscriptMessage( + text_chunks=[f"Message {i} about topic {i % 10}"], + metadata=TranscriptMessageMeta(speaker=f"Speaker{i % 3}"), + tags=[f"tag{i % 5}"], + ) + for i in range(count) + ] + + +async def create_indexed_transcript( + settings: ConversationSettings, + storage: SqliteStorageProvider, + message_count: int, +) -> Transcript: + settings.storage_provider = storage + transcript = await Transcript.create(settings, name="benchmark-query") + await transcript.add_messages_with_indexing(synthetic_messages(message_count)) + return transcript + + +async def find_best_term(transcript: Transcript) -> tuple[str, int]: + semref_index = transcript.semantic_ref_index + assert semref_index is not None + + best_term: str | None = None + best_count = 0 + + for term in await semref_index.get_terms(): + refs = await semref_index.lookup_term(term) + ref_count = len(refs) if refs is not None else 0 + if ref_count > best_count: + best_count = ref_count + best_term = term + + if best_term is None: + raise ValueError("No terms found after indexing") + + return best_term, best_count + + +async def run_benchmark( + target: Callable[[], Awaitable[None]], + rounds: int, + warmup_rounds: int, +) -> list[float]: + for _ in range(warmup_rounds): + await target() + + samples_us: list[float] = [] + for _ in range(rounds): + start = time.perf_counter_ns() + await target() + elapsed_us = (time.perf_counter_ns() - start) / 1_000 + samples_us.append(elapsed_us) + return samples_us + + +def print_report( + label: str, samples_us: list[float], rounds: int, warmup_rounds: int +) -> None: + print(label) + print(f" rounds: {rounds} ({warmup_rounds} warmup)") + print(f" min: {min(samples_us):9.3f} us") + print(f" mean: {statistics.fmean(samples_us):9.3f} us") + print(f" median: {statistics.median(samples_us):9.3f} us") + print(f" max: {max(samples_us):9.3f} us") + + +async def main() -> None: + args = create_arg_parser().parse_args() + temp_dir = tempfile.mkdtemp(prefix="benchmark-query-") + db_path = os.path.join(temp_dir, "query_bench.db") + + settings = make_settings(args.embedding_size) + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + + try: + transcript = await create_indexed_transcript(settings, storage, args.messages) + best_term, best_count = await find_best_term(transcript) + print(f"Benchmarking term {best_term!r} with {best_count} matches") + + term = Term(text=best_term) + semref_index = transcript.semantic_ref_index + semantic_refs = transcript.semantic_refs + assert semref_index is not None + assert semantic_refs is not None + + async def target() -> None: + results = await lookup_term_filtered( + semref_index, + term, + semantic_refs, + lambda _metadata, _scored: True, + ) + if results is None: + raise ValueError(f"No results found for {best_term!r}") + + samples_us = await run_benchmark(target, args.rounds, args.warmup_rounds) + print_report( + "lookup_term_filtered (accept-all filter)", + samples_us, + args.rounds, + args.warmup_rounds, + ) + finally: + await storage.close() + shutil.rmtree(temp_dir) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/tools/benchmark_semref_writes.py b/tools/benchmark_semref_writes.py new file mode 100644 index 00000000..3162f495 --- /dev/null +++ b/tools/benchmark_semref_writes.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benchmark semref index write strategies: per-item vs batched. + +No API keys or network access required — uses synthetic knowledge data +and the deterministic test embedding model. + +The "individual" path inlines the pre-optimization logic (one append + +add_term per entity/action/topic) so results are comparable on any +branch without switching. + +Usage: + uv run python tools/benchmark_semref_writes.py + uv run python tools/benchmark_semref_writes.py --chunks 100 --rounds 20 +""" + +from __future__ import annotations + +import argparse +import asyncio +import os +import shutil +import statistics +import tempfile +import time + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import SemanticRef, Topic +from typeagent.storage.memory.semrefindex import ( + add_knowledge_batch_to_semantic_ref_index, + text_range_from_message_chunk, + validate_entity, + verify_has_semantic_ref_index, +) +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + +# --------------------------------------------------------------------------- +# Inlined pre-optimization write path (one append + add_term per item) +# --------------------------------------------------------------------------- + + +async def _individual_add_knowledge( + conversation, + message_ordinal, + chunk_ordinal, + knowledge, +): + """Reproduces the pre-optimization per-item write logic.""" + verify_has_semantic_ref_index(conversation) + semantic_refs = conversation.semantic_refs + assert semantic_refs is not None + semantic_ref_index = conversation.semantic_ref_index + assert semantic_ref_index is not None + + for entity in knowledge.entities: + if not validate_entity(entity): + continue + ordinal = await semantic_refs.size() + await semantic_refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range_from_message_chunk(message_ordinal, chunk_ordinal), + knowledge=entity, + ) + ) + await semantic_ref_index.add_term(entity.name, ordinal) + for type_name in entity.type: + await semantic_ref_index.add_term(type_name, ordinal) + if entity.facets: + for facet in entity.facets: + if facet is not None: + await semantic_ref_index.add_term(facet.name, ordinal) + if facet.value is not None: + await semantic_ref_index.add_term(str(facet.value), ordinal) + + for action in list(knowledge.actions) + list(knowledge.inverse_actions): + ordinal = await semantic_refs.size() + await semantic_refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range_from_message_chunk(message_ordinal, chunk_ordinal), + knowledge=action, + ) + ) + await semantic_ref_index.add_term(" ".join(action.verbs), ordinal) + if action.subject_entity_name != "none": + await semantic_ref_index.add_term(action.subject_entity_name, ordinal) + if action.object_entity_name != "none": + await semantic_ref_index.add_term(action.object_entity_name, ordinal) + if action.indirect_object_entity_name != "none": + await semantic_ref_index.add_term( + action.indirect_object_entity_name, ordinal + ) + if action.params: + for param in action.params: + if isinstance(param, str): + await semantic_ref_index.add_term(param, ordinal) + else: + await semantic_ref_index.add_term(param.name, ordinal) + if isinstance(param.value, str): + await semantic_ref_index.add_term(param.value, ordinal) + if action.subject_entity_facet is not None: + await semantic_ref_index.add_term(action.subject_entity_facet.name, ordinal) + if action.subject_entity_facet.value is not None: + await semantic_ref_index.add_term( + str(action.subject_entity_facet.value), ordinal + ) + + for topic_text in knowledge.topics: + ordinal = await semantic_refs.size() + await semantic_refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range_from_message_chunk(message_ordinal, chunk_ordinal), + knowledge=Topic(text=topic_text), + ) + ) + await semantic_ref_index.add_term(topic_text, ordinal) + + +# --------------------------------------------------------------------------- +# Synthetic data +# --------------------------------------------------------------------------- + + +def synthetic_knowledge(chunk_index: int) -> kplib.KnowledgeResponse: + return kplib.KnowledgeResponse( + entities=[ + kplib.ConcreteEntity( + name=f"entity_{chunk_index}_{j}", + type=[f"type_{j}", f"category_{chunk_index % 5}"], + facets=[ + kplib.Facet(name=f"facet_{j}", value=f"value_{j}") for j in range(2) + ], + ) + for j in range(3) + ], + actions=[ + kplib.Action( + verbs=[f"verb_{chunk_index}"], + verb_tense="past", + subject_entity_name=f"entity_{chunk_index}_0", + object_entity_name=f"entity_{chunk_index}_1", + indirect_object_entity_name="none", + params=[f"param_{chunk_index}"], + ) + ], + inverse_actions=[], + topics=[f"topic_{chunk_index}", f"theme_{chunk_index % 3}"], + ) + + +def synthetic_messages(count: int) -> list[TranscriptMessage]: + return [ + TranscriptMessage( + text_chunks=[f"Message {i} about topic {i % 10}"], + metadata=TranscriptMessageMeta(speaker=f"Speaker{i % 3}"), + tags=[f"tag{i % 5}"], + ) + for i in range(count) + ] + + +# --------------------------------------------------------------------------- +# Benchmark harness +# --------------------------------------------------------------------------- + + +async def create_transcript(db_path: str) -> Transcript: + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + return await Transcript.create(settings, name="bench-semref") + + +async def bench_individual(transcript: Transcript, chunks: int) -> None: + for i in range(chunks): + await _individual_add_knowledge(transcript, i, 0, synthetic_knowledge(i)) + + +async def bench_batched(transcript: Transcript, chunks: int) -> None: + items = [(i, 0, synthetic_knowledge(i)) for i in range(chunks)] + await add_knowledge_batch_to_semantic_ref_index(transcript, items) + + +async def run_benchmark( + label: str, + factory, + chunks: int, + rounds: int, + warmup: int, +) -> list[float]: + samples_us: list[float] = [] + for r in range(warmup + rounds): + temp_dir = tempfile.mkdtemp(prefix="bench-semref-") + db_path = os.path.join(temp_dir, "bench.db") + try: + transcript = await create_transcript(db_path) + msgs = synthetic_messages(chunks) + await transcript.add_messages_with_indexing(msgs) + + start = time.perf_counter_ns() + await factory(transcript, chunks) + elapsed_us = (time.perf_counter_ns() - start) / 1_000 + + if r >= warmup: + samples_us.append(elapsed_us) + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + return samples_us + + +def print_report(label: str, samples_us: list[float], rounds: int, warmup: int) -> None: + print(f"\n{label}") + print(f" rounds: {rounds} ({warmup} warmup)") + print(f" min: {min(samples_us):12.1f} us") + print(f" mean: {statistics.fmean(samples_us):12.1f} us") + print(f" median: {statistics.median(samples_us):12.1f} us") + print(f" max: {max(samples_us):12.1f} us") + + +async def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark semref index write strategies.", + ) + parser.add_argument( + "--chunks", + type=int, + default=50, + help="Number of knowledge chunks to write per run (default: 50).", + ) + parser.add_argument( + "--rounds", + type=int, + default=10, + help="Number of timed rounds (default: 10).", + ) + parser.add_argument( + "--warmup", + type=int, + default=2, + help="Number of untimed warmup rounds (default: 2).", + ) + args = parser.parse_args() + + knowledge_sample = synthetic_knowledge(0) + refs_per_chunk = ( + len([e for e in knowledge_sample.entities if e.name]) + + len(knowledge_sample.actions) + + len(knowledge_sample.inverse_actions) + + len(knowledge_sample.topics) + ) + print(f"Chunks per run: {args.chunks}") + print(f"Semrefs per chunk: ~{refs_per_chunk}") + print(f"Total semrefs per run: ~{refs_per_chunk * args.chunks}") + + individual = await run_benchmark( + "Individual writes", + bench_individual, + args.chunks, + args.rounds, + args.warmup, + ) + print_report( + "Individual writes (per-entity append + add_term)", + individual, + args.rounds, + args.warmup, + ) + + batched = await run_benchmark( + "Batched writes", + bench_batched, + args.chunks, + args.rounds, + args.warmup, + ) + print_report( + "Batched writes (bulk extend + add_terms_batch)", + batched, + args.rounds, + args.warmup, + ) + + speedup = statistics.fmean(individual) / statistics.fmean(batched) + print(f"\nSpeedup: {speedup:.2f}x") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tools/benchmark_vectorbase.py b/tools/benchmark_vectorbase.py new file mode 100644 index 00000000..d14314a7 --- /dev/null +++ b/tools/benchmark_vectorbase.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benchmark VectorBase lookup methods as a standalone script. + +Usage: + uv run python tools/benchmark_vectorbase.py +""" + +from __future__ import annotations + +import argparse +from collections.abc import Callable +import statistics +import time + +import numpy as np + +from typeagent.aitools.embeddings import NormalizedEmbedding, NormalizedEmbeddings +from typeagent.aitools.vectorbase import ( + ScoredInt, + TextEmbeddingIndexSettings, + VectorBase, +) + + +class NullEmbeddingModel: + @property + def model_name(self) -> str: + return "benchmark-local" + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + return None + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + raise RuntimeError("VectorBase benchmark does not use embedding generation") + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + raise RuntimeError("VectorBase benchmark does not use embedding generation") + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + raise RuntimeError("VectorBase benchmark does not use embedding generation") + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + raise RuntimeError("VectorBase benchmark does not use embedding generation") + + +def create_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Benchmark VectorBase lookup methods with synthetic vectors.", + ) + parser.add_argument( + "--rounds", + type=int, + default=200, + help="Number of timed rounds to run for each benchmark.", + ) + parser.add_argument( + "--warmup-rounds", + type=int, + default=20, + help="Number of untimed warmup rounds to run first.", + ) + parser.add_argument( + "--dim", + type=int, + default=384, + help="Embedding dimension to generate.", + ) + parser.add_argument( + "--subset-size", + type=int, + default=1_000, + help="Subset size for fuzzy_lookup_embedding_in_subset.", + ) + return parser + + +def make_vectorbase( + vector_count: int, dim: int, seed: int +) -> tuple[VectorBase, NormalizedEmbedding]: + rng = np.random.default_rng(seed) + vectors = rng.standard_normal((vector_count, dim)).astype(np.float32) + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + vectors /= norms + + settings = TextEmbeddingIndexSettings(embedding_model=NullEmbeddingModel()) + vectorbase = VectorBase(settings) + vectorbase.add_embeddings(None, vectors) + + query = rng.standard_normal(dim).astype(np.float32) + query /= np.linalg.norm(query) + return vectorbase, query + + +def run_benchmark( + target: Callable[[], list[ScoredInt]], rounds: int, warmup_rounds: int +) -> list[float]: + for _ in range(warmup_rounds): + target() + + samples_us: list[float] = [] + for _ in range(rounds): + start = time.perf_counter_ns() + target() + elapsed_us = (time.perf_counter_ns() - start) / 1_000 + samples_us.append(elapsed_us) + return samples_us + + +def validate_result(result: list[ScoredInt]) -> None: + if len(result) != 10: + raise ValueError(f"Expected 10 hits, got {len(result)}") + if not all(isinstance(item, ScoredInt) for item in result): + raise TypeError("Expected every result item to be a ScoredInt") + + +def print_report( + label: str, samples_us: list[float], rounds: int, warmup_rounds: int +) -> None: + print(label) + print(f" rounds: {rounds} ({warmup_rounds} warmup)") + print(f" min: {min(samples_us):9.3f} us") + print(f" mean: {statistics.fmean(samples_us):9.3f} us") + print(f" median: {statistics.median(samples_us):9.3f} us") + print(f" max: {max(samples_us):9.3f} us") + + +def main() -> None: + args = create_arg_parser().parse_args() + + vb_1k, query_1k = make_vectorbase(1_000, args.dim, seed=42) + vb_10k, query_10k = make_vectorbase(10_000, args.dim, seed=43) + subset_rng = np.random.default_rng(99) + subset = subset_rng.choice(10_000, size=args.subset_size, replace=False).tolist() + + benchmarks: list[tuple[str, Callable[[], list[ScoredInt]]]] = [ + ( + "fuzzy_lookup_embedding (1k vectors)", + lambda: vb_1k.fuzzy_lookup_embedding(query_1k, max_hits=10, min_score=0.0), + ), + ( + "fuzzy_lookup_embedding (10k vectors)", + lambda: vb_10k.fuzzy_lookup_embedding( + query_10k, max_hits=10, min_score=0.0 + ), + ), + ( + f"fuzzy_lookup_embedding_in_subset ({args.subset_size} of 10k)", + lambda: vb_10k.fuzzy_lookup_embedding_in_subset( + query_10k, + subset, + max_hits=10, + min_score=0.0, + ), + ), + ] + + for label, target in benchmarks: + validate_result(target()) + samples_us = run_benchmark(target, args.rounds, args.warmup_rounds) + print_report(label, samples_us, args.rounds, args.warmup_rounds) + + +if __name__ == "__main__": + main() diff --git a/tools/ingest_email.py b/tools/ingest_email.py index b7768e8a..34c4c6ce 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -17,25 +17,17 @@ python tools/query.py --database email.db --query "What was discussed?" """ -""" -TODO - -- Collect knowledge outside db transaction to reduce lock time -""" - import argparse import asyncio +from collections.abc import AsyncIterator from datetime import datetime from pathlib import Path import sys import time -import traceback from typing import Iterable from dotenv import load_dotenv -import openai - from typeagent.aitools import utils from typeagent.emails.email_import import ( decode_encoded_words, @@ -45,6 +37,7 @@ from typeagent.emails.email_memory import EmailMemory from typeagent.emails.email_message import EmailMessage from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import AddMessagesResult from typeagent.storage.utils import create_storage_provider @@ -134,6 +127,35 @@ def create_arg_parser() -> argparse.ArgumentParser: ), ) + # Concurrency / batching + parser.add_argument( + "--concurrency", + type=int, + default=None, + metavar="N", + help=( + "Number of concurrent LLM extraction requests. " + "Default: 4 (from ConversationSettings)." + ), + ) + parser.add_argument( + "--batch-size", + type=int, + default=100, + metavar="N", + help="Number of chunks per commit batch. Default: 100.", + ) + parser.add_argument( + "--max-chunks", + type=int, + default=20, + metavar="N", + help=( + "Maximum number of text chunks to keep per email. " + "Extra chunks are silently dropped. Default: 20." + ), + ) + return parser @@ -149,8 +171,17 @@ def _validate_args(args: argparse.Namespace) -> None: if args.limit is not None and args.limit <= 0: errors.append("--limit must be a positive integer.") - # --offset without --limit is allowed (skip first N, ingest the rest) - # --limit without --offset is allowed (ingest at most N) + # --concurrency must be positive when given + if args.concurrency is not None and args.concurrency <= 0: + errors.append("--concurrency must be a positive integer.") + + # --batch-size must be positive + if args.batch_size <= 0: + errors.append("--batch-size must be a positive integer.") + + # --max-chunks must be positive when given + if args.max_chunks is not None and args.max_chunks <= 0: + errors.append("--max-chunks must be a positive integer.") # --start-date must be before --stop-date when both are given if args.start_date and args.stop_date: @@ -239,7 +270,7 @@ def _iter_emails( sliced_total = len(email_files) for i, email_file in enumerate(email_files): label = f"[{i + 1}/{sliced_total}] {email_file}" - yield str(email_file), email_file, label + yield str(email_file.resolve()), email_file, label def _print_email_verbose(email: EmailMessage) -> None: @@ -267,6 +298,65 @@ def _print_email_verbose(email: EmailMessage) -> None: print(f" {preview}") +async def _email_generator( + email_entries: list[tuple[str, Path, str]], + verbose: bool, + start_date: datetime | None, + stop_date: datetime | None, + max_chunks: int | None, + counters: dict[str, int], + already_ingested: set[str], +) -> AsyncIterator[EmailMessage]: + """Async generator that parses and yields EmailMessage objects. + + *email_entries* is a pre-collected list of ``(source_id, file_path, label)`` + tuples produced by :func:`_iter_emails`. + + *already_ingested* is the set of source_ids known to be in the DB at + the start of this run (one bulk query). Files in this set are skipped + before parsing. + + *counters* is mutated in place to track ``parsed``, ``skipped``, + ``date_skipped``, and ``failed`` counts for the caller's summary. + """ + for source_id, email_file, label in email_entries: + if source_id in already_ingested: + counters["skipped"] += 1 + continue + + try: + email = import_email_from_file(str(email_file)) + except Exception as e: + counters["failed"] += 1 + print( + f"Error parsing {source_id}: {e!r:.150s}", + file=sys.stderr, + ) + continue + + # Apply date filter + if not email_matches_date_filter(email.timestamp, start_date, stop_date): + counters["date_skipped"] += 1 + if verbose: + print(f"{label} [Outside date range, skipping]") + continue + + if verbose: + print(label) + _print_email_verbose(email) + + # Truncate chunks if --max-chunks is set + if max_chunks is not None and len(email.text_chunks) > max_chunks: + if verbose: + print(f" Truncating {len(email.text_chunks)} chunks to {max_chunks}") + email.text_chunks = email.text_chunks[:max_chunks] + + # Set source_id so streaming API handles dedup and tracking + email.source_id = source_id + counters["parsed"] += 1 + yield email + + async def ingest_emails( eml_paths: list[str], database: str, @@ -275,6 +365,9 @@ async def ingest_emails( stop_date: datetime | None = None, offset: int = 0, limit: int | None = None, + concurrency: int | None = None, + batch_size: int = 100, + max_chunks: int | None = 20, ) -> None: """Ingest email files into a database.""" @@ -288,6 +381,11 @@ async def ingest_emails( print("Setting up conversation settings...") settings = ConversationSettings() + + # Override concurrency if specified + if concurrency is not None: + settings.semantic_ref_index_settings.concurrency = concurrency + settings.storage_provider = await create_storage_provider( settings.message_text_index_settings, settings.related_term_index_settings, @@ -301,97 +399,129 @@ async def ingest_emails( if verbose: print(f"Target database: {database}") - batch_size = settings.semantic_ref_index_settings.batch_size + effective_concurrency = settings.semantic_ref_index_settings.concurrency + if verbose: + print(f"Concurrency: {effective_concurrency}") + print(f"Batch size: {batch_size} chunks") + + # One bulk query: collect all source_ids, then ask the DB which are + # already ingested. This replaces N per-file is_source_ingested calls + # with a single are_sources_ingested call. + storage = settings.storage_provider + email_entries = list(_iter_emails(eml_paths, verbose, offset, limit)) + all_source_ids = [sid for sid, _, _ in email_entries] + already_ingested = await storage.are_sources_ingested(all_source_ids) + if already_ingested and verbose: + print( + f"Pre-filter: {len(already_ingested)} of {len(all_source_ids)} already ingested" + ) + if verbose: - print(f"Batch size: {batch_size}") print("\nParsing and importing emails...") - success_count = 0 - failed_count = 0 - skipped_count = 0 start_time = time.time() + last_batch_time = start_time + + # Counters mutated by the generator and callback + counters: dict[str, int] = { + "parsed": 0, + "skipped": 0, + "date_skipped": 0, + "failed": 0, + "ingested": 0, + "chunks": 0, + "semrefs": 0, + "batches": 0, + } + + def on_batch_committed(result: AddMessagesResult) -> None: + nonlocal last_batch_time + counters["ingested"] += result.messages_added + counters["chunks"] += result.chunks_added + counters["semrefs"] += result.semrefs_added + counters["batches"] += 1 + now = time.time() + batch_secs = now - last_batch_time + last_batch_time = now + elapsed = now - start_time + per_chunk = batch_secs / result.chunks_added if result.chunks_added else 0 + parts = [ + f" Batch {counters['batches']}:", + f"+{result.messages_added} messages,", + f"+{result.chunks_added} chunks,", + f"+{result.semrefs_added} semrefs", + ] + print( + f"{' '.join(parts)} | " + f"{batch_secs:.1f}s ({per_chunk:.2f}s/chunk) | " + f"{counters['ingested']} total ingested | " + f"{elapsed:.1f}s elapsed", + flush=True, + ) - semref_coll = await settings.storage_provider.get_semantic_ref_collection() - storage_provider = settings.storage_provider - - for source_id, email_file, label in _iter_emails(eml_paths, verbose, offset, limit): - try: - if verbose: - print(label, end="", flush=True) - - # Parse the email only after confirming it hasn't been ingested - email = import_email_from_file(str(email_file)) - - # Apply date filter - if not email_matches_date_filter(email.timestamp, start_date, stop_date): - skipped_count += 1 - if verbose: - print(" [Outside date range, skipping]") - continue - - if verbose: - _print_email_verbose(email) - - # Ingest the email - try: - await email_memory.add_messages_with_indexing( - [email], source_ids=[source_id] - ) - success_count += 1 - except openai.AuthenticationError as e: - if verbose: - traceback.print_exc() - sys.exit(f"Authentication error: {e!r}") - - # Print progress periodically - if (success_count + failed_count) % batch_size == 0: - elapsed = time.time() - start_time - semref_count = await semref_coll.size() - print( - f"\n{label} " - f"{success_count} imported | " - f"{failed_count} failed | " - f"{skipped_count} skipped | " - f"{semref_count} semrefs | " - f"{elapsed:.1f}s elapsed\n" - ) + message_stream = _email_generator( + email_entries, + verbose, + start_date, + stop_date, + max_chunks, + counters, + already_ingested, + ) - except Exception as e: - failed_count += 1 - print( - f"Error processing {source_id}: {e!r:.150s}", - file=sys.stderr, - ) - mod = e.__class__.__module__ - qual = e.__class__.__qualname__ - exc_name = qual if mod == "builtins" else f"{mod}.{qual}" - async with storage_provider: - await storage_provider.mark_source_ingested(source_id, exc_name) - if verbose: - traceback.print_exc(limit=10) + result: AddMessagesResult | None = None + interrupted = False + try: + result = await email_memory.add_messages_streaming( + message_stream, + batch_size=batch_size, + on_batch_committed=on_batch_committed, + ) + except (KeyboardInterrupt, asyncio.CancelledError): + interrupted = True # Final summary elapsed = time.time() - start_time - semref_count = await semref_coll.size() + if interrupted and counters["batches"] == 0: + print() + print("Interrupted before any batches were committed.") + return + + messages_ingested = ( + result.messages_added if result is not None else counters["ingested"] + ) + total_chunks = result.chunks_added if result is not None else counters["chunks"] + semrefs_added = result.semrefs_added if result is not None else counters["semrefs"] + total_skipped = counters["skipped"] + overall_per_chunk = elapsed / total_chunks if total_chunks else 0 print() if verbose: - print(f"Successfully imported {success_count} email(s)") - if skipped_count: - print(f"Skipped {skipped_count} already-ingested email(s)") - if failed_count: - print(f"Failed to import {failed_count} email(s)") - print(f"Extracted {semref_count} semantic references") + if interrupted: + print("Ingestion interrupted by user (^C).") + print(f"Successfully ingested {messages_ingested} email(s)") + print(f"Ingested {total_chunks} chunk(s)") + if total_skipped: + print(f"Skipped {total_skipped} already-ingested email(s)") + if counters["date_skipped"]: + print(f"Skipped {counters['date_skipped']} email(s) outside date range") + if counters["failed"]: + print(f"Failed to parse {counters['failed']} email(s)") + print(f"Extracted {semrefs_added} semantic references") print(f"Total time: {elapsed:.1f}s") + print(f"Overall time per chunk: {overall_per_chunk:.2f}s/chunk") else: print( - f"Imported {success_count} emails to {database} " - f"({semref_count} refs, {elapsed:.1f}s)" + f"Ingested {messages_ingested} emails to {database} " + f"({total_chunks} chunks, {semrefs_added} refs added, {elapsed:.1f}s, " + f"{overall_per_chunk:.2f}s/chunk)" ) - if skipped_count: - print(f"Skipped: {skipped_count} (already ingested)") - if failed_count: - print(f"Failed: {failed_count}") + if total_skipped: + print(f"Skipped: {total_skipped} (already ingested)") + if counters["date_skipped"]: + print(f"Skipped: {counters['date_skipped']} (outside date range)") + if counters["failed"]: + print(f"Failed: {counters['failed']}") # Show usage information print() @@ -419,6 +549,9 @@ def main() -> None: stop_date=stop_date, offset=args.offset, limit=args.limit, + concurrency=args.concurrency, + batch_size=args.batch_size, + max_chunks=args.max_chunks, ) ) diff --git a/tools/ingest_podcast.py b/tools/ingest_podcast.py index 39195145..c0f7303d 100644 --- a/tools/ingest_podcast.py +++ b/tools/ingest_podcast.py @@ -31,7 +31,13 @@ async def main(): "--batch-size", type=int, default=10, - help="Batch size for message indexing (default 10)", + help="Number of messages per indexing call (default 10)", + ) + parser.add_argument( + "--concurrency", + type=int, + default=0, + help="Max concurrent knowledge extractions (0 = use settings default)", ) parser.add_argument( "--start-message", @@ -75,6 +81,7 @@ async def main(): dbname=args.database, batch_size=args.batch_size, start_message=args.start_message, + concurrency=args.concurrency, verbose=not args.quiet, ) except (RuntimeError, ValueError) as err: diff --git a/tools/ingest_vtt.py b/tools/ingest_vtt.py index 7fbf38fc..b1448e6d 100644 --- a/tools/ingest_vtt.py +++ b/tools/ingest_vtt.py @@ -15,6 +15,7 @@ import argparse import asyncio +from collections.abc import AsyncIterator from datetime import timedelta import os from pathlib import Path @@ -27,6 +28,7 @@ from typeagent.aitools.model_adapters import create_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import ConversationMetadata +from typeagent.knowpro.interfaces_core import AddMessagesResult from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH from typeagent.storage.utils import create_storage_provider from typeagent.transcripts.transcript import ( @@ -75,10 +77,17 @@ def create_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--batchsize", + "--concurrency", type=int, default=None, - help="Batch size for knowledge extraction (default: from settings)", + help="Max concurrent knowledge extractions (default: from settings)", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=100, + help="Number of chunks per commit batch (default: 100)", ) parser.add_argument( @@ -132,7 +141,8 @@ async def ingest_vtt_files( name: str | None = None, merge_consecutive: bool = False, verbose: bool = False, - batchsize: int | None = None, + concurrency: int | None = None, + batch_size: int = 100, embedding_name: str | None = None, ) -> None: """Ingest one or more VTT files into a database.""" @@ -227,9 +237,9 @@ async def ingest_vtt_files( # Update settings to use our storage provider settings.storage_provider = storage_provider - # Override batch size if specified - if batchsize is not None: - settings.semantic_ref_index_settings.batch_size = batchsize + # Override concurrency if specified + if concurrency is not None: + settings.semantic_ref_index_settings.concurrency = concurrency if verbose: print("Settings and storage provider configured") @@ -242,8 +252,8 @@ async def ingest_vtt_files( print(f"\nParsing VTT files and creating messages...") try: # Get collections from our storage provider - msg_coll = await storage_provider.get_message_collection() - semref_coll = await storage_provider.get_semantic_ref_collection() + msg_coll = storage_provider.messages + semref_coll = storage_provider.semantic_refs # Database should be empty (we checked it doesn't exist earlier) # But verify collections are empty just in case @@ -254,110 +264,100 @@ async def ingest_vtt_files( ) sys.exit(1) - # Process all VTT files and collect messages - all_messages: list[TranscriptMessage] = [] - time_offset = 0.0 # Cumulative time offset for multiple files + msg_count = 0 - for file_idx, vtt_file in enumerate(vtt_files): - if verbose: - print(f" Processing {vtt_file}...") - if file_idx > 0: - print(f" Time offset: {time_offset:.2f} seconds") + async def _message_stream() -> AsyncIterator[TranscriptMessage]: + nonlocal msg_count + time_offset = 0.0 - # Parse VTT file - try: - vtt = webvtt.read(vtt_file) - except Exception as e: - print( - f"Error: Failed to parse VTT file {vtt_file}: {e}", file=sys.stderr - ) - sys.exit(1) - - current_speaker = None - current_text_chunks = [] - current_start_time = None - file_max_end_time = 0.0 # Track the maximum end time in this file - - def save_current_message(): - """Helper to save the current message and add to all_messages.""" - if current_text_chunks and current_start_time is not None: - combined_text = " ".join(current_text_chunks).strip() - if combined_text: - # Calculate timestamp from WebVTT start time - offset_seconds = webvtt_timestamp_to_seconds(current_start_time) - timestamp = format_timestamp_utc( - UNIX_EPOCH + timedelta(seconds=offset_seconds) - ) - metadata = TranscriptMessageMeta( - speaker=current_speaker, - recipients=[], - ) - message = TranscriptMessage( - text_chunks=[combined_text], - metadata=metadata, - timestamp=timestamp, - ) - all_messages.append(message) - - for caption in vtt: - # Skip empty captions - if not caption.text.strip(): - continue - - # Parse raw text for voice tags (handles multiple speakers per cue) - raw_text = getattr(caption, "raw_text", caption.text) - voice_segments = parse_voice_tags(raw_text) - - # Convert WebVTT timestamps and apply offset for multi-file continuity - start_time_seconds = ( - vtt_timestamp_to_seconds(caption.start) + time_offset - ) - end_time_seconds = vtt_timestamp_to_seconds(caption.end) + time_offset - start_time = seconds_to_vtt_timestamp(start_time_seconds) - - # Track the maximum end time for this file - if end_time_seconds > file_max_end_time: - file_max_end_time = end_time_seconds + for file_idx, vtt_file in enumerate(vtt_files): + if verbose: + print(f" Processing {vtt_file}...") + if file_idx > 0: + print(f" Time offset: {time_offset:.2f} seconds") - # Process each voice segment in this caption - for speaker, text in voice_segments: - if not text.strip(): + try: + vtt = webvtt.read(vtt_file) + except Exception as e: + print( + f"Error: Failed to parse VTT file {vtt_file}: {e}", + file=sys.stderr, + ) + sys.exit(1) + + current_speaker: str | None = None + current_text_chunks: list[str] = [] + current_start_time: str | None = None + file_max_end_time = 0.0 + + def _build_message() -> TranscriptMessage | None: + if current_text_chunks and current_start_time is not None: + combined_text = " ".join(current_text_chunks).strip() + if combined_text: + offset_seconds = webvtt_timestamp_to_seconds( + current_start_time + ) + timestamp = format_timestamp_utc( + UNIX_EPOCH + timedelta(seconds=offset_seconds) + ) + metadata = TranscriptMessageMeta( + speaker=current_speaker, + recipients=[], + ) + return TranscriptMessage( + text_chunks=[combined_text], + metadata=metadata, + timestamp=timestamp, + source_id=f"{vtt_file}#{msg_count}", + ) + return None + + for caption in vtt: + if not caption.text.strip(): continue - # If we should merge consecutive segments from the same speaker - if ( - merge_consecutive - and speaker == current_speaker - and current_text_chunks - ): - # Merge with current message - current_text_chunks.append(text) - else: - # Save previous message if it exists - save_current_message() - - # Start new message - current_speaker = speaker - current_text_chunks = [text] if text.strip() else [] - current_start_time = start_time - - # Don't forget the last message from this file - save_current_message() + raw_text = getattr(caption, "raw_text", caption.text) + voice_segments = parse_voice_tags(raw_text) - if verbose: - print(f" Extracted {len(all_messages)} messages so far") - if file_max_end_time > 0: - print( - f" File time range: 0.00s to {file_max_end_time - time_offset:.2f}s (with offset: {time_offset:.2f}s to {file_max_end_time:.2f}s)" + start_time_seconds = ( + vtt_timestamp_to_seconds(caption.start) + time_offset ) + end_time_seconds = ( + vtt_timestamp_to_seconds(caption.end) + time_offset + ) + start_time = seconds_to_vtt_timestamp(start_time_seconds) + + if end_time_seconds > file_max_end_time: + file_max_end_time = end_time_seconds + + for speaker, text in voice_segments: + if not text.strip(): + continue + + if ( + merge_consecutive + and speaker == current_speaker + and current_text_chunks + ): + current_text_chunks.append(text) + else: + msg = _build_message() + if msg is not None: + msg_count += 1 + yield msg + + current_speaker = speaker + current_text_chunks = [text] if text.strip() else [] + current_start_time = start_time + + # Last message from this file + msg = _build_message() + if msg is not None: + msg_count += 1 + yield msg - # Update time offset for next file: add 5 seconds gap - if file_max_end_time > 0: - time_offset = file_max_end_time + 5.0 - - # Add all messages to the database in batches with indexing - if verbose: - print(f"\nAdding {len(all_messages)} total messages to database...") + if file_max_end_time > 0: + time_offset = file_max_end_time + 5.0 try: # Enable knowledge extraction for index building @@ -368,7 +368,7 @@ def save_current_message(): f" auto_extract_knowledge = {settings.semantic_ref_index_settings.auto_extract_knowledge}" ) print( - f" batch_size = {settings.semantic_ref_index_settings.batch_size}" + f" concurrency = {settings.semantic_ref_index_settings.concurrency}" ) # Create a Transcript object @@ -378,42 +378,91 @@ def save_current_message(): tags=[name, "vtt-transcript"], ) - # Process messages in batches - batch_size = settings.semantic_ref_index_settings.batch_size - successful_count = 0 start_time = time.time() + last_batch_time = start_time + + counters: dict[str, int] = { + "ingested": 0, + "chunks": 0, + "semrefs": 0, + "batches": 0, + } + + def on_batch_committed(result: AddMessagesResult) -> None: + nonlocal last_batch_time + counters["ingested"] += result.messages_added + counters["chunks"] += result.chunks_added + counters["semrefs"] += result.semrefs_added + counters["batches"] += 1 + now = time.time() + batch_secs = now - last_batch_time + last_batch_time = now + elapsed = now - start_time + per_chunk = ( + batch_secs / result.chunks_added if result.chunks_added else 0 + ) + parts = [ + f" {counters['ingested']} messages", + f"+{result.chunks_added} chunks", + f"+{result.semrefs_added} semrefs", + ] + print( + f"{' | '.join(parts)} | " + f"{batch_secs:.1f}s ({per_chunk:.2f}s/chunk) | " + f"{elapsed:.1f}s elapsed", + flush=True, + ) print( - f" Processing {len(all_messages)} messages in batches of {batch_size}..." + f" Processing messages in batches of {batch_size}" + f" (concurrency={settings.semantic_ref_index_settings.concurrency})..." ) - for i in range(0, len(all_messages), batch_size): - batch = all_messages[i : i + batch_size] - batch_start = time.time() - - result = await transcript.add_messages_with_indexing(batch) + result: AddMessagesResult | None = None + interrupted = False + try: + result = await transcript.add_messages_streaming( + _message_stream(), + batch_size=batch_size, + on_batch_committed=on_batch_committed, + ) + except (KeyboardInterrupt, asyncio.CancelledError): + interrupted = True - successful_count += result.messages_added - batch_time = time.time() - batch_start + elapsed = time.time() - start_time + if interrupted and counters["batches"] == 0: + print() + print("Interrupted before any batches were committed.") + return - elapsed = time.time() - start_time - print( - f" {successful_count}/{len(all_messages)} messages | " - f"{await semref_coll.size()} refs | " - f"{batch_time:.1f}s/batch | " - f"{elapsed:.1f}s elapsed" - ) + messages_ingested = ( + result.messages_added if result is not None else counters["ingested"] + ) + total_chunks = ( + result.chunks_added if result is not None else counters["chunks"] + ) + semrefs_added = ( + result.semrefs_added if result is not None else counters["semrefs"] + ) + overall_per_chunk = elapsed / total_chunks if total_chunks else 0 if verbose: - semref_count = await semref_coll.size() - print(f" Successfully added {successful_count} messages") - print(f" Extracted {semref_count} semantic references") + if interrupted: + print("Ingestion interrupted by user (^C).") + print(f" Successfully added {messages_ingested} messages") + print(f" Ingested {total_chunks} chunk(s)") + print(f" Extracted {semrefs_added} semantic references") + print(f" Total time: {elapsed:.1f}s") + print(f" Overall time per chunk: {overall_per_chunk:.2f}s/chunk") else: print( - f"Imported {successful_count} messages from {len(vtt_files)} file(s) to {database}" + f"Imported {messages_ingested} messages from {len(vtt_files)} file(s) to {database}" + f" ({total_chunks} chunks, {semrefs_added} refs, {elapsed:.1f}s," + f" {overall_per_chunk:.2f}s/chunk)" ) - print("All indexes built successfully") + if not interrupted: + print("All indexes built successfully") except BaseException as e: print(f"\nError: Failed to process messages: {e}", file=sys.stderr) @@ -449,7 +498,8 @@ def main(): database=args.database, name=args.name, merge_consecutive=args.merge, - batchsize=args.batchsize, + concurrency=args.concurrency, + batch_size=args.batch_size, embedding_name=args.embedding_name, verbose=args.verbose, ) diff --git a/tools/load_json.py b/tools/load_json.py index 8a885047..a1f01bc2 100644 --- a/tools/load_json.py +++ b/tools/load_json.py @@ -54,7 +54,7 @@ async def load_json_to_database( # Get the storage provider to check if database is empty provider = await settings.get_storage_provider() - msgs = await provider.get_message_collection() + msgs = provider.messages # Check if database already has data msg_count = await msgs.size() diff --git a/tools/query.py b/tools/query.py index c7ec7908..dc5c90c5 100644 --- a/tools/query.py +++ b/tools/query.py @@ -8,8 +8,10 @@ import argparse import asyncio from collections.abc import Mapping +import contextvars from dataclasses import dataclass, field, replace import difflib +import io import json import os import re @@ -59,6 +61,21 @@ from typeagent.podcasts import podcast from typeagent.storage.utils import create_storage_provider +### Output redirection for batch concurrency ### + +_output_buf: contextvars.ContextVar[io.StringIO | None] = contextvars.ContextVar( + "_output_buf", default=None +) + + +def qprint(*args: object, **kwargs: typing.Any) -> None: + """Print wrapper that redirects to per-task StringIO when set.""" + buf = _output_buf.get() + if buf is not None and "file" not in kwargs: + kwargs["file"] = buf + print(*args, **kwargs) + + ### Classes ### @@ -550,7 +567,7 @@ async def main(): # Load existing database provider = await settings.get_storage_provider() - msgs = await provider.get_message_collection() + msgs = provider.messages if await msgs.size() == 0: raise SystemExit(f"Error: Database '{args.database}' is empty.") @@ -577,7 +594,7 @@ async def main(): "Error: non-empty --search-results required for batch mode." ) - model = model_adapters.create_chat_model() + model = model_adapters.create_chat_model(retrier=settings.chat_retrier) query_translator = utils.create_translator(model, search_query_schema.SearchQuery) if args.alt_schema: if args.verbose: @@ -631,7 +648,9 @@ async def main(): + f"Running in batch mode [{args.offset}:{args.offset + args.limit if args.limit else ''}]." + Fore.RESET ) - await batch_loop(context, args.offset, args.limit, args.skip_counters) + await batch_loop( + context, args.offset, args.limit, args.skip_counters, args.concurrency + ) else: if args.verbose: print(Fore.YELLOW + "Running in interactive mode." + Fore.RESET) @@ -696,7 +715,11 @@ async def print_conversation_stats(c: IConversation, verbose: bool = True) -> No async def batch_loop( - context: ProcessingContext, offset: int, limit: int, skip_counters: str + context: ProcessingContext, + offset: int, + limit: int, + skip_counters: str, + concurrency: int, ) -> None: skips = [] if skip_counters: @@ -704,15 +727,36 @@ async def batch_loop( if limit == 0: limit = len(context.ar_list) - offset sublist = context.ar_list[offset : offset + limit] - all_scores = [] + + semaphore = asyncio.Semaphore(concurrency) + + async def run_one(counter: int, question: str) -> tuple[int, float | None, str]: + buf = io.StringIO() + token = _output_buf.set(buf) + try: + qprint("-" * 20, counter, question, "-" * 20) + async with semaphore: + score = await process_query(context, question) + return (counter, score, buf.getvalue()) + finally: + _output_buf.reset(token) + + tasks: list[asyncio.Task[tuple[int, float | None, str]]] = [] for counter, qadata in enumerate(sublist, offset + 1): if counter in skips: continue question = qadata["question"] - print("-" * 20, counter, question, "-" * 20) - score = await process_query(context, question) + tasks.append(asyncio.create_task(run_one(counter, question))) + + results = await asyncio.gather(*tasks) + + # Output in canonical (counter) order + all_scores: list[tuple[float, int]] = [] + for counter, score, output in sorted(results): + sys.stdout.write(output) if score is not None: all_scores.append((score, counter)) + if not all_scores: return print("=" * 50) @@ -786,14 +830,14 @@ async def process_query(context: ProcessingContext, query_text: str) -> float | if not record or ( "searchQueryExpr" not in record or "compiledQueryExpr" not in record ): - print("Can't skip stages 1 or 2, no precomputed outcomes found.") + qprint("Can't skip stages 1 or 2, no precomputed outcomes found.") else: # Skipping stage 2 implies skipping stage 1, and we must supply the # precomputed results for both stages. debug_context.use_search_query = serialization.deserialize_object( search_query_schema.SearchQuery, record["searchQueryExpr"] ) - print("Skipping stage 1, substituting precomputed search query.") + qprint("Skipping stage 1, substituting precomputed search query.") if context.debug2 == "skip": debug_context.use_compiled_search_query_exprs = ( serialization.deserialize_object( @@ -801,7 +845,7 @@ async def process_query(context: ProcessingContext, query_text: str) -> float | record["compiledQueryExpr"], ) ) - print( + qprint( "Skipping stage 2, substituting precomputed compiled query expressions." ) prsep() @@ -823,62 +867,62 @@ async def process_query(context: ProcessingContext, query_text: str) -> float | debug_context=debug_context, ) if isinstance(result, typechat.Failure): - print("Stages 1-3 failed:") - print(Fore.RED + str(result) + Fore.RESET) + qprint("Stages 1-3 failed:") + qprint(Fore.RED + str(result) + Fore.RESET) return search_results = result.value actual1 = debug_context.search_query if actual1: if context.debug1 == "full": - print("Stage 1 results:") - utils.pretty_print(actual1, Fore.GREEN, Fore.RESET) + qprint("Stage 1 results:") + utils.pretty_print(actual1, Fore.GREEN, Fore.RESET, file=_output_buf.get()) prsep() elif context.debug1 == "diff": if record and "searchQueryExpr" in record: - print("Stage 1 diff:") + qprint("Stage 1 diff:") expected1 = serialization.deserialize_object( search_query_schema.SearchQuery, record["searchQueryExpr"] ) compare_and_print_diff(expected1, actual1) else: - print("Stage 1 diff unavailable") + qprint("Stage 1 diff unavailable") prsep() actual2 = debug_context.search_query_expr if actual2: if context.debug2 == "full": - print("Stage 2 results:") - utils.pretty_print(actual2, Fore.GREEN, Fore.RESET) + qprint("Stage 2 results:") + utils.pretty_print(actual2, Fore.GREEN, Fore.RESET, file=_output_buf.get()) prsep() elif context.debug2 == "diff": if record and "compiledQueryExpr" in record: - print("Stage 2 diff:") + qprint("Stage 2 diff:") expected2 = serialization.deserialize_object( list[search.SearchQueryExpr], record["compiledQueryExpr"] ) compare_and_print_diff(expected2, actual2) else: - print("Stage 2 diff unavailable") + qprint("Stage 2 diff unavailable") prsep() actual3 = search_results if context.debug3 == "full": - print("Stage 3 full results:") - utils.pretty_print(actual3, Fore.GREEN, Fore.RESET) + qprint("Stage 3 full results:") + utils.pretty_print(actual3, Fore.GREEN, Fore.RESET, file=_output_buf.get()) prsep() elif context.debug3 == "nice": - print("Stage 3 nice results:") + qprint("Stage 3 nice results:") for sr in search_results: await print_result(sr, context.query_context.conversation) prsep() elif context.debug3 == "diff": if record and "results" in record: - print("Stage 3 diff:") + qprint("Stage 3 diff:") expected3: list[RawSearchResultData] = record["results"] compare_results(expected3, actual3) else: - print("Stage 3 diff unavailable") + qprint("Stage 3 diff unavailable") prsep() context.answer_context_options.debug = context.debug4 == "full" @@ -897,19 +941,19 @@ async def process_query(context: ProcessingContext, query_text: str) -> float | context.history.add(query_text, combined_answer.why_no_answer or "", False) if context.debug4 == "full": - utils.pretty_print(all_answers) + utils.pretty_print(all_answers, file=_output_buf.get()) prsep() if context.debug4 in ("full", "nice"): if combined_answer.type == "NoAnswer": - print(Fore.RED + f"Failure: {combined_answer.why_no_answer}" + Fore.RESET) + qprint(Fore.RED + f"Failure: {combined_answer.why_no_answer}" + Fore.RESET) else: - print(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET) + qprint(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET) prsep() elif context.debug4 == "diff": if query_text in context.ar_index: record = context.ar_index[query_text] expected4: tuple[str, bool] = (record["answer"], not record["hasNoAnswer"]) - print("Stage 4 diff:") + qprint("Stage 4 diff:") match combined_answer.type: case "NoAnswer": actual4 = (combined_answer.why_no_answer or "", False) @@ -917,23 +961,23 @@ async def process_query(context: ProcessingContext, query_text: str) -> float | actual4 = (combined_answer.answer or "", True) score = await compare_answers(context, expected4, actual4) if actual4[0].startswith("TypeChat failure:"): - print(Fore.YELLOW + "No answer received" + Fore.RESET) + qprint(Fore.YELLOW + "No answer received" + Fore.RESET) else: - print(f"Score: {score:.3f}; Question: {query_text}") + qprint(f"Score: {score:.3f}; Question: {query_text}") return score else: - print("Stage 4 diff unavailable; nice answer:") + qprint("Stage 4 diff unavailable; nice answer:") if combined_answer.type == "NoAnswer": - print( + qprint( Fore.RED + f"Failure: {combined_answer.why_no_answer}" + Fore.RESET ) else: - print(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET) + qprint(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET) prsep() def prsep(): - print("-" * 50) + qprint("-" * 50) ### CLI processing ### @@ -1019,6 +1063,12 @@ def make_arg_parser(description: str) -> argparse.ArgumentParser: default=0, help="Do just this question (similar to --offset START-1 --limit 1)", ) + batch.add_argument( + "--concurrency", + type=int, + default=10, + help="Max concurrent queries in batch mode (default 10)", + ) debug = parser.add_argument_group("Debug options") debug.add_argument( @@ -1134,13 +1184,13 @@ async def print_result[TMessage: IMessage, TIndex: ITermToSemanticRefIndex]( result: search.ConversationSearchResult, conversation: IConversation[TMessage, TIndex], ) -> None: - print( + qprint( f"Raw query: {result.raw_query_text};", f"{len(result.message_matches)} message matches,", f"{len(result.knowledge_matches)} knowledge matches", ) if result.message_matches: - print("Message matches:") + qprint("Message matches:") for scored_ord in sorted( result.message_matches, key=lambda x: x.score, reverse=True ): @@ -1149,25 +1199,25 @@ async def print_result[TMessage: IMessage, TIndex: ITermToSemanticRefIndex]( msg = await conversation.messages.get_item(msg_ord) assert msg.metadata is not None # For type checkers text = " ".join(msg.text_chunks).strip() - print( + qprint( f"({score:5.1f}) M={msg_ord:d}: " f"{msg.metadata.source!s:>15.15s}: " f"{repr(text)[1:-1]:<150.150s} " ) if result.knowledge_matches: - print(f"Knowledge matches ({', '.join(sorted(result.knowledge_matches))}):") + qprint(f"Knowledge matches ({', '.join(sorted(result.knowledge_matches))}):") for key, value in sorted(result.knowledge_matches.items()): - print(f"Type {key} -- {value.term_matches}:") + qprint(f"Type {key} -- {value.term_matches}:") for scored_sem_ref_ord in value.semantic_ref_matches: score = scored_sem_ref_ord.score sem_ref_ord = scored_sem_ref_ord.semantic_ref_ordinal if conversation.semantic_refs is None: - print(f" Ord: {sem_ref_ord} (score {score})") + qprint(f" Ord: {sem_ref_ord} (score {score})") else: sem_ref = await conversation.semantic_refs.get_item(sem_ref_ord) msg_ord = sem_ref.range.start.message_ordinal msg = await conversation.messages.get_item(msg_ord) - print( + qprint( f"({score:5.1f}) M={msg_ord}: " f"S={summarize_knowledge(sem_ref)}" ) @@ -1227,7 +1277,7 @@ def compare_results( results: list[search.ConversationSearchResult], ) -> bool: if len(results) != len(matches_records): - print(f"(Result sizes mismatch, {len(results)} != {len(matches_records)})") + qprint(f"(Result sizes mismatch, {len(results)} != {len(matches_records)})") return False res = True for result, record in zip(results, matches_records): @@ -1277,8 +1327,10 @@ def compare_message_ordinals(aa: list[ScoredMessageOrdinal], b: list[int]) -> bo a = [aai.message_ordinal for aai in aa] if set(a) ^ set(b) <= NOISE_MESSAGES: return True - print("Message ordinals do not match:") - utils.list_diff(" Expected:", b, " Actual:", a, max_items=20) + qprint("Message ordinals do not match:") + utils.list_diff( + " Expected:", b, " Actual:", a, max_items=20, file=_output_buf.get() + ) return False @@ -1288,8 +1340,10 @@ def compare_semantic_ref_ordinals( a = [aai.semantic_ref_ordinal for aai in aa] if sorted(a) == sorted(b): return True - print(f"{label.capitalize()} SemanticRef ordinals do not match:") - utils.list_diff(" Expected:", b, " Actual:", a, max_items=20) + qprint(f"{label.capitalize()} SemanticRef ordinals do not match:") + utils.list_diff( + " Expected:", b, " Actual:", a, max_items=20, file=_output_buf.get() + ) return False @@ -1319,18 +1373,18 @@ async def compare_answers( actual_text, actual_success = actual if expected_success != actual_success: - print( + qprint( f"Expected success: {Fore.RED}{expected_success}{Fore.RESET}; " f"actual: {Fore.GREEN}{actual_success}{Fore.RESET}" ) score = 0.000 if expected_success else 0.001 # 0.001 == Answer not expected elif not actual_success: - print(Fore.GREEN + f"Both failed" + Fore.RESET) + qprint(Fore.GREEN + f"Both failed" + Fore.RESET) score = 1.001 elif expected_text == actual_text: - print(Fore.GREEN + f"Both equal" + Fore.RESET) + qprint(Fore.GREEN + f"Both equal" + Fore.RESET) score = 1.000 else: @@ -1341,7 +1395,7 @@ async def compare_answers( else: n = 2 if score == 1.0: - print(actual_text) + qprint(actual_text) else: print_diff(expected_text, actual_text, n=n) @@ -1358,11 +1412,11 @@ def print_diff(a: str, b: str, n: int) -> None: ) for x in diff: if x.startswith("-"): - print(Fore.RED + x.rstrip("\n") + Fore.RESET) + qprint(Fore.RED + x.rstrip("\n") + Fore.RESET) elif x.startswith("+"): - print(Fore.GREEN + x.rstrip("\n") + Fore.RESET) + qprint(Fore.GREEN + x.rstrip("\n") + Fore.RESET) else: - print(x.rstrip("\n")) + qprint(x.rstrip("\n")) async def equality_score(context: ProcessingContext, a: str, b: str) -> float: diff --git a/uv.lock b/uv.lock index 86f16351..6e867d26 100644 --- a/uv.lock +++ b/uv.lock @@ -1015,7 +1015,7 @@ wheels = [ [[package]] name = "logfire" -version = "4.32.0" +version = "4.32.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "executing" }, @@ -1026,9 +1026,9 @@ dependencies = [ { name = "rich" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/98/64/f927d4f9de1f1371047b9016adba1ec2e08258301708d548d41f86f27772/logfire-4.32.0.tar.gz", hash = "sha256:f1dc9d756a4b28f0483645244aaf3ea8535b8e2ae5a1068442a968ca0c746304", size = 1088575, upload-time = "2026-04-10T19:36:54.172Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/d7/70c6def7f3f459b2d57aa7fb37863d31b8d877e391547f200ee8c31d2e30/logfire-4.32.1.tar.gz", hash = "sha256:8e7ff418b5f2629c8a8e9426283ff82c760a30f24516c4c389d6cbb1d9768c58", size = 1089612, upload-time = "2026-04-15T14:11:57.518Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/33/81b13e1f2044b5fe0112068a2494526db9cfdf784030a2ea57688279360a/logfire-4.32.0-py3-none-any.whl", hash = "sha256:d9cff51c3c093c4161ece87a65e6ac6e2d862258b62494c30d93d713e9858758", size = 312412, upload-time = "2026-04-10T19:36:50.97Z" }, + { url = "https://files.pythonhosted.org/packages/a4/77/70f6d97d7d74d2f2eeb695fe491b28906ae5c350b48516bb237ace9a1778/logfire-4.32.1-py3-none-any.whl", hash = "sha256:cb7873efec0e94a3de6e603539daaa6509a454599621c80dd227fbfa0ade37d4", size = 313021, upload-time = "2026-04-15T14:11:54.024Z" }, ] [[package]] @@ -1228,7 +1228,7 @@ wheels = [ [[package]] name = "msgraph-sdk" -version = "1.55.0" +version = "1.56.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "azure-identity" }, @@ -1238,9 +1238,9 @@ dependencies = [ { name = "microsoft-kiota-serialization-text" }, { name = "msgraph-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/10/44/0b5a188addf6341b3da10dd207e444417de255f7c1651902ba72016a2843/msgraph_sdk-1.55.0.tar.gz", hash = "sha256:6df691a31954a050d26b8a678968017e157d940fb377f2a8a4e17a9741b98756", size = 6295669, upload-time = "2026-02-20T00:32:29.378Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/cd/a8a472679c01a62757d405fb4975862b0fe3bf5becf3be52709d60cf87db/msgraph_sdk-1.56.0.tar.gz", hash = "sha256:5c3f8efd4d45672b36f04acc5faca19cedc554cfd9c95d925459e948cacef35b", size = 6443012, upload-time = "2026-04-17T18:22:31.549Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fb/a8/de807e62f8ff93003b573aa243cdcee2da2c0618b42efbc9a8e61aa7300d/msgraph_sdk-1.55.0-py3-none-any.whl", hash = "sha256:c8e68ebc4b88af5111de312e7fa910a4e76ddf48a4534feadb1fb8a411c48cfc", size = 25758742, upload-time = "2026-02-20T00:30:40.039Z" }, + { url = "https://files.pythonhosted.org/packages/56/1d/24da99bec3e419ed4da4bfdc50c0cd2406061f478d6225fb9af780099002/msgraph_sdk-1.56.0-py3-none-any.whl", hash = "sha256:9afe06413aed910095dd95de7a5d2d1ea14d6734cc68b778acdb2d9930bdafdc", size = 26201107, upload-time = "2026-04-17T18:22:27.317Z" }, ] [[package]] @@ -1451,32 +1451,32 @@ wheels = [ [[package]] name = "opentelemetry-api" -version = "1.39.1" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "importlib-metadata" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/1d/4049a9e8698361cc1a1aa03a6c59e4fa4c71e0c0f94a30f988a6876a2ae6/opentelemetry_api-1.40.0.tar.gz", hash = "sha256:159be641c0b04d11e9ecd576906462773eb97ae1b657730f0ecf64d32071569f", size = 70851, upload-time = "2026-03-04T14:17:21.555Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, + { url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676, upload-time = "2026-03-04T14:17:01.24Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.39.1" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-proto" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e9/9d/22d241b66f7bbde88a3bfa6847a351d2c46b84de23e71222c6aae25c7050/opentelemetry_exporter_otlp_proto_common-1.39.1.tar.gz", hash = "sha256:763370d4737a59741c89a67b50f9e39271639ee4afc999dadfe768541c027464", size = 20409, upload-time = "2025-12-11T13:32:40.885Z" } +sdist = { url = "https://files.pythonhosted.org/packages/51/bc/1559d46557fe6eca0b46c88d4c2676285f1f3be2e8d06bb5d15fbffc814a/opentelemetry_exporter_otlp_proto_common-1.40.0.tar.gz", hash = "sha256:1cbee86a4064790b362a86601ee7934f368b81cd4cc2f2e163902a6e7818a0fa", size = 20416, upload-time = "2026-03-04T14:17:23.801Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8c/02/ffc3e143d89a27ac21fd557365b98bd0653b98de8a101151d5805b5d4c33/opentelemetry_exporter_otlp_proto_common-1.39.1-py3-none-any.whl", hash = "sha256:08f8a5862d64cc3435105686d0216c1365dc5701f86844a8cd56597d0c764fde", size = 18366, upload-time = "2025-12-11T13:32:20.2Z" }, + { url = "https://files.pythonhosted.org/packages/8b/ca/8f122055c97a932311a3f640273f084e738008933503d0c2563cd5d591fc/opentelemetry_exporter_otlp_proto_common-1.40.0-py3-none-any.whl", hash = "sha256:7081ff453835a82417bf38dccf122c827c3cbc94f2079b03bba02a3165f25149", size = 18369, upload-time = "2026-03-04T14:17:04.796Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-http" -version = "1.39.1" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "googleapis-common-protos" }, @@ -1487,14 +1487,14 @@ dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/80/04/2a08fa9c0214ae38880df01e8bfae12b067ec0793446578575e5080d6545/opentelemetry_exporter_otlp_proto_http-1.39.1.tar.gz", hash = "sha256:31bdab9745c709ce90a49a0624c2bd445d31a28ba34275951a6a362d16a0b9cb", size = 17288, upload-time = "2025-12-11T13:32:42.029Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/fa/73d50e2c15c56be4d000c98e24221d494674b0cc95524e2a8cb3856d95a4/opentelemetry_exporter_otlp_proto_http-1.40.0.tar.gz", hash = "sha256:db48f5e0f33217588bbc00274a31517ba830da576e59503507c839b38fa0869c", size = 17772, upload-time = "2026-03-04T14:17:25.324Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/f1/b27d3e2e003cd9a3592c43d099d2ed8d0a947c15281bf8463a256db0b46c/opentelemetry_exporter_otlp_proto_http-1.39.1-py3-none-any.whl", hash = "sha256:d9f5207183dd752a412c4cd564ca8875ececba13be6e9c6c370ffb752fd59985", size = 19641, upload-time = "2025-12-11T13:32:22.248Z" }, + { url = "https://files.pythonhosted.org/packages/a0/3a/8865d6754e61c9fb170cdd530a124a53769ee5f740236064816eb0ca7301/opentelemetry_exporter_otlp_proto_http-1.40.0-py3-none-any.whl", hash = "sha256:a8d1dab28f504c5d96577d6509f80a8150e44e8f45f82cdbe0e34c99ab040069", size = 19960, upload-time = "2026-03-04T14:17:07.153Z" }, ] [[package]] name = "opentelemetry-instrumentation" -version = "0.60b1" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -1502,14 +1502,14 @@ dependencies = [ { name = "packaging" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/41/0f/7e6b713ac117c1f5e4e3300748af699b9902a2e5e34c9cf443dde25a01fa/opentelemetry_instrumentation-0.60b1.tar.gz", hash = "sha256:57ddc7974c6eb35865af0426d1a17132b88b2ed8586897fee187fd5b8944bd6a", size = 31706, upload-time = "2025-12-11T13:36:42.515Z" } +sdist = { url = "https://files.pythonhosted.org/packages/da/37/6bf8e66bfcee5d3c6515b79cb2ee9ad05fe573c20f7ceb288d0e7eeec28c/opentelemetry_instrumentation-0.61b0.tar.gz", hash = "sha256:cb21b48db738c9de196eba6b805b4ff9de3b7f187e4bbf9a466fa170514f1fc7", size = 32606, upload-time = "2026-03-04T14:20:16.825Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/77/d2/6788e83c5c86a2690101681aeef27eeb2a6bf22df52d3f263a22cee20915/opentelemetry_instrumentation-0.60b1-py3-none-any.whl", hash = "sha256:04480db952b48fb1ed0073f822f0ee26012b7be7c3eac1a3793122737c78632d", size = 33096, upload-time = "2025-12-11T13:35:33.067Z" }, + { url = "https://files.pythonhosted.org/packages/d8/3e/f6f10f178b6316de67f0dfdbbb699a24fbe8917cf1743c1595fb9dcdd461/opentelemetry_instrumentation-0.61b0-py3-none-any.whl", hash = "sha256:92a93a280e69788e8f88391247cc530fd81f16f2b011979d4d6398f805cfbc63", size = 33448, upload-time = "2026-03-04T14:19:02.447Z" }, ] [[package]] name = "opentelemetry-instrumentation-httpx" -version = "0.60b1" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -1518,57 +1518,57 @@ dependencies = [ { name = "opentelemetry-util-http" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/86/08/11208bcfcab4fc2023252c3f322aa397fd9ad948355fea60f5fc98648603/opentelemetry_instrumentation_httpx-0.60b1.tar.gz", hash = "sha256:a506ebaf28c60112cbe70ad4f0338f8603f148938cb7b6794ce1051cd2b270ae", size = 20611, upload-time = "2025-12-11T13:37:01.661Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/2a/e2becd55e33c29d1d9ef76e2579040ed1951cb33bacba259f6aff2fdd2a6/opentelemetry_instrumentation_httpx-0.61b0.tar.gz", hash = "sha256:6569ec097946c5551c2a4252f74c98666addd1bf047c1dde6b4ef426719ff8dd", size = 24104, upload-time = "2026-03-04T14:20:34.752Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/43/59/b98e84eebf745ffc75397eaad4763795bff8a30cbf2373a50ed4e70646c5/opentelemetry_instrumentation_httpx-0.60b1-py3-none-any.whl", hash = "sha256:f37636dd742ad2af83d896ba69601ed28da51fa4e25d1ab62fde89ce413e275b", size = 15701, upload-time = "2025-12-11T13:36:04.56Z" }, + { url = "https://files.pythonhosted.org/packages/af/88/dde310dce56e2d85cf1a09507f5888544955309edc4b8d22971d6d3d1417/opentelemetry_instrumentation_httpx-0.61b0-py3-none-any.whl", hash = "sha256:dee05c93a6593a5dc3ae5d9d5c01df8b4e2c5d02e49275e5558534ee46343d5e", size = 17198, upload-time = "2026-03-04T14:19:33.585Z" }, ] [[package]] name = "opentelemetry-proto" -version = "1.39.1" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/49/1d/f25d76d8260c156c40c97c9ed4511ec0f9ce353f8108ca6e7561f82a06b2/opentelemetry_proto-1.39.1.tar.gz", hash = "sha256:6c8e05144fc0d3ed4d22c2289c6b126e03bcd0e6a7da0f16cedd2e1c2772e2c8", size = 46152, upload-time = "2025-12-11T13:32:48.681Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/77/dd38991db037fdfce45849491cb61de5ab000f49824a00230afb112a4392/opentelemetry_proto-1.40.0.tar.gz", hash = "sha256:03f639ca129ba513f5819810f5b1f42bcb371391405d99c168fe6937c62febcd", size = 45667, upload-time = "2026-03-04T14:17:31.194Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/51/95/b40c96a7b5203005a0b03d8ce8cd212ff23f1793d5ba289c87a097571b18/opentelemetry_proto-1.39.1-py3-none-any.whl", hash = "sha256:22cdc78efd3b3765d09e68bfbd010d4fc254c9818afd0b6b423387d9dee46007", size = 72535, upload-time = "2025-12-11T13:32:33.866Z" }, + { url = "https://files.pythonhosted.org/packages/b9/b2/189b2577dde745b15625b3214302605b1353436219d42b7912e77fa8dc24/opentelemetry_proto-1.40.0-py3-none-any.whl", hash = "sha256:266c4385d88923a23d63e353e9761af0f47a6ed0d486979777fe4de59dc9b25f", size = 72073, upload-time = "2026-03-04T14:17:16.673Z" }, ] [[package]] name = "opentelemetry-sdk" -version = "1.39.1" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-semantic-conventions" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/eb/fb/c76080c9ba07e1e8235d24cdcc4d125ef7aa3edf23eb4e497c2e50889adc/opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6", size = 171460, upload-time = "2025-12-11T13:32:49.369Z" } +sdist = { url = "https://files.pythonhosted.org/packages/58/fd/3c3125b20ba18ce2155ba9ea74acb0ae5d25f8cd39cfd37455601b7955cc/opentelemetry_sdk-1.40.0.tar.gz", hash = "sha256:18e9f5ec20d859d268c7cb3c5198c8d105d073714db3de50b593b8c1345a48f2", size = 184252, upload-time = "2026-03-04T14:17:31.87Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/98/e91cf858f203d86f4eccdf763dcf01cf03f1dae80c3750f7e635bfa206b6/opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c", size = 132565, upload-time = "2025-12-11T13:32:35.069Z" }, + { url = "https://files.pythonhosted.org/packages/2c/c5/6a852903d8bfac758c6dc6e9a68b015d3c33f2f1be5e9591e0f4b69c7e0a/opentelemetry_sdk-1.40.0-py3-none-any.whl", hash = "sha256:787d2154a71f4b3d81f20524a8ce061b7db667d24e46753f32a7bc48f1c1f3f1", size = 141951, upload-time = "2026-03-04T14:17:17.961Z" }, ] [[package]] name = "opentelemetry-semantic-conventions" -version = "0.60b1" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/91/df/553f93ed38bf22f4b999d9be9c185adb558982214f33eae539d3b5cd0858/opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953", size = 137935, upload-time = "2025-12-11T13:32:50.487Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/c0/4ae7973f3c2cfd2b6e321f1675626f0dab0a97027cc7a297474c9c8f3d04/opentelemetry_semantic_conventions-0.61b0.tar.gz", hash = "sha256:072f65473c5d7c6dc0355b27d6c9d1a679d63b6d4b4b16a9773062cb7e31192a", size = 145755, upload-time = "2026-03-04T14:17:32.664Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/5e/5958555e09635d09b75de3c4f8b9cae7335ca545d77392ffe7331534c402/opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb", size = 219982, upload-time = "2025-12-11T13:32:36.955Z" }, + { url = "https://files.pythonhosted.org/packages/b2/37/cc6a55e448deaa9b27377d087da8615a3416d8ad523d5960b78dbeadd02a/opentelemetry_semantic_conventions-0.61b0-py3-none-any.whl", hash = "sha256:fa530a96be229795f8cef353739b618148b0fe2b4b3f005e60e262926c4d38e2", size = 231621, upload-time = "2026-03-04T14:17:19.33Z" }, ] [[package]] name = "opentelemetry-util-http" -version = "0.60b1" +version = "0.61b0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/50/fc/c47bb04a1d8a941a4061307e1eddfa331ed4d0ab13d8a9781e6db256940a/opentelemetry_util_http-0.60b1.tar.gz", hash = "sha256:0d97152ca8c8a41ced7172d29d3622a219317f74ae6bb3027cfbdcf22c3cc0d6", size = 11053, upload-time = "2025-12-11T13:37:25.115Z" } +sdist = { url = "https://files.pythonhosted.org/packages/57/3c/f0196223efc5c4ca19f8fad3d5462b171ac6333013335ce540c01af419e9/opentelemetry_util_http-0.61b0.tar.gz", hash = "sha256:1039cb891334ad2731affdf034d8fb8b48c239af9b6dd295e5fabd07f1c95572", size = 11361, upload-time = "2026-03-04T14:20:57.01Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/16/5c/d3f1733665f7cd582ef0842fb1d2ed0bc1fba10875160593342d22bba375/opentelemetry_util_http-0.60b1-py3-none-any.whl", hash = "sha256:66381ba28550c91bee14dcba8979ace443444af1ed609226634596b4b0faf199", size = 8947, upload-time = "2025-12-11T13:36:37.151Z" }, + { url = "https://files.pythonhosted.org/packages/0d/e5/c08aaaf2f64288d2b6ef65741d2de5454e64af3e050f34285fb1907492fe/opentelemetry_util_http-0.61b0-py3-none-any.whl", hash = "sha256:8e715e848233e9527ea47e275659ea60a57a75edf5206a3b937e236a6da5fc33", size = 9281, upload-time = "2026-03-04T14:20:08.364Z" }, ] [[package]] @@ -2304,6 +2304,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/7f/3de5402f39890ac5660b86bcf5c03f9d855dad5c4ed764866d7b592b46fd/sse_starlette-3.3.4-py3-none-any.whl", hash = "sha256:84bb06e58939a8b38d8341f1bc9792f06c2b53f48c608dd207582b664fc8f3c1", size = 14330, upload-time = "2026-03-29T09:00:21.846Z" }, ] +[[package]] +name = "stamina" +version = "26.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tenacity" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/80/bd/b2f71ae14368a066f103d182f25bbc6c3bf4aa695889f3ed3cba026d6f36/stamina-26.1.0.tar.gz", hash = "sha256:0214d05fdf5102c518194a4aac7520ce53cf660550ae3b940701aad88cf50c17", size = 568171, upload-time = "2026-04-13T17:44:31.012Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/f0/1ff90a1d1dd02de23feafdf9dffaecef3958348be5c192df56670ccb4f86/stamina-26.1.0-py3-none-any.whl", hash = "sha256:62e06829bec87c06d4cafde520b32a6097d1017c378a9eb63253c5bf5ebbbb88", size = 18508, upload-time = "2026-04-13T17:44:29.545Z" }, +] + [[package]] name = "starlette" version = "1.0.0" @@ -2326,6 +2338,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/64/97/b4f2f442fee92a1406f08b4fbc990bd7d02dc84b3b5e6315a59fa9b2a9f4/std_uritemplate-2.0.8-py3-none-any.whl", hash = "sha256:839807a7f9d07f0bad1a88977c3428bd97b9ff0d229412a0bf36123d8c724257", size = 6512, upload-time = "2025-10-16T15:51:28.713Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, +] + [[package]] name = "tiktoken" version = "0.12.0" @@ -2400,6 +2421,7 @@ dependencies = [ { name = "pyreadline3", marker = "sys_platform == 'win32'" }, { name = "pyright" }, { name = "python-dotenv" }, + { name = "stamina" }, { name = "tiktoken" }, { name = "typechat" }, { name = "webvtt-py" }, @@ -2444,6 +2466,7 @@ requires-dist = [ { name = "pyreadline3", marker = "sys_platform == 'win32'", specifier = ">=3.5.4" }, { name = "pyright", specifier = ">=1.1.409" }, { name = "python-dotenv", specifier = ">=1.1.0" }, + { name = "stamina", specifier = ">=26.1.0" }, { name = "tiktoken", specifier = ">=0.12.0" }, { name = "typechat", specifier = ">=0.0.4" }, { name = "webvtt-py", specifier = ">=0.5.1" }, @@ -2453,20 +2476,20 @@ provides-extras = ["logfire"] [package.metadata.requires-dev] dev = [ { name = "azure-mgmt-authorization", specifier = ">=4.0.0" }, - { name = "azure-mgmt-keyvault", specifier = ">=12.1.1" }, - { name = "black", specifier = ">=25.12.0" }, - { name = "coverage", extras = ["toml"], specifier = ">=7.9.1" }, - { name = "google-api-python-client", specifier = ">=2.184.0" }, - { name = "google-auth-httplib2", specifier = ">=0.2.0" }, - { name = "google-auth-oauthlib", specifier = ">=1.2.2" }, - { name = "isort", specifier = ">=7.0.0" }, - { name = "logfire", specifier = ">=4.1.0" }, - { name = "msgraph-sdk", specifier = ">=1.54.0" }, - { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.57b0" }, - { name = "pyright", specifier = ">=1.1.408" }, - { name = "pytest", specifier = ">=8.3.5" }, - { name = "pytest-asyncio", specifier = ">=0.26.0" }, - { name = "pytest-mock", specifier = ">=3.14.0" }, + { name = "azure-mgmt-keyvault", specifier = ">=14.0.1" }, + { name = "black", specifier = ">=26.3.1" }, + { name = "coverage", extras = ["toml"], specifier = ">=7.13.5" }, + { name = "google-api-python-client", specifier = ">=2.194.0" }, + { name = "google-auth-httplib2", specifier = ">=0.3.1" }, + { name = "google-auth-oauthlib", specifier = ">=1.3.1" }, + { name = "isort", specifier = ">=8.0.1" }, + { name = "logfire", specifier = ">=4.32.1" }, + { name = "msgraph-sdk", specifier = ">=1.56.0" }, + { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.61b0" }, + { name = "pyright", specifier = ">=1.1.409" }, + { name = "pytest", specifier = ">=9.0.3" }, + { name = "pytest-asyncio", specifier = ">=1.3.0" }, + { name = "pytest-mock", specifier = ">=3.15.1" }, ] [[package]]