diff --git a/AGENTS.md b/AGENTS.md index f9a4fb3b..6e391b01 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,6 +25,7 @@ AGENTS.md. In all cases show what you added to AGENTS.md. - Use `make test` to run all tests - Use `make check test` to run `make check` and if it passes also run `make test` - Use `make format` to format all files using `black`. Do this before reporting success. +- When validating changes, first run `pytest` only on new/modified test files, then run `make format check test` once at the end. - Keep ad-hoc and performance benchmarks under `tools/`, not `tests/`, so `make test` does not run them. ## Package Management with uv @@ -36,7 +37,7 @@ AGENTS.md. In all cases show what you added to AGENTS.md. - uv maintains consistency between `pyproject.toml`, `uv.lock`, and installed packages - Trust uv's automatic version resolution and file management -**IMPORTANT! YOU ARE NOT DONE UNTIL `make check test format` PASSES** +**IMPORTANT! YOU ARE NOT DONE UNTIL `make format check test` PASSES** # Code generation diff --git a/src/typeagent/emails/email_message.py b/src/typeagent/emails/email_message.py index 0a469f49..47abdbec 100644 --- a/src/typeagent/emails/email_message.py +++ b/src/typeagent/emails/email_message.py @@ -161,6 +161,7 @@ def __init__(self, **data: Any) -> None: ) timestamp: str | None = None # Use metadata.sent_on for the actual sent time src_url: str | None = None # Source file or uri for this email + source_id: str | None = None # External source id (see IMessage.source_id) def get_knowledge(self) -> kplib.KnowledgeResponse: return self.metadata.get_knowledge() diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 8026472a..131b0ceb 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -3,6 +3,7 @@ """Base class for conversations with incremental indexing support.""" +from collections.abc import AsyncIterable, Callable from dataclasses import dataclass from datetime import datetime, timezone from typing import Generic, Self, TypeVar @@ -36,6 +37,8 @@ MessageOrdinal, Topic, ) +from .knowledge import extract_knowledge_from_text_batch +from .messageutils import get_all_message_chunk_locations TMessage = TypeVar("TMessage", bound=IMessage) @@ -132,9 +135,12 @@ async def add_messages_with_indexing( Args: messages: Messages to add - source_ids: Optional list of source IDs to mark as ingested. These are - marked within the same transaction, so if the indexing fails, the - source IDs won't be marked as ingested (for SQLite storage). + source_ids: Optional explicit list of source IDs to mark as ingested, + one per message. When ``None`` (the default), each message's + ``source_id`` attribute is used instead — messages whose + ``source_id`` is ``None`` are silently skipped. These are marked + within the same transaction, so if the indexing fails, the source + IDs won't be marked as ingested (for SQLite storage). Returns: Result with counts of messages/semrefs added @@ -143,7 +149,7 @@ async def add_messages_with_indexing( Exception: Any error """ storage = await self.settings.get_storage_provider() - if source_ids: + if source_ids is not None: if len(source_ids) != len(messages): raise ValueError( f"Length of source_ids {len(source_ids)} " @@ -152,9 +158,13 @@ async def add_messages_with_indexing( async with storage: # Mark source IDs as ingested (will be rolled back on error) - if source_ids: - for source_id in source_ids: - await storage.mark_source_ingested(source_id) + if source_ids is not None: + for sid in source_ids: + await storage.mark_source_ingested(sid) + else: + for msg in messages: + if msg.source_id is not None: + await storage.mark_source_ingested(msg.source_id) start_points = IndexingStartPoints( message_count=await self.messages.size(), @@ -173,8 +183,11 @@ async def add_messages_with_indexing( await self._update_secondary_indexes_incremental(start_points) + messages_added = await self.messages.size() - start_points.message_count + chunks_added = sum(len(m.text_chunks) for m in messages[:messages_added]) result = AddMessagesResult( - messages_added=await self.messages.size() - start_points.message_count, + messages_added=messages_added, + chunks_added=chunks_added, semrefs_added=await self.semantic_refs.size() - start_points.semref_count, ) @@ -186,6 +199,184 @@ async def add_messages_with_indexing( return result + async def add_messages_streaming( + self, + messages: AsyncIterable[TMessage], + *, + batch_size: int = 100, + on_batch_committed: Callable[[AddMessagesResult], None] | None = None, + ) -> AddMessagesResult: + """Add messages from an async iterable, committing in batches. + + Unlike ``add_messages_with_indexing`` which processes all messages in a + single transaction, this method buffers messages into batches of + ``batch_size``, processes each batch in its own transaction, and commits + after every batch. This makes it suitable for ingesting large streams + (millions of messages) where a single all-or-nothing transaction would + be impractical. + + **Source-ID tracking**: each message's ``source_id`` (if not ``None``) + is checked before ingestion. Already-ingested sources are silently + skipped. Newly ingested sources are marked within the same transaction. + + **Extraction failures**: when knowledge extraction returns a + ``Failure`` for a chunk, the failure is recorded via + ``storage.record_chunk_failure`` and processing continues with the + remaining chunks. Raised exceptions (HTTP errors, timeouts, etc.) + are treated as systemic and stop the run immediately — the current + batch is rolled back and the exception propagates. + + Args: + messages: An async iterable of messages to ingest. + batch_size: Number of messages per commit batch. + on_batch_committed: Optional callback invoked after each batch is + committed, receiving the batch's ``AddMessagesResult``. + + Returns: + Cumulative ``AddMessagesResult`` across all committed batches. + """ + storage = await self.settings.get_storage_provider() + total_messages_added = 0 + total_semrefs_added = 0 + total_chunks_added = 0 + + batch: list[TMessage] = [] + async for msg in messages: + batch.append(msg) + if len(batch) >= batch_size: + result = await self._ingest_batch_streaming(storage, batch) + total_messages_added += result.messages_added + total_semrefs_added += result.semrefs_added + total_chunks_added += result.chunks_added + if on_batch_committed: + on_batch_committed(result) + batch = [] + + # Flush remaining messages + if batch: + result = await self._ingest_batch_streaming(storage, batch) + total_messages_added += result.messages_added + total_semrefs_added += result.semrefs_added + total_chunks_added += result.chunks_added + if on_batch_committed: + on_batch_committed(result) + + return AddMessagesResult( + messages_added=total_messages_added, + chunks_added=total_chunks_added, + semrefs_added=total_semrefs_added, + ) + + async def _ingest_batch_streaming( + self, + storage: IStorageProvider[TMessage], + batch: list[TMessage], + ) -> AddMessagesResult: + """Process and commit a single batch within a transaction. + + Messages whose ``source_id`` is already ingested are filtered out. + Extraction ``Failure``\\s are recorded as chunk failures. + """ + # Filter out already-ingested sources + filtered: list[TMessage] = [] + for msg in batch: + if msg.source_id is not None and await storage.is_source_ingested( + msg.source_id + ): + continue + filtered.append(msg) + + if not filtered: + return AddMessagesResult() + + async with storage: + start_points = IndexingStartPoints( + message_count=await self.messages.size(), + semref_count=await self.semantic_refs.size(), + ) + + await self.messages.extend(filtered) + + # Mark source IDs as ingested (rolled back on error) + for msg in filtered: + if msg.source_id is not None: + await storage.mark_source_ingested(msg.source_id) + + await self._add_metadata_knowledge_incremental(start_points.message_count) + + if self.settings.semantic_ref_index_settings.auto_extract_knowledge: + await self._add_llm_knowledge_streaming( + storage, filtered, start_points.message_count + ) + + await self._update_secondary_indexes_incremental(start_points) + + await storage.update_conversation_timestamps( + updated_at=datetime.now(timezone.utc) + ) + + messages_added = await self.messages.size() - start_points.message_count + chunks_added = sum(len(m.text_chunks) for m in filtered[:messages_added]) + return AddMessagesResult( + messages_added=messages_added, + chunks_added=chunks_added, + semrefs_added=await self.semantic_refs.size() + - start_points.semref_count, + ) + + async def _add_llm_knowledge_streaming( + self, + storage: IStorageProvider[TMessage], + messages: list[TMessage], + start_from_message_ordinal: int, + ) -> None: + """Extract LLM knowledge, recording failures instead of raising. + + On ``Failure``: records a chunk failure via the storage provider and + continues. On a raised exception: lets it propagate (the caller's + ``async with storage`` will roll back the transaction). + """ + settings = self.settings.semantic_ref_index_settings + knowledge_extractor = ( + settings.knowledge_extractor or convknowledge.KnowledgeExtractor() + ) + + text_locations = get_all_message_chunk_locations( + messages, start_from_message_ordinal + ) + if not text_locations: + return + + start_ordinal = text_locations[0].message_ordinal + text_batch: list[str] = [] + for tl in text_locations: + list_index = tl.message_ordinal - start_ordinal + text_batch.append( + messages[list_index].text_chunks[tl.chunk_ordinal].strip() + ) + + knowledge_results = await extract_knowledge_from_text_batch( + knowledge_extractor, + text_batch, + settings.concurrency, + ) + for i, knowledge_result in enumerate(knowledge_results): + tl = text_locations[i] + if isinstance(knowledge_result, typechat.Failure): + await storage.record_chunk_failure( + tl.message_ordinal, + tl.chunk_ordinal, + type(knowledge_result).__name__, + knowledge_result.message[:500], + ) + continue + await semrefindex.add_knowledge_to_semantic_ref_index( + self, + tl.message_ordinal, + tl.chunk_ordinal, + knowledge_result.value, + ) + async def _add_metadata_knowledge_incremental( self, start_from_message_ordinal: int, @@ -216,21 +407,17 @@ async def _add_llm_knowledge_incremental( settings.knowledge_extractor or convknowledge.KnowledgeExtractor() ) - # Get batches of text locations from the message list - from .messageutils import get_message_chunk_batch_from_list - - batches = get_message_chunk_batch_from_list( + text_locations = get_all_message_chunk_locations( messages, start_from_message_ordinal, - settings.batch_size, ) - for text_location_batch in batches: - await semrefindex.add_batch_to_semantic_ref_index_from_list( - self, - messages, - text_location_batch, - knowledge_extractor, - ) + await semrefindex.add_batch_to_semantic_ref_index_from_list( + self, + messages, + text_locations, + knowledge_extractor, + concurrency=settings.concurrency, + ) async def _update_secondary_indexes_incremental( self, diff --git a/src/typeagent/knowpro/convsettings.py b/src/typeagent/knowpro/convsettings.py index 9dbf1214..97c2bee2 100644 --- a/src/typeagent/knowpro/convsettings.py +++ b/src/typeagent/knowpro/convsettings.py @@ -29,7 +29,7 @@ def __init__(self, embedding_index_settings: TextEmbeddingIndexSettings): @dataclass class SemanticRefIndexSettings: - batch_size: int + concurrency: int auto_extract_knowledge: bool knowledge_extractor: IKnowledgeExtractor | None = None @@ -54,7 +54,7 @@ def __init__( TextEmbeddingIndexSettings(model, min_score=0.7) ) self.semantic_ref_index_settings = SemanticRefIndexSettings( - batch_size=4, # Effectively max concurrency + concurrency=4, auto_extract_knowledge=True, # The high-level API wants this ) diff --git a/src/typeagent/knowpro/interfaces_core.py b/src/typeagent/knowpro/interfaces_core.py index 4dc8fc8e..87ef7329 100644 --- a/src/typeagent/knowpro/interfaces_core.py +++ b/src/typeagent/knowpro/interfaces_core.py @@ -90,8 +90,9 @@ class IndexingStartPoints: class AddMessagesResult: """Result of add_messages_with_indexing operation.""" - messages_added: int - semrefs_added: int + messages_added: int = 0 + chunks_added: int = 0 + semrefs_added: int = 0 # Messages are referenced by their sequential ordinal numbers. @@ -129,6 +130,12 @@ class IMessage[TMetadata: IMessageMetadata](IKnowledgeSource, Protocol): # Metadata associated with the message such as its source. metadata: TMetadata | None = None + # Optional external identifier of the source this message was ingested from + # (e.g., an email ID, a file path, a URL). Used by ingestion pipelines to + # detect already-ingested sources for restartability. None means the message + # is not associated with an external source (e.g., synthesized in tests). + source_id: str | None = None + # Semantic references are also ordinal. type SemanticRefOrdinal = int diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index 97f7b600..c0450f29 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -67,6 +67,21 @@ class SemanticRefMetadata(NamedTuple): knowledge_type: KnowledgeType +@dataclass +class ChunkFailure: + """Record of a single failed knowledge-extraction attempt for one chunk. + + Stored in the storage provider so that ingestion pipelines can retry just + the failed chunks without re-processing whole messages. + """ + + message_ordinal: int + chunk_ordinal: int + error_class: str + error_message: str + failed_at: Datetime + + class IReadonlyCollection[T, TOrdinal](AsyncIterable[T], Protocol): async def size(self) -> int: ... @@ -168,6 +183,33 @@ async def mark_source_ingested( """Mark a source as ingested (no commit; call within transaction context).""" ... + # Chunk-level extraction failure tracking + + async def record_chunk_failure( + self, + message_ordinal: int, + chunk_ordinal: int, + error_class: str, + error_message: str, + ) -> None: + """Record an extraction failure for a single chunk. + + Idempotent: re-recording overwrites any prior entry for the same + (message_ordinal, chunk_ordinal). No commit; call within transaction + context. + """ + ... + + async def clear_chunk_failure( + self, message_ordinal: int, chunk_ordinal: int + ) -> None: + """Remove the failure record for one chunk (e.g., after a retry succeeds).""" + ... + + async def get_chunk_failures(self) -> list[ChunkFailure]: + """Return all recorded chunk failures, ordered by message and chunk.""" + ... + # Transaction management async def __aenter__(self) -> Self: """Enter transaction context. Calls begin_transaction().""" @@ -198,6 +240,7 @@ class IConversation[ __all__ = [ + "ChunkFailure", "ConversationMetadata", "ICollection", "IConversation", diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index bb889e58..e2503967 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -47,7 +47,7 @@ async def batch_worker( async def extract_knowledge_from_text_batch( knowledge_extractor: IKnowledgeExtractor, text_batch: list[str], - concurrency: int = 2, + concurrency: int = 4, ) -> list[Result[kplib.KnowledgeResponse]]: """Extract knowledge from a batch of text inputs concurrently.""" if not text_batch: diff --git a/src/typeagent/knowpro/messageutils.py b/src/typeagent/knowpro/messageutils.py index bd7cf879..6b4afb6e 100644 --- a/src/typeagent/knowpro/messageutils.py +++ b/src/typeagent/knowpro/messageutils.py @@ -5,7 +5,6 @@ from .interfaces import ( IMessage, - IMessageCollection, MessageOrdinal, TextLocation, TextRange, @@ -23,90 +22,28 @@ def text_range_from_message_chunk( ) -async def get_message_chunk_batch[TMessage: IMessage]( - messages: IMessageCollection[TMessage], - message_ordinal_start_at: MessageOrdinal, - batch_size: int, -) -> list[list[TextLocation]]: - """ - Get batches of message chunk locations for processing. - - Args: - messages: Collection of messages to process - message_ordinal_start_at: Starting message ordinal - batch_size: Number of message chunks per batch - - Yields: - Lists of TextLocation objects, each representing a message chunk - """ - batches: list[list[TextLocation]] = [] - current_batch: list[TextLocation] = [] - - message_ordinal = message_ordinal_start_at - async for message in messages: - if message_ordinal < message_ordinal_start_at: - message_ordinal += 1 - continue - - # Process each text chunk in the message - for chunk_ordinal in range(len(message.text_chunks)): - text_location = TextLocation( - message_ordinal=message_ordinal, - chunk_ordinal=chunk_ordinal, - ) - current_batch.append(text_location) - - # When batch is full, yield it and start a new one - if len(current_batch) >= batch_size: - batches.append(current_batch) - current_batch = [] - - message_ordinal += 1 - - # Don't forget the last batch if it has items - if current_batch: - batches.append(current_batch) - - return batches - - -def get_message_chunk_batch_from_list[TMessage: IMessage]( +def get_all_message_chunk_locations[TMessage: IMessage]( messages: list[TMessage], message_ordinal_start_at: MessageOrdinal, - batch_size: int, -) -> list[list[TextLocation]]: +) -> list[TextLocation]: """ - Get batches of message chunk locations for processing from a list of messages. + Get a flat list of all message chunk locations from a list of messages. Args: messages: List of messages to process message_ordinal_start_at: Starting message ordinal (ordinal of first message in list) - batch_size: Number of message chunks per batch Returns: - Lists of TextLocation objects, each representing a message chunk + Flat list of TextLocation objects, one per message chunk """ - batches: list[list[TextLocation]] = [] - current_batch: list[TextLocation] = [] - + locations: list[TextLocation] = [] for idx, message in enumerate(messages): message_ordinal = message_ordinal_start_at + idx - - # Process each text chunk in the message for chunk_ordinal in range(len(message.text_chunks)): - text_location = TextLocation( - message_ordinal=message_ordinal, - chunk_ordinal=chunk_ordinal, + locations.append( + TextLocation( + message_ordinal=message_ordinal, + chunk_ordinal=chunk_ordinal, + ) ) - current_batch.append(text_location) - - # When batch is full, yield it and start a new one - if len(current_batch) >= batch_size: - batches.append(current_batch) - current_batch = [] - - # Don't forget the last batch if it has items - if current_batch: - batches.append(current_batch) - - return batches + return locations diff --git a/src/typeagent/knowpro/universal_message.py b/src/typeagent/knowpro/universal_message.py index 01abfdf9..c5008fe2 100644 --- a/src/typeagent/knowpro/universal_message.py +++ b/src/typeagent/knowpro/universal_message.py @@ -204,6 +204,11 @@ class ConversationMessage(IMessage): Format: "2024-01-01T12:34:56Z" or "1970-01-01T00:01:23Z" (epoch-based) MUST include "Z" suffix to explicitly indicate UTC timezone. """ + source_id: str | None = None + """ + Optional external identifier of the source this message was ingested from + (e.g., a transcript file path or podcast episode id). See ``IMessage.source_id``. + """ def get_knowledge(self) -> kplib.KnowledgeResponse: return self.metadata.get_knowledge() diff --git a/src/typeagent/podcasts/podcast_ingest.py b/src/typeagent/podcasts/podcast_ingest.py index b124f003..d2de7c82 100644 --- a/src/typeagent/podcasts/podcast_ingest.py +++ b/src/typeagent/podcasts/podcast_ingest.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from collections.abc import AsyncIterator from datetime import timedelta import os import re @@ -8,6 +9,7 @@ from ..knowpro.convsettings import ConversationSettings from ..knowpro.interfaces import Datetime +from ..knowpro.interfaces_core import AddMessagesResult from ..knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH from ..storage.utils import create_storage_provider from .podcast import Podcast, PodcastMessage, PodcastMessageMeta @@ -22,6 +24,7 @@ async def ingest_podcast( dbname: str | None = None, batch_size: int = 0, start_message: int = 0, + concurrency: int = 0, verbose: bool = False, ) -> Podcast: """ @@ -37,8 +40,10 @@ async def ingest_podcast( date is unknown (Unix "timestamp left at zero" convention). length_minutes: Total length of podcast in minutes (for proportional timestamp allocation) dbname: Database name or None (to use in-memory non-persistent storage) - batch_size: Number of messages to index per batch (default all messages) + batch_size: Number of messages per call to add_messages_with_indexing + (default: all messages at once). Used for recoverability on crash. start_message: Number of initial messages to skip (for resuming interrupted ingests) + concurrency: Max concurrent knowledge extractions (0 = use settings default) verbose: Whether to print progress information (default False) Returns: @@ -121,20 +126,46 @@ async def ingest_podcast( tags=[podcast_name], ) - # Add messages with indexing to build embeddings, using batch_size - batch_size = batch_size or len(msgs) - settings.semantic_ref_index_settings.batch_size = batch_size - for i in range(start_message, len(msgs), batch_size): - batch = msgs[i : i + batch_size] - t0 = time.time() - await pod.add_messages_with_indexing(batch) - t1 = time.time() + # Set source_id on each message for restartability + for i, msg in enumerate(msgs): + msg.source_id = f"{transcript_file_path}#{i}" + + # Add messages using the streaming API (commit-per-batch) + if concurrency: + settings.semantic_ref_index_settings.concurrency = concurrency + + async def _message_stream() -> AsyncIterator[PodcastMessage]: + for msg in msgs[start_message:]: + yield msg + + cumulative_messages = 0 + t0 = time.time() + + def _on_batch_committed(result: AddMessagesResult) -> None: + nonlocal cumulative_messages + batch_start = cumulative_messages + cumulative_messages += result.messages_added if verbose: print( - f"Indexed messages {i} to {i + len(batch) - 1} " - f"in {t1 - t0:.1f} seconds." + f"Indexed messages {batch_start}-{cumulative_messages - 1} " + f"({result.chunks_added} chunks, {result.semrefs_added} semrefs) " + f"at t={time.time() - t0:.1f} seconds." ) + batch_size = batch_size or len(msgs) + result = await pod.add_messages_streaming( + _message_stream(), + batch_size=batch_size, + on_batch_committed=_on_batch_committed, + ) + t1 = time.time() + if verbose: + print( + f"Indexed {result.messages_added} messages " + f"({result.chunks_added} chunks, {result.semrefs_added} semrefs) " + f"in {t1 - t0:.1f} seconds." + ) + return pod diff --git a/src/typeagent/storage/memory/provider.py b/src/typeagent/storage/memory/provider.py index 540d31b8..603fbd24 100644 --- a/src/typeagent/storage/memory/provider.py +++ b/src/typeagent/storage/memory/provider.py @@ -3,10 +3,11 @@ """In-memory storage provider implementation.""" -from datetime import datetime +from datetime import datetime, timezone from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings from ...knowpro.interfaces import ( + ChunkFailure, ConversationMetadata, IConversationThreads, IMessage, @@ -40,6 +41,7 @@ class MemoryStorageProvider[TMessage: IMessage](IStorageProvider[TMessage]): _related_terms_index: RelatedTermsIndex _conversation_threads: ConversationThreads _ingested_sources: set[str] + _chunk_failures: dict[tuple[int, int], ChunkFailure] def __init__( self, @@ -60,6 +62,7 @@ def __init__( thread_settings = message_text_settings.embedding_index_settings self._conversation_threads = ConversationThreads(thread_settings) self._ingested_sources = set() + self._chunk_failures = {} async def __aenter__(self) -> "MemoryStorageProvider[TMessage]": """Enter transaction context. No-op for in-memory storage.""" @@ -172,3 +175,29 @@ async def mark_source_ingested( source_id: External source identifier (email ID, file path, etc.) """ self._ingested_sources.add(source_id) + + async def record_chunk_failure( + self, + message_ordinal: int, + chunk_ordinal: int, + error_class: str, + error_message: str, + ) -> None: + """Record a knowledge-extraction failure for a single chunk.""" + self._chunk_failures[(message_ordinal, chunk_ordinal)] = ChunkFailure( + message_ordinal=message_ordinal, + chunk_ordinal=chunk_ordinal, + error_class=error_class, + error_message=error_message, + failed_at=datetime.now(timezone.utc), + ) + + async def clear_chunk_failure( + self, message_ordinal: int, chunk_ordinal: int + ) -> None: + """Remove a previously recorded chunk failure (no-op if absent).""" + self._chunk_failures.pop((message_ordinal, chunk_ordinal), None) + + async def get_chunk_failures(self) -> list[ChunkFailure]: + """Return all recorded chunk failures, ordered by (msg_ordinal, chunk_ordinal).""" + return [self._chunk_failures[k] for k in sorted(self._chunk_failures)] diff --git a/src/typeagent/storage/memory/semrefindex.py b/src/typeagent/storage/memory/semrefindex.py index 773f9212..8654e5a3 100644 --- a/src/typeagent/storage/memory/semrefindex.py +++ b/src/typeagent/storage/memory/semrefindex.py @@ -30,7 +30,6 @@ ) from ...knowpro.knowledge import extract_knowledge_from_text_batch from ...knowpro.messageutils import ( - get_message_chunk_batch, text_range_from_message_chunk, ) @@ -50,6 +49,7 @@ async def add_batch_to_semantic_ref_index[ batch: list[TextLocation], knowledge_extractor: IKnowledgeExtractor, terms_added: set[str] | None = None, + concurrency: int = 4, ) -> None: messages = conversation.messages @@ -63,7 +63,7 @@ async def add_batch_to_semantic_ref_index[ knowledge_results = await extract_knowledge_from_text_batch( knowledge_extractor, text_batch, - len(text_batch), + concurrency, ) for i, knowledge_result in enumerate(knowledge_results): if isinstance(knowledge_result, Failure): @@ -89,6 +89,7 @@ async def add_batch_to_semantic_ref_index_from_list[ batch: list[TextLocation], knowledge_extractor: IKnowledgeExtractor, terms_added: set[str] | None = None, + concurrency: int = 4, ) -> None: """ Add a batch of knowledge to semantic ref index, extracting from provided message list. @@ -121,7 +122,7 @@ async def add_batch_to_semantic_ref_index_from_list[ knowledge_results = await extract_knowledge_from_text_batch( knowledge_extractor, text_batch, - len(text_batch), + concurrency, ) for i, knowledge_result in enumerate(knowledge_results): if isinstance(knowledge_result, Failure): @@ -726,24 +727,34 @@ async def add_to_semantic_ref_index[ """Add semantic references to the conversation's semantic reference index.""" # Only create knowledge extractor if auto extraction is enabled - knowledge_extractor = None if settings.auto_extract_knowledge: knowledge_extractor = ( settings.knowledge_extractor or convknowledge.KnowledgeExtractor() ) - # Process messages in batches for LLM knowledge extraction - batches = await get_message_chunk_batch( - conversation.messages, - message_ordinal_start_at, - settings.batch_size, - ) - for text_location_batch in batches: + # Build a flat list of all text locations + text_locations: list[TextLocation] = [] + message_ordinal = message_ordinal_start_at + async for message in conversation.messages: + if message_ordinal < message_ordinal_start_at: + message_ordinal += 1 + continue + for chunk_ordinal in range(len(message.text_chunks)): + text_locations.append( + TextLocation( + message_ordinal=message_ordinal, + chunk_ordinal=chunk_ordinal, + ) + ) + message_ordinal += 1 + + if text_locations: await add_batch_to_semantic_ref_index( conversation, - text_location_batch, + text_locations, knowledge_extractor, terms_added, + concurrency=settings.concurrency, ) diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 3d5a3185..2978c8ed 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -11,6 +11,7 @@ from ...knowpro import interfaces from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings from ...knowpro.interfaces import ConversationMetadata, STATUS_INGESTED +from ...knowpro.interfaces_storage import ChunkFailure from .collections import SqliteMessageCollection, SqliteSemanticRefCollection from .messageindex import SqliteMessageTextIndex from .propindex import SqlitePropertyIndex @@ -627,3 +628,56 @@ async def mark_source_ingested( "INSERT OR REPLACE INTO IngestedSources (source_id, status) VALUES (?, ?)", (source_id, status), ) + + async def record_chunk_failure( + self, + message_ordinal: int, + chunk_ordinal: int, + error_class: str, + error_message: str, + ) -> None: + """Record a knowledge-extraction failure for a single chunk. + + Idempotent: re-recording overwrites any prior entry for the same + (message_ordinal, chunk_ordinal). No commit; call within a transaction + context. + """ + failed_at = datetime.now(timezone.utc).isoformat() + cursor = self.db.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO ChunkFailures + (msg_id, chunk_ordinal, error_class, error_message, failed_at) + VALUES (?, ?, ?, ?, ?) + """, + (message_ordinal, chunk_ordinal, error_class, error_message, failed_at), + ) + + async def clear_chunk_failure( + self, message_ordinal: int, chunk_ordinal: int + ) -> None: + """Remove a previously recorded chunk failure (no-op if absent).""" + cursor = self.db.cursor() + cursor.execute( + "DELETE FROM ChunkFailures WHERE msg_id = ? AND chunk_ordinal = ?", + (message_ordinal, chunk_ordinal), + ) + + async def get_chunk_failures(self) -> list[ChunkFailure]: + """Return all recorded chunk failures, ordered by (msg_id, chunk_ordinal).""" + cursor = self.db.cursor() + cursor.execute(""" + SELECT msg_id, chunk_ordinal, error_class, error_message, failed_at + FROM ChunkFailures + ORDER BY msg_id, chunk_ordinal + """) + return [ + ChunkFailure( + message_ordinal=row[0], + chunk_ordinal=row[1], + error_class=row[2], + error_message=row[3], + failed_at=datetime.fromisoformat(row[4]), + ) + for row in cursor.fetchall() + ] diff --git a/src/typeagent/storage/sqlite/schema.py b/src/typeagent/storage/sqlite/schema.py index db6933db..99117c24 100644 --- a/src/typeagent/storage/sqlite/schema.py +++ b/src/typeagent/storage/sqlite/schema.py @@ -148,6 +148,28 @@ ); """ +# Table for tracking knowledge-extraction failures at the chunk level. +# Each row records a (message_ordinal, chunk_ordinal) pair whose extraction +# failed (typically because the LLM returned malformed JSON or an invalid +# schema). The message text itself is still stored in the Messages table; only +# the *enrichment* of that chunk is missing. A future "re-extract" tool can +# read this table to retry just the failed chunks. +CHUNK_FAILURES_SCHEMA = """ +CREATE TABLE IF NOT EXISTS ChunkFailures ( + msg_id INTEGER NOT NULL, -- Message ordinal (matches Messages.msg_id) + chunk_ordinal INTEGER NOT NULL, -- 0-based index into the message's text_chunks + error_class TEXT NOT NULL, -- Fully-qualified class name of the failure + error_message TEXT NOT NULL, -- Human-readable failure description + failed_at TEXT NOT NULL, -- ISO-8601 UTC timestamp of the failure + + PRIMARY KEY (msg_id, chunk_ordinal) +); +""" + +CHUNK_FAILURES_MSG_INDEX = """ +CREATE INDEX IF NOT EXISTS idx_chunk_failures_msg ON ChunkFailures(msg_id); +""" + # Type aliases for database row tuples type ShreddedMessage = tuple[ str | None, str | None, str | None, str | None, str | None, str | None @@ -271,6 +293,7 @@ def init_db_schema(db: sqlite3.Connection) -> None: cursor.execute(RELATED_TERMS_FUZZY_SCHEMA) cursor.execute(TIMESTAMP_INDEX_SCHEMA) cursor.execute(INGESTED_SOURCES_SCHEMA) + cursor.execute(CHUNK_FAILURES_SCHEMA) # Create additional indexes cursor.execute(SEMANTIC_REF_INDEX_TERM_INDEX) @@ -279,6 +302,7 @@ def init_db_schema(db: sqlite3.Connection) -> None: cursor.execute(RELATED_TERMS_ALIASES_TERM_INDEX) cursor.execute(RELATED_TERMS_ALIASES_ALIAS_INDEX) cursor.execute(RELATED_TERMS_FUZZY_TERM_INDEX) + cursor.execute(CHUNK_FAILURES_MSG_INDEX) def get_db_schema_version(db: sqlite3.Connection) -> int: diff --git a/tests/test_add_messages_streaming.py b/tests/test_add_messages_streaming.py new file mode 100644 index 00000000..bdc25ede --- /dev/null +++ b/tests/test_add_messages_streaming.py @@ -0,0 +1,293 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for add_messages_streaming.""" + +from collections.abc import AsyncIterator +import os +import tempfile + +import pytest + +import typechat + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import IKnowledgeExtractor +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_message( + text: str, + speaker: str = "Alice", + source_id: str | None = None, +) -> TranscriptMessage: + return TranscriptMessage( + text_chunks=[text], + metadata=TranscriptMessageMeta(speaker=speaker), + tags=["test"], + source_id=source_id, + ) + + +async def _create_transcript( + db_path: str, + *, + auto_extract: bool = False, + knowledge_extractor: IKnowledgeExtractor | None = None, +) -> tuple[Transcript, SqliteStorageProvider]: + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = auto_extract + if knowledge_extractor is not None: + settings.semantic_ref_index_settings.knowledge_extractor = knowledge_extractor + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + transcript = await Transcript.create(settings, name="test") + return transcript, storage + + +async def _async_iter( + items: list[TranscriptMessage], +) -> AsyncIterator[TranscriptMessage]: + for item in items: + yield item + + +def _ingested_count(storage: SqliteStorageProvider) -> int: + cursor = storage.db.cursor() + cursor.execute("SELECT COUNT(*) FROM IngestedSources") + return cursor.fetchone()[0] + + +def _failure_count(storage: SqliteStorageProvider) -> int: + cursor = storage.db.cursor() + cursor.execute("SELECT COUNT(*) FROM ChunkFailures") + return cursor.fetchone()[0] + + +# --------------------------------------------------------------------------- +# A test IKnowledgeExtractor that lets us control per-call results +# --------------------------------------------------------------------------- + +_EMPTY_RESPONSE = kplib.KnowledgeResponse( + entities=[], actions=[], inverse_actions=[], topics=[] +) + + +class ControlledExtractor: + """An IKnowledgeExtractor that returns Success or Failure per call. + + ``fail_on`` is a set of 0-based call indices for which the extractor + returns a Failure instead of a Success. + ``raise_on`` is a set of call indices that raise an exception. + """ + + def __init__( + self, + *, + fail_on: set[int] | None = None, + raise_on: set[int] | None = None, + ) -> None: + self.fail_on = fail_on or set() + self.raise_on = raise_on or set() + self.call_count = 0 + + async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: + idx = self.call_count + self.call_count += 1 + if idx in self.raise_on: + raise RuntimeError(f"Systemic failure at call {idx}") + if idx in self.fail_on: + return typechat.Failure(f"Extraction failed for call {idx}") + return typechat.Success(_EMPTY_RESPONSE) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_streaming_basic() -> None: + """Streaming ingest of a few messages with no extraction.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}") for i in range(5)] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 5 + assert await transcript.messages.size() == 5 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_batching() -> None: + """Messages are committed in batches of the requested size.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(7)] + result = await transcript.add_messages_streaming( + _async_iter(msgs), batch_size=3 + ) + + # 3 batches: [0,1,2], [3,4,5], [6] + assert result.messages_added == 7 + assert await transcript.messages.size() == 7 + # All 7 sources marked + assert _ingested_count(storage) == 7 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_skips_already_ingested() -> None: + """Messages whose source_id is already ingested are skipped.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + # Pre-mark some sources as ingested + async with storage: + await storage.mark_source_ingested("s-1") + await storage.mark_source_ingested("s-3") + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(5)] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + # s-1 and s-3 skipped -> only 3 added + assert result.messages_added == 3 + assert await transcript.messages.size() == 3 + assert _ingested_count(storage) == 5 # 2 pre-existing + 3 new + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_no_source_id_always_ingested() -> None: + """Messages without source_id are always ingested (never skipped).""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}") for i in range(3)] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 3 + assert _ingested_count(storage) == 0 # no source IDs to track + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_records_chunk_failures() -> None: + """Extraction Failure results are recorded, not raised.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor(fail_on={1}) # second chunk fails + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + _make_message("good chunk 0"), + _make_message("bad chunk 1"), + _make_message("good chunk 2"), + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 3 + assert _failure_count(storage) == 1 + + failures = await storage.get_chunk_failures() + assert len(failures) == 1 + assert failures[0].message_ordinal == 1 + assert failures[0].chunk_ordinal == 0 + assert "Extraction failed" in failures[0].error_message + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_exception_stops_run() -> None: + """A raised exception stops processing; committed batches survive.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Raise on the 4th extract call (first chunk of second batch) + extractor = ControlledExtractor(raise_on={3}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + + with pytest.raises(ExceptionGroup) as exc_info: + await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) + + # Verify the wrapped exception is our RuntimeError + assert any( + isinstance(e, RuntimeError) and "Systemic failure" in str(e) + for e in exc_info.value.exceptions + ) + + # First batch (3 messages, 3 extract calls 0-2) committed + assert await transcript.messages.size() == 3 + assert _ingested_count(storage) == 3 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_empty_iterable() -> None: + """Streaming with no messages returns zeros.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + result = await transcript.add_messages_streaming(_async_iter([])) + + assert result.messages_added == 0 + assert result.semrefs_added == 0 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_all_skipped_batch() -> None: + """A batch where all messages are already ingested produces no commit.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + # Pre-mark all sources + async with storage: + for i in range(3): + await storage.mark_source_ingested(f"s-{i}") + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(3)] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 0 + assert await transcript.messages.size() == 0 + + await storage.close() diff --git a/tests/test_messageutils.py b/tests/test_messageutils.py index 37b10d70..97c61c13 100644 --- a/tests/test_messageutils.py +++ b/tests/test_messageutils.py @@ -1,16 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pytest - from typeagent.knowpro.interfaces import TextLocation, TextRange from typeagent.knowpro.messageutils import ( - get_message_chunk_batch, text_range_from_message_chunk, ) -from typeagent.storage.memory.collections import MemoryMessageCollection - -from conftest import FakeMessage class TestTextRangeFromMessageChunk: @@ -27,59 +21,3 @@ def test_explicit_chunk_ordinal(self) -> None: def test_returns_text_range(self) -> None: tr = text_range_from_message_chunk(0) assert isinstance(tr, TextRange) - - -class TestGetMessageChunkBatch: - @pytest.mark.asyncio - async def test_empty_collection(self) -> None: - messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection() - batches = await get_message_chunk_batch(messages, 0, 10) - assert batches == [] - - @pytest.mark.asyncio - async def test_single_message_single_chunk(self) -> None: - messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( - [FakeMessage("hello")] - ) - batches = await get_message_chunk_batch(messages, 0, 10) - assert len(batches) == 1 - assert len(batches[0]) == 1 - assert batches[0][0] == TextLocation(0, 0) - - @pytest.mark.asyncio - async def test_message_with_multiple_chunks(self) -> None: - messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( - [FakeMessage(["chunk0", "chunk1", "chunk2"])] - ) - batches = await get_message_chunk_batch(messages, 0, 10) - assert len(batches) == 1 - locs = batches[0] - assert locs == [TextLocation(0, 0), TextLocation(0, 1), TextLocation(0, 2)] - - @pytest.mark.asyncio - async def test_batch_size_splits_across_messages(self) -> None: - messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( - [FakeMessage("a"), FakeMessage("b"), FakeMessage("c")] - ) - batches = await get_message_chunk_batch(messages, 0, batch_size=2) - assert len(batches) == 2 - assert len(batches[0]) == 2 - assert len(batches[1]) == 1 - - @pytest.mark.asyncio - async def test_exact_batch_size(self) -> None: - messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( - [FakeMessage("a"), FakeMessage("b")] - ) - batches = await get_message_chunk_batch(messages, 0, batch_size=2) - assert len(batches) == 1 - assert len(batches[0]) == 2 - - @pytest.mark.asyncio - async def test_start_offset_skips_earlier_messages(self) -> None: - messages: MemoryMessageCollection[FakeMessage] = MemoryMessageCollection( - [FakeMessage("skip"), FakeMessage("include")] - ) - batches = await get_message_chunk_batch(messages, 1, batch_size=10) - assert len(batches) == 1 - assert batches[0][0] == TextLocation(1, 0) diff --git a/tests/test_podcasts.py b/tests/test_podcasts.py index 02ccf8e8..d77f6ba3 100644 --- a/tests/test_podcasts.py +++ b/tests/test_podcasts.py @@ -118,7 +118,7 @@ async def test_ingest_podcast( @pytest.mark.asyncio -async def test_ingest_podcast_parallelism_uses_batch_size( +async def test_ingest_podcast_parallelism_uses_concurrency( temp_dir: str, embedding_model: IEmbeddingModel ) -> None: transcript_path = os.path.join(temp_dir, "parallel_podcast.txt") @@ -130,15 +130,15 @@ async def test_ingest_podcast_parallelism_uses_batch_size( extractor = TrackingKnowledgeExtractor() settings.semantic_ref_index_settings.knowledge_extractor = extractor - batch_size = 20 + concurrency = 5 podcast = await podcast_ingest.ingest_podcast( transcript_path, settings, start_date=Datetime.now(timezone.utc), length_minutes=5.0, - batch_size=batch_size, + concurrency=concurrency, ) assert await podcast.messages.size() == 25 - assert extractor.max_concurrency == batch_size + assert extractor.max_concurrency == concurrency assert len(extractor.started_texts) == 25 diff --git a/tests/test_source_id_ingestion.py b/tests/test_source_id_ingestion.py new file mode 100644 index 00000000..1886a7b7 --- /dev/null +++ b/tests/test_source_id_ingestion.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for source_id-based ingestion tracking in add_messages_with_indexing.""" + +import os +import tempfile + +import pytest + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + + +def _make_message( + text: str, speaker: str = "Alice", source_id: str | None = None +) -> TranscriptMessage: + return TranscriptMessage( + text_chunks=[text], + metadata=TranscriptMessageMeta(speaker=speaker), + tags=["test"], + source_id=source_id, + ) + + +async def _create_transcript( + db_path: str, +) -> tuple[Transcript, SqliteStorageProvider]: + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + transcript = await Transcript.create(settings, name="test") + return transcript, storage + + +def _ingested_count(storage: SqliteStorageProvider) -> int: + """Count rows in IngestedSources table.""" + cursor = storage.db.cursor() + cursor.execute("SELECT COUNT(*) FROM IngestedSources") + return cursor.fetchone()[0] + + +@pytest.mark.asyncio +async def test_explicit_source_ids_marks_ingested() -> None: + """Passing source_ids= explicitly marks those IDs as ingested.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message("Hello"), _make_message("World")] + await transcript.add_messages_with_indexing(msgs, source_ids=["src-1", "src-2"]) + + assert await storage.is_source_ingested("src-1") + assert await storage.is_source_ingested("src-2") + assert not await storage.is_source_ingested("src-3") + + await storage.close() + + +@pytest.mark.asyncio +async def test_message_source_id_marks_ingested() -> None: + """When source_ids is omitted, message.source_id is used.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_message("Hello", source_id="msg-src-1"), + _make_message("World", source_id="msg-src-2"), + ] + await transcript.add_messages_with_indexing(msgs) + + assert await storage.is_source_ingested("msg-src-1") + assert await storage.is_source_ingested("msg-src-2") + + await storage.close() + + +@pytest.mark.asyncio +async def test_message_source_id_none_skipped() -> None: + """Messages with source_id=None are silently skipped (no ingestion mark).""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_message("Hello", source_id="only-one"), + _make_message("World"), # source_id=None + ] + await transcript.add_messages_with_indexing(msgs) + + assert await storage.is_source_ingested("only-one") + # The second message had no source_id, so nothing extra was marked + assert await storage.get_source_status("only-one") == "ingested" + assert _ingested_count(storage) == 1 + + await storage.close() + + +@pytest.mark.asyncio +async def test_explicit_source_ids_overrides_message_source_id() -> None: + """Passing source_ids= takes precedence; message.source_id is ignored.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [ + _make_message("Hello", source_id="msg-level"), + ] + await transcript.add_messages_with_indexing(msgs, source_ids=["explicit-id"]) + + assert await storage.is_source_ingested("explicit-id") + assert not await storage.is_source_ingested("msg-level") + + await storage.close() + + +@pytest.mark.asyncio +async def test_source_ids_length_mismatch_raises() -> None: + """Passing source_ids with wrong length raises ValueError.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message("Hello"), _make_message("World")] + with pytest.raises(ValueError, match="Length of source_ids"): + await transcript.add_messages_with_indexing(msgs, source_ids=["only-one"]) + + await storage.close() + + +@pytest.mark.asyncio +async def test_no_source_ids_no_message_source_id() -> None: + """When neither source_ids nor message.source_id is set, nothing is marked.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message("Hello"), _make_message("World")] + result = await transcript.add_messages_with_indexing(msgs) + + assert result.messages_added == 2 + # No source tracking happened + assert _ingested_count(storage) == 0 + + await storage.close() diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index cc218bf9..e2354405 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -305,7 +305,7 @@ async def test_transcript_knowledge_extraction_slow( # Enable knowledge extraction settings.semantic_ref_index_settings.auto_extract_knowledge = True - settings.semantic_ref_index_settings.batch_size = 10 + settings.semantic_ref_index_settings.concurrency = 10 # Add messages with indexing (this should extract knowledge) result = await transcript.add_messages_with_indexing(messages_list) diff --git a/tools/ingest_email.py b/tools/ingest_email.py index b7768e8a..eccac4cb 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -301,9 +301,9 @@ async def ingest_emails( if verbose: print(f"Target database: {database}") - batch_size = settings.semantic_ref_index_settings.batch_size + concurrency = settings.semantic_ref_index_settings.concurrency if verbose: - print(f"Batch size: {batch_size}") + print(f"Concurrency: {concurrency}") print("\nParsing and importing emails...") success_count = 0 @@ -344,7 +344,7 @@ async def ingest_emails( sys.exit(f"Authentication error: {e!r}") # Print progress periodically - if (success_count + failed_count) % batch_size == 0: + if concurrency and (success_count + failed_count) % concurrency == 0: elapsed = time.time() - start_time semref_count = await semref_coll.size() print( diff --git a/tools/ingest_podcast.py b/tools/ingest_podcast.py index 39195145..c0f7303d 100644 --- a/tools/ingest_podcast.py +++ b/tools/ingest_podcast.py @@ -31,7 +31,13 @@ async def main(): "--batch-size", type=int, default=10, - help="Batch size for message indexing (default 10)", + help="Number of messages per indexing call (default 10)", + ) + parser.add_argument( + "--concurrency", + type=int, + default=0, + help="Max concurrent knowledge extractions (0 = use settings default)", ) parser.add_argument( "--start-message", @@ -75,6 +81,7 @@ async def main(): dbname=args.database, batch_size=args.batch_size, start_message=args.start_message, + concurrency=args.concurrency, verbose=not args.quiet, ) except (RuntimeError, ValueError) as err: diff --git a/tools/ingest_vtt.py b/tools/ingest_vtt.py index 7fbf38fc..ffaccfc1 100644 --- a/tools/ingest_vtt.py +++ b/tools/ingest_vtt.py @@ -75,10 +75,10 @@ def create_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--batchsize", + "--concurrency", type=int, default=None, - help="Batch size for knowledge extraction (default: from settings)", + help="Max concurrent knowledge extractions (default: from settings)", ) parser.add_argument( @@ -132,7 +132,7 @@ async def ingest_vtt_files( name: str | None = None, merge_consecutive: bool = False, verbose: bool = False, - batchsize: int | None = None, + concurrency: int | None = None, embedding_name: str | None = None, ) -> None: """Ingest one or more VTT files into a database.""" @@ -227,9 +227,9 @@ async def ingest_vtt_files( # Update settings to use our storage provider settings.storage_provider = storage_provider - # Override batch size if specified - if batchsize is not None: - settings.semantic_ref_index_settings.batch_size = batchsize + # Override concurrency if specified + if concurrency is not None: + settings.semantic_ref_index_settings.concurrency = concurrency if verbose: print("Settings and storage provider configured") @@ -368,7 +368,7 @@ def save_current_message(): f" auto_extract_knowledge = {settings.semantic_ref_index_settings.auto_extract_knowledge}" ) print( - f" batch_size = {settings.semantic_ref_index_settings.batch_size}" + f" concurrency = {settings.semantic_ref_index_settings.concurrency}" ) # Create a Transcript object @@ -378,13 +378,14 @@ def save_current_message(): tags=[name, "vtt-transcript"], ) - # Process messages in batches - batch_size = settings.semantic_ref_index_settings.batch_size + # Process messages in batches for recoverability + batch_size = 50 successful_count = 0 start_time = time.time() print( - f" Processing {len(all_messages)} messages in batches of {batch_size}..." + f" Processing {len(all_messages)} messages" + f" (concurrency={settings.semantic_ref_index_settings.concurrency})..." ) for i in range(0, len(all_messages), batch_size): @@ -449,7 +450,7 @@ def main(): database=args.database, name=args.name, merge_consecutive=args.merge, - batchsize=args.batchsize, + concurrency=args.concurrency, embedding_name=args.embedding_name, verbose=args.verbose, )