diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index e22083c8..f67812e6 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -113,11 +113,14 @@ async def add_key(self, key: str, cache: bool = True) -> None: embedding = await self.get_embedding(key, cache=cache) self.add_embedding(key if cache else None, embedding) - async def add_keys(self, keys: list[str], cache: bool = True) -> None: + async def add_keys( + self, keys: list[str], cache: bool = True + ) -> NormalizedEmbeddings | None: if not keys: - return + return None embeddings = await self.get_embeddings(keys, cache=cache) self.add_embeddings(keys if cache else None, embeddings) + return embeddings def fuzzy_lookup_embedding( self, diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 131b0ceb..83ad45dc 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -3,7 +3,9 @@ """Base class for conversations with incremental indexing support.""" -from collections.abc import AsyncIterable, Callable +import asyncio +import contextlib +from collections.abc import AsyncIterable, Callable, Sequence from dataclasses import dataclass from datetime import datetime, timezone from typing import Generic, Self, TypeVar @@ -37,12 +39,22 @@ MessageOrdinal, Topic, ) +from .interfaces_core import TextLocation from .knowledge import extract_knowledge_from_text_batch from .messageutils import get_all_message_chunk_locations TMessage = TypeVar("TMessage", bound=IMessage) +@dataclass(frozen=True) +class _ExtractionResult: + """Pre-extracted knowledge for a batch, ready to commit.""" + + messages: Sequence[IMessage] + text_locations: list[TextLocation] + knowledge_results: list[typechat.Result[kplib.KnowledgeResponse]] + + @dataclass(init=False) class ConversationBase( Generic[TMessage], IConversation[TMessage, ITermToSemanticRefIndex] @@ -158,13 +170,13 @@ async def add_messages_with_indexing( async with storage: # Mark source IDs as ingested (will be rolled back on error) - if source_ids is not None: - for sid in source_ids: - await storage.mark_source_ingested(sid) - else: - for msg in messages: - if msg.source_id is not None: - await storage.mark_source_ingested(msg.source_id) + sids = ( + source_ids + if source_ids is not None + else [m.source_id for m in messages if m.source_id is not None] + ) + if sids: + await storage.mark_sources_ingested_batch(sids) start_points = IndexingStartPoints( message_count=await self.messages.size(), @@ -208,12 +220,10 @@ async def add_messages_streaming( ) -> AddMessagesResult: """Add messages from an async iterable, committing in batches. - Unlike ``add_messages_with_indexing`` which processes all messages in a - single transaction, this method buffers messages into batches of - ``batch_size``, processes each batch in its own transaction, and commits - after every batch. This makes it suitable for ingesting large streams - (millions of messages) where a single all-or-nothing transaction would - be impractical. + Uses a two-stage pipeline: while batch N is being committed (DB writes, + embeddings, secondary indexes), batch N+1's LLM extraction runs + concurrently. LLM extraction is typically 95% of wall time, so this + nearly doubles throughput for multi-batch ingestions. **Source-ID tracking**: each message's ``source_id`` (if not ``None``) is checked before ingestion. Already-ingested sources are silently @@ -236,48 +246,88 @@ async def add_messages_streaming( Cumulative ``AddMessagesResult`` across all committed batches. """ storage = await self.settings.get_storage_provider() - total_messages_added = 0 - total_semrefs_added = 0 - total_chunks_added = 0 - - batch: list[TMessage] = [] - async for msg in messages: - batch.append(msg) - if len(batch) >= batch_size: - result = await self._ingest_batch_streaming(storage, batch) - total_messages_added += result.messages_added - total_semrefs_added += result.semrefs_added - total_chunks_added += result.chunks_added - if on_batch_committed: - on_batch_committed(result) - batch = [] - - # Flush remaining messages - if batch: - result = await self._ingest_batch_streaming(storage, batch) - total_messages_added += result.messages_added - total_semrefs_added += result.semrefs_added - total_chunks_added += result.chunks_added + should_extract = ( + self.settings.semantic_ref_index_settings.auto_extract_knowledge + ) + total = AddMessagesResult() + + def _accumulate(result: AddMessagesResult) -> None: + total.messages_added += result.messages_added + total.semrefs_added += result.semrefs_added + total.chunks_added += result.chunks_added if on_batch_committed: on_batch_committed(result) - return AddMessagesResult( - messages_added=total_messages_added, - chunks_added=total_chunks_added, - semrefs_added=total_semrefs_added, - ) + pending_commit: asyncio.Task[AddMessagesResult] | None = None + + async def _drain_commit() -> None: + nonlocal pending_commit + if pending_commit is not None: + _accumulate(await pending_commit) + pending_commit = None + + async def _submit_batch(filtered: list[TMessage]) -> None: + nonlocal pending_commit + if not filtered: + return + + if should_extract: + next_extraction = asyncio.create_task( + self._extract_knowledge_for_batch(filtered) + ) + else: + next_extraction = None + + # Wait for previous commit to finish (frees the DB connection) + await _drain_commit() - async def _ingest_batch_streaming( + # Await extraction result for this batch + extraction = ( + await next_extraction if next_extraction is not None else None + ) + + # Start commit (DB transaction) — runs concurrently with the + # *next* batch's LLM extraction once we yield back to the loop. + pending_commit = asyncio.create_task( + self._commit_batch_streaming(storage, filtered, extraction) + ) + + try: + batch: list[TMessage] = [] + async for msg in messages: + batch.append(msg) + if len(batch) >= batch_size: + filtered = await self._filter_ingested(storage, batch) + await _submit_batch(filtered) + batch = [] + + if batch: + filtered = await self._filter_ingested(storage, batch) + await _submit_batch(filtered) + + await _drain_commit() + except BaseException: + if pending_commit is not None and not pending_commit.done(): + pending_commit.cancel() + with contextlib.suppress(asyncio.CancelledError): + await pending_commit + raise + + return total + + async def _filter_ingested( self, storage: IStorageProvider[TMessage], batch: list[TMessage], - ) -> AddMessagesResult: - """Process and commit a single batch within a transaction. - - Messages whose ``source_id`` is already ingested are filtered out. - Extraction ``Failure``\\s are recorded as chunk failures. + ) -> list[TMessage]: + """Filter out messages whose source_id has already been ingested. + + Safe to call while a pending_commit task exists: is_source_ingested + is a synchronous SELECT on SQLite's single connection, so it won't + interleave with the commit task's cursor operations in asyncio's + cooperative model. If the storage provider becomes truly async + (e.g. aiosqlite), this assumption needs revisiting. """ - # Filter out already-ingested sources filtered: list[TMessage] = [] for msg in batch: if msg.source_id is not None and await storage.is_source_ingested( @@ -285,10 +335,76 @@ async def _ingest_batch_streaming( ): continue filtered.append(msg) + return filtered + + async def _extract_knowledge_for_batch( + self, + messages: list[TMessage], + ) -> _ExtractionResult | None: + """Run LLM extraction on message texts — no DB access. + + Uses 0-based ordinals; the caller remaps to global ordinals at commit + time. Safe to run concurrently with a DB transaction on another batch. + """ + text_locations = get_all_message_chunk_locations(messages, 0) + if not text_locations: + return None + + settings = self.settings.semantic_ref_index_settings + knowledge_extractor = ( + settings.knowledge_extractor or convknowledge.KnowledgeExtractor() + ) + + text_batch = [ + messages[tl.message_ordinal].text_chunks[tl.chunk_ordinal].strip() + for tl in text_locations + ] + + knowledge_results = await extract_knowledge_from_text_batch( + knowledge_extractor, + text_batch, + settings.concurrency, + ) + return _ExtractionResult( + messages=messages, + text_locations=text_locations, + knowledge_results=knowledge_results, + ) - if not filtered: - return AddMessagesResult() + async def _apply_extraction_results( + self, + storage: IStorageProvider[TMessage], + extraction: _ExtractionResult, + global_message_start: int, + ) -> None: + """Write pre-extracted knowledge into the DB. Must be inside a transaction.""" + bulk_items: list[tuple[int, int, kplib.KnowledgeResponse]] = [] + for i, knowledge_result in enumerate(extraction.knowledge_results): + tl = extraction.text_locations[i] + global_msg_ord = tl.message_ordinal + global_message_start + if isinstance(knowledge_result, typechat.Failure): + await storage.record_chunk_failure( + global_msg_ord, + tl.chunk_ordinal, + type(knowledge_result).__name__, + knowledge_result.message[:500], + ) + continue + bulk_items.append( + (global_msg_ord, tl.chunk_ordinal, knowledge_result.value) + ) + if bulk_items: + await semrefindex.add_knowledge_batch_to_semantic_ref_index( + self, bulk_items + ) + async def _commit_batch_streaming( + self, + storage: IStorageProvider[TMessage], + filtered: list[TMessage], + extraction: _ExtractionResult | None, + ) -> AddMessagesResult: + """Commit a single batch with pre-extracted knowledge.""" async with storage: start_points = IndexingStartPoints( message_count=await self.messages.size(), @@ -297,16 +413,15 @@ async def _ingest_batch_streaming( await self.messages.extend(filtered) - # Mark source IDs as ingested (rolled back on error) - for msg in filtered: - if msg.source_id is not None: - await storage.mark_source_ingested(msg.source_id) + source_ids = [m.source_id for m in filtered if m.source_id is not None] + if source_ids: + await storage.mark_sources_ingested_batch(source_ids) await self._add_metadata_knowledge_incremental(start_points.message_count) - if self.settings.semantic_ref_index_settings.auto_extract_knowledge: - await self._add_llm_knowledge_streaming( - storage, filtered, start_points.message_count + if extraction is not None: + await self._apply_extraction_results( + storage, extraction, start_points.message_count ) await self._update_secondary_indexes_incremental(start_points) @@ -324,59 +439,6 @@ async def _ingest_batch_streaming( - start_points.semref_count, ) - async def _add_llm_knowledge_streaming( - self, - storage: IStorageProvider[TMessage], - messages: list[TMessage], - start_from_message_ordinal: int, - ) -> None: - """Extract LLM knowledge, recording failures instead of raising. - - On ``Failure``: records a chunk failure via the storage provider and - continues. On a raised exception: lets it propagate (the caller's - ``async with storage`` will roll back the transaction). - """ - settings = self.settings.semantic_ref_index_settings - knowledge_extractor = ( - settings.knowledge_extractor or convknowledge.KnowledgeExtractor() - ) - - text_locations = get_all_message_chunk_locations( - messages, start_from_message_ordinal - ) - if not text_locations: - return - - start_ordinal = text_locations[0].message_ordinal - text_batch: list[str] = [] - for tl in text_locations: - list_index = tl.message_ordinal - start_ordinal - text_batch.append( - messages[list_index].text_chunks[tl.chunk_ordinal].strip() - ) - - knowledge_results = await extract_knowledge_from_text_batch( - knowledge_extractor, - text_batch, - settings.concurrency, - ) - for i, knowledge_result in enumerate(knowledge_results): - tl = text_locations[i] - if isinstance(knowledge_result, typechat.Failure): - await storage.record_chunk_failure( - tl.message_ordinal, - tl.chunk_ordinal, - type(knowledge_result).__name__, - knowledge_result.message[:500], - ) - continue - await semrefindex.add_knowledge_to_semantic_ref_index( - self, - tl.message_ordinal, - tl.chunk_ordinal, - knowledge_result.value, - ) - async def _add_metadata_knowledge_incremental( self, start_from_message_ordinal: int, diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index c0450f29..20a58858 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -183,6 +183,12 @@ async def mark_source_ingested( """Mark a source as ingested (no commit; call within transaction context).""" ... + async def mark_sources_ingested_batch( + self, source_ids: list[str], status: str = STATUS_INGESTED + ) -> None: + """Mark multiple sources as ingested in one operation.""" + ... + # Chunk-level extraction failure tracking async def record_chunk_failure( diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index e2503967..50723f50 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -28,8 +28,7 @@ async def extract_knowledge_from_text( knowledge_extractor: IKnowledgeExtractor, text: str, ) -> Result[kplib.KnowledgeResponse]: - """Extract knowledge from a single text input with retries.""" - # TODO: Add a retry mechanism to handle transient errors. + """Extract knowledge from a single text input.""" return await knowledge_extractor.extract(text) diff --git a/src/typeagent/storage/memory/provider.py b/src/typeagent/storage/memory/provider.py index 603fbd24..f37ec96f 100644 --- a/src/typeagent/storage/memory/provider.py +++ b/src/typeagent/storage/memory/provider.py @@ -176,6 +176,12 @@ async def mark_source_ingested( """ self._ingested_sources.add(source_id) + async def mark_sources_ingested_batch( + self, source_ids: list[str], status: str = STATUS_INGESTED + ) -> None: + """Mark multiple sources as ingested in one operation.""" + self._ingested_sources.update(source_ids) + async def record_chunk_failure( self, message_ordinal: int, diff --git a/src/typeagent/storage/memory/semrefindex.py b/src/typeagent/storage/memory/semrefindex.py index 8654e5a3..84892171 100644 --- a/src/typeagent/storage/memory/semrefindex.py +++ b/src/typeagent/storage/memory/semrefindex.py @@ -48,9 +48,9 @@ async def add_batch_to_semantic_ref_index[ conversation: IConversation[TMessage, TTermToSemanticRefIndex], batch: list[TextLocation], knowledge_extractor: IKnowledgeExtractor, - terms_added: set[str] | None = None, concurrency: int = 4, ) -> None: + """Extract knowledge and bulk-add to the semantic ref index.""" messages = conversation.messages text_batch = [ @@ -65,19 +65,19 @@ async def add_batch_to_semantic_ref_index[ text_batch, concurrency, ) + bulk_items: list[tuple[int, int, kplib.KnowledgeResponse]] = [] for i, knowledge_result in enumerate(knowledge_results): if isinstance(knowledge_result, Failure): raise RuntimeError( f"Knowledge extraction failed: {knowledge_result.message}" ) - text_location = batch[i] - knowledge = knowledge_result.value - await add_knowledge_to_semantic_ref_index( - conversation, - text_location.message_ordinal, - text_location.chunk_ordinal, - knowledge, - terms_added, + tl = batch[i] + bulk_items.append( + (tl.message_ordinal, tl.chunk_ordinal, knowledge_result.value) + ) + if bulk_items: + await add_knowledge_batch_to_semantic_ref_index( + conversation, bulk_items ) @@ -88,55 +88,43 @@ async def add_batch_to_semantic_ref_index_from_list[ messages: list[TMessage], batch: list[TextLocation], knowledge_extractor: IKnowledgeExtractor, - terms_added: set[str] | None = None, concurrency: int = 4, ) -> None: - """ - Add a batch of knowledge to semantic ref index, extracting from provided message list. - - Args: - conversation: The conversation containing semantic refs and index - messages: List of messages containing the text to extract from - batch: List of text locations (ordinals) to process - knowledge_extractor: Extractor for LLM-based knowledge extraction - terms_added: Optional set to track newly added terms - """ - # Get the starting ordinal of the message list + """Extract knowledge from messages and bulk-add to the semantic ref index.""" if not batch: return start_ordinal = batch[0].message_ordinal - # Extract text from the messages list text_batch = [] for tl in batch: - # Calculate index in the list from the ordinal list_index = tl.message_ordinal - start_ordinal if list_index < 0 or list_index >= len(messages): raise IndexError( - f"Message ordinal {tl.message_ordinal} out of range for list starting at {start_ordinal}" + f"Message ordinal {tl.message_ordinal} out of range " + f"for list starting at {start_ordinal}" ) - message = messages[list_index] - text = message.text_chunks[tl.chunk_ordinal].strip() - text_batch.append(text) + text_batch.append( + messages[list_index].text_chunks[tl.chunk_ordinal].strip() + ) knowledge_results = await extract_knowledge_from_text_batch( knowledge_extractor, text_batch, concurrency, ) + bulk_items: list[tuple[int, int, kplib.KnowledgeResponse]] = [] for i, knowledge_result in enumerate(knowledge_results): if isinstance(knowledge_result, Failure): raise RuntimeError( f"Knowledge extraction failed: {knowledge_result.message:.150}" ) - text_location = batch[i] - knowledge = knowledge_result.value - await add_knowledge_to_semantic_ref_index( - conversation, - text_location.message_ordinal, - text_location.chunk_ordinal, - knowledge, - terms_added, + tl = batch[i] + bulk_items.append( + (tl.message_ordinal, tl.chunk_ordinal, knowledge_result.value) + ) + if bulk_items: + await add_knowledge_batch_to_semantic_ref_index( + conversation, bulk_items ) @@ -357,22 +345,83 @@ async def add_action( # TODO:L KnowledgeValidator +def _collect_knowledge_refs_and_terms( + base_ordinal: SemanticRefOrdinal, + message_ordinal: MessageOrdinal, + chunk_ordinal: int, + knowledge: kplib.KnowledgeResponse, +) -> tuple[list[SemanticRef], list[tuple[str, SemanticRefOrdinal]]]: + """Collect SemanticRefs and index terms without writing to storage.""" + refs: list[SemanticRef] = [] + terms: list[tuple[str, SemanticRefOrdinal]] = [] + ordinal = base_ordinal + text_range = text_range_from_message_chunk(message_ordinal, chunk_ordinal) + + for entity in knowledge.entities: + if not validate_entity(entity): + continue + refs.append(SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range, + knowledge=entity, + )) + terms.append((entity.name, ordinal)) + for type_name in entity.type: + terms.append((type_name, ordinal)) + if entity.facets: + for facet in entity.facets: + if facet is not None: + terms.append((facet.name, ordinal)) + if facet.value is not None: + terms.append((str(facet.value), ordinal)) + ordinal += 1 + + for action in list(knowledge.actions) + list(knowledge.inverse_actions): + refs.append(SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range, + knowledge=action, + )) + terms.append((" ".join(action.verbs), ordinal)) + if action.subject_entity_name != "none": + terms.append((action.subject_entity_name, ordinal)) + if action.object_entity_name != "none": + terms.append((action.object_entity_name, ordinal)) + if action.indirect_object_entity_name != "none": + terms.append((action.indirect_object_entity_name, ordinal)) + if action.params: + for param in action.params: + if isinstance(param, str): + terms.append((param, ordinal)) + else: + terms.append((param.name, ordinal)) + if isinstance(param.value, str): + terms.append((param.value, ordinal)) + if action.subject_entity_facet is not None: + terms.append((action.subject_entity_facet.name, ordinal)) + if action.subject_entity_facet.value is not None: + terms.append((str(action.subject_entity_facet.value), ordinal)) + ordinal += 1 + + for topic_text in knowledge.topics: + refs.append(SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range, + knowledge=Topic(text=topic_text), + )) + terms.append((topic_text, ordinal)) + ordinal += 1 + + return refs, terms + + async def add_knowledge_to_semantic_ref_index( conversation: IConversation, message_ordinal: MessageOrdinal, chunk_ordinal: int, knowledge: kplib.KnowledgeResponse, - terms_added: set[str] | None = None, ) -> None: - """Add knowledge to the semantic reference index of a conversation. - - Args: - conversation: The conversation to add knowledge to - message_ordinal: Ordinal of the message containing the knowledge - chunk_ordinal: Ordinal of the chunk within the message - knowledge: Knowledge response containing entities, actions and topics - terms_added: Optional set to track terms added to the index - """ + """Add knowledge to the semantic reference index of a conversation.""" verify_has_semantic_ref_index(conversation) semantic_refs = conversation.semantic_refs @@ -380,47 +429,46 @@ async def add_knowledge_to_semantic_ref_index( semantic_ref_index = conversation.semantic_ref_index assert semantic_ref_index is not None - for entity in knowledge.entities: - if validate_entity(entity): - await add_entity( - entity, - semantic_refs, - semantic_ref_index, - message_ordinal, - chunk_ordinal, - terms_added, - ) + base_ordinal = await semantic_refs.size() + refs, terms = _collect_knowledge_refs_and_terms( + base_ordinal, message_ordinal, chunk_ordinal, knowledge, + ) - for action in knowledge.actions: - await add_action( - action, - semantic_refs, - semantic_ref_index, - message_ordinal, - chunk_ordinal, - terms_added, - ) + if refs: + await semantic_refs.extend(refs) + if terms: + await semantic_ref_index.add_terms_batch(terms) - for inverse_action in knowledge.inverse_actions: - await add_action( - inverse_action, - semantic_refs, - semantic_ref_index, - message_ordinal, - chunk_ordinal, - terms_added, - ) - for topic in knowledge.topics: - topic_obj = Topic(text=topic) - await add_topic( - topic_obj, - semantic_refs, - semantic_ref_index, - message_ordinal, - chunk_ordinal, - terms_added, +async def add_knowledge_batch_to_semantic_ref_index( + conversation: IConversation, + items: list[tuple[MessageOrdinal, int, kplib.KnowledgeResponse]], +) -> None: + """Bulk-add knowledge from multiple chunks in two DB round-trips.""" + if not items: + return + verify_has_semantic_ref_index(conversation) + + semantic_refs = conversation.semantic_refs + assert semantic_refs is not None + semantic_ref_index = conversation.semantic_ref_index + assert semantic_ref_index is not None + + all_refs: list[SemanticRef] = [] + all_terms: list[tuple[str, SemanticRefOrdinal]] = [] + base_ordinal = await semantic_refs.size() + + for msg_ord, chunk_ord, knowledge in items: + refs, terms = _collect_knowledge_refs_and_terms( + base_ordinal + len(all_refs), msg_ord, chunk_ord, knowledge, ) + all_refs.extend(refs) + all_terms.extend(terms) + + if all_refs: + await semantic_refs.extend(all_refs) + if all_terms: + await semantic_ref_index.add_terms_batch(all_terms) def validate_entity(entity: kplib.ConcreteEntity) -> bool: @@ -722,40 +770,37 @@ async def add_to_semantic_ref_index[ conversation: IConversation[TMessage, TTermToSemanticRefIndex], settings: SemanticRefIndexSettings, message_ordinal_start_at: MessageOrdinal, - terms_added: set[str] | None = None, ) -> None: """Add semantic references to the conversation's semantic reference index.""" + if not settings.auto_extract_knowledge: + return - # Only create knowledge extractor if auto extraction is enabled - if settings.auto_extract_knowledge: - knowledge_extractor = ( - settings.knowledge_extractor or convknowledge.KnowledgeExtractor() - ) + knowledge_extractor = ( + settings.knowledge_extractor or convknowledge.KnowledgeExtractor() + ) - # Build a flat list of all text locations - text_locations: list[TextLocation] = [] - message_ordinal = message_ordinal_start_at - async for message in conversation.messages: - if message_ordinal < message_ordinal_start_at: - message_ordinal += 1 - continue - for chunk_ordinal in range(len(message.text_chunks)): - text_locations.append( - TextLocation( - message_ordinal=message_ordinal, - chunk_ordinal=chunk_ordinal, - ) - ) + text_locations: list[TextLocation] = [] + message_ordinal = message_ordinal_start_at + async for message in conversation.messages: + if message_ordinal < message_ordinal_start_at: message_ordinal += 1 - - if text_locations: - await add_batch_to_semantic_ref_index( - conversation, - text_locations, - knowledge_extractor, - terms_added, - concurrency=settings.concurrency, + continue + for chunk_ordinal in range(len(message.text_chunks)): + text_locations.append( + TextLocation( + message_ordinal=message_ordinal, + chunk_ordinal=chunk_ordinal, + ) ) + message_ordinal += 1 + + if text_locations: + await add_batch_to_semantic_ref_index( + conversation, + text_locations, + knowledge_extractor, + concurrency=settings.concurrency, + ) def verify_has_semantic_ref_index(conversation: IConversation) -> None: diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 2978c8ed..295cd0b3 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -629,6 +629,18 @@ async def mark_source_ingested( (source_id, status), ) + async def mark_sources_ingested_batch( + self, source_ids: list[str], status: str = STATUS_INGESTED + ) -> None: + """Mark multiple sources as ingested in one operation.""" + if not source_ids: + return + cursor = self.db.cursor() + cursor.executemany( + "INSERT OR REPLACE INTO IngestedSources (source_id, status) VALUES (?, ?)", + [(sid, status) for sid in source_ids], + ) + async def record_chunk_failure( self, message_ordinal: int, diff --git a/src/typeagent/storage/sqlite/reltermsindex.py b/src/typeagent/storage/sqlite/reltermsindex.py index dec29db2..cf5b201b 100644 --- a/src/typeagent/storage/sqlite/reltermsindex.py +++ b/src/typeagent/storage/sqlite/reltermsindex.py @@ -209,30 +209,24 @@ async def get_terms(self) -> list[str]: return [row[0] for row in cursor.fetchall()] async def add_terms(self, texts: list[str]) -> None: - """Add terms.""" + """Add terms with batched embedding generation and DB writes.""" + new_terms = [t for t in texts if t not in self._added_terms] + if not new_terms: + return + + embeddings = await self._vector_base.add_keys(new_terms) + assert embeddings is not None + cursor = self.db.cursor() - # TODO: Batch additions to database - for text in texts: - if text in self._added_terms: - continue - - # Add to VectorBase for fuzzy lookup - await self._vector_base.add_key(text) - self._terms_list.append(text) - self._added_terms.add(text) - - # Generate embedding for term and store in database - embedding = await self._vector_base.get_embedding(text) # Cached - serialized_embedding = serialize_embedding(embedding) - # Insert term and embedding - cursor.execute( - """ - INSERT OR REPLACE INTO RelatedTermsFuzzy - (term, term_embedding) - VALUES (?, ?) - """, - (text, serialized_embedding), - ) + cursor.executemany( + "INSERT OR REPLACE INTO RelatedTermsFuzzy (term, term_embedding) VALUES (?, ?)", + [ + (term, serialize_embedding(embeddings[i])) + for i, term in enumerate(new_terms) + ], + ) + self._terms_list.extend(new_terms) + self._added_terms.update(new_terms) async def lookup_terms( self, diff --git a/src/typeagent/storage/sqlite/timestampindex.py b/src/typeagent/storage/sqlite/timestampindex.py index 1419b340..8fe017dd 100644 --- a/src/typeagent/storage/sqlite/timestampindex.py +++ b/src/typeagent/storage/sqlite/timestampindex.py @@ -88,12 +88,13 @@ async def add_timestamps( self, message_timestamps: list[tuple[interfaces.MessageOrdinal, str]] ) -> None: """Add multiple timestamps.""" + if not message_timestamps: + return cursor = self.db.cursor() - for message_ordinal, timestamp in message_timestamps: - cursor.execute( - "UPDATE Messages SET start_timestamp = ? WHERE msg_id = ?", - (timestamp, message_ordinal), - ) + cursor.executemany( + "UPDATE Messages SET start_timestamp = ? WHERE msg_id = ?", + [(ts, ordinal) for ordinal, ts in message_timestamps], + ) async def lookup_range( self, date_range: interfaces.DateRange diff --git a/tests/test_add_messages_streaming.py b/tests/test_add_messages_streaming.py index bdc25ede..dc3f55b4 100644 --- a/tests/test_add_messages_streaming.py +++ b/tests/test_add_messages_streaming.py @@ -291,3 +291,192 @@ async def test_streaming_all_skipped_batch() -> None: assert await transcript.messages.size() == 0 await storage.close() + + +# --------------------------------------------------------------------------- +# Pipeline overlap and DB batching tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_streaming_on_batch_committed_fires_per_batch() -> None: + """on_batch_committed fires once per non-empty batch with the pipelined approach.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + transcript, storage = await _create_transcript(db_path) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(7)] + batch_results: list[int] = [] + result = await transcript.add_messages_streaming( + _async_iter(msgs), + batch_size=3, + on_batch_committed=lambda r: batch_results.append(r.messages_added), + ) + + assert result.messages_added == 7 + # 3 batches: [0,1,2], [3,4,5], [6] + assert batch_results == [3, 3, 1] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_extraction_with_multiple_batches() -> None: + """Extraction results are correctly applied across batches with ordinal remapping.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor() + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + result = await transcript.add_messages_streaming( + _async_iter(msgs), batch_size=3 + ) + + assert result.messages_added == 6 + assert await transcript.messages.size() == 6 + # All 6 chunks extracted (no failures) + assert extractor.call_count == 6 + assert _failure_count(storage) == 0 + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_extraction_failure_across_batches() -> None: + """Extraction failures are recorded with correct global ordinals across batches.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Fail on call index 1 (batch 0, msg 1) and 4 (batch 1, msg 1) + extractor = ControlledExtractor(fail_on={1, 4}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + result = await transcript.add_messages_streaming( + _async_iter(msgs), batch_size=3 + ) + + assert result.messages_added == 6 + assert _failure_count(storage) == 2 + + failures = await storage.get_chunk_failures() + failure_ordinals = sorted(f.message_ordinal for f in failures) + # msg 1 in batch 0 → global ordinal 1, msg 1 in batch 1 → global ordinal 4 + assert failure_ordinals == [1, 4] + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_exception_in_later_batch_preserves_earlier() -> None: + """A raised exception in batch 1 stops processing; batch 0 is committed.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + # Raise on call 4 (first call of batch 1, since batch 0 has 3 msgs) + extractor = ControlledExtractor(raise_on={3}) + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] + with pytest.raises(ExceptionGroup) as exc_info: + await transcript.add_messages_streaming( + _async_iter(msgs), batch_size=3 + ) + + assert any( + isinstance(e, RuntimeError) and "Systemic failure" in str(e) + for e in exc_info.value.exceptions + ) + + # Batch 0 committed (3 messages), batch 1 rolled back + assert await transcript.messages.size() == 3 + assert _ingested_count(storage) == 3 + + await storage.close() + + +@pytest.mark.asyncio +async def test_mark_sources_ingested_batch_sqlite() -> None: + """mark_sources_ingested_batch marks multiple sources in one call.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + _, storage = await _create_transcript(db_path) + + async with storage: + await storage.mark_sources_ingested_batch(["a", "b", "c"]) + + assert await storage.is_source_ingested("a") + assert await storage.is_source_ingested("b") + assert await storage.is_source_ingested("c") + assert not await storage.is_source_ingested("d") + assert _ingested_count(storage) == 3 + + await storage.close() + + +@pytest.mark.asyncio +async def test_mark_sources_ingested_batch_empty() -> None: + """mark_sources_ingested_batch with empty list is a no-op.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + _, storage = await _create_transcript(db_path) + + async with storage: + await storage.mark_sources_ingested_batch([]) + + assert _ingested_count(storage) == 0 + + await storage.close() + + +@pytest.mark.asyncio +async def test_mark_sources_ingested_batch_idempotent() -> None: + """mark_sources_ingested_batch is idempotent via INSERT OR REPLACE.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + _, storage = await _create_transcript(db_path) + + async with storage: + await storage.mark_sources_ingested_batch(["a", "b"]) + async with storage: + await storage.mark_sources_ingested_batch(["b", "c"]) + + assert _ingested_count(storage) == 3 + assert await storage.is_source_ingested("a") + assert await storage.is_source_ingested("b") + assert await storage.is_source_ingested("c") + + await storage.close() + + +@pytest.mark.asyncio +async def test_streaming_extraction_with_empty_text_chunks() -> None: + """Messages with empty text_chunks skip extraction gracefully.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + extractor = ControlledExtractor() + transcript, storage = await _create_transcript( + db_path, auto_extract=True, knowledge_extractor=extractor + ) + + msgs = [ + TranscriptMessage( + text_chunks=[], + metadata=TranscriptMessageMeta(speaker="Alice"), + tags=["test"], + source_id="empty-chunks", + ), + _make_message("has content", source_id="has-content"), + ] + result = await transcript.add_messages_streaming(_async_iter(msgs)) + + assert result.messages_added == 2 + # Only the message with content triggers extraction + assert extractor.call_count == 1 + + await storage.close() diff --git a/tools/benchmark_semref_writes.py b/tools/benchmark_semref_writes.py new file mode 100644 index 00000000..d799cba1 --- /dev/null +++ b/tools/benchmark_semref_writes.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benchmark semref index write strategies: per-item vs batched. + +No API keys or network access required — uses synthetic knowledge data +and the deterministic test embedding model. + +The "individual" path inlines the pre-optimization logic (one append + +add_term per entity/action/topic) so results are comparable on any +branch without switching. + +Usage: + uv run python tools/benchmark_semref_writes.py + uv run python tools/benchmark_semref_writes.py --chunks 100 --rounds 20 +""" + +from __future__ import annotations + +import argparse +import asyncio +import os +import shutil +import statistics +import tempfile +import time + +from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.knowpro.interfaces_core import SemanticRef, Topic +from typeagent.storage.memory.semrefindex import ( + add_knowledge_batch_to_semantic_ref_index, + text_range_from_message_chunk, + validate_entity, + verify_has_semantic_ref_index, +) +from typeagent.storage.sqlite.provider import SqliteStorageProvider +from typeagent.transcripts.transcript import ( + Transcript, + TranscriptMessage, + TranscriptMessageMeta, +) + + +# --------------------------------------------------------------------------- +# Inlined pre-optimization write path (one append + add_term per item) +# --------------------------------------------------------------------------- + + +async def _individual_add_knowledge( + conversation, message_ordinal, chunk_ordinal, knowledge, +): + """Reproduces the pre-optimization per-item write logic.""" + verify_has_semantic_ref_index(conversation) + semantic_refs = conversation.semantic_refs + assert semantic_refs is not None + semantic_ref_index = conversation.semantic_ref_index + assert semantic_ref_index is not None + + for entity in knowledge.entities: + if not validate_entity(entity): + continue + ordinal = await semantic_refs.size() + await semantic_refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range_from_message_chunk(message_ordinal, chunk_ordinal), + knowledge=entity, + ) + ) + await semantic_ref_index.add_term(entity.name, ordinal) + for type_name in entity.type: + await semantic_ref_index.add_term(type_name, ordinal) + if entity.facets: + for facet in entity.facets: + if facet is not None: + await semantic_ref_index.add_term(facet.name, ordinal) + if facet.value is not None: + await semantic_ref_index.add_term(str(facet.value), ordinal) + + for action in list(knowledge.actions) + list(knowledge.inverse_actions): + ordinal = await semantic_refs.size() + await semantic_refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range_from_message_chunk(message_ordinal, chunk_ordinal), + knowledge=action, + ) + ) + await semantic_ref_index.add_term(" ".join(action.verbs), ordinal) + if action.subject_entity_name != "none": + await semantic_ref_index.add_term(action.subject_entity_name, ordinal) + if action.object_entity_name != "none": + await semantic_ref_index.add_term(action.object_entity_name, ordinal) + if action.indirect_object_entity_name != "none": + await semantic_ref_index.add_term(action.indirect_object_entity_name, ordinal) + if action.params: + for param in action.params: + if isinstance(param, str): + await semantic_ref_index.add_term(param, ordinal) + else: + await semantic_ref_index.add_term(param.name, ordinal) + if isinstance(param.value, str): + await semantic_ref_index.add_term(param.value, ordinal) + if action.subject_entity_facet is not None: + await semantic_ref_index.add_term(action.subject_entity_facet.name, ordinal) + if action.subject_entity_facet.value is not None: + await semantic_ref_index.add_term( + str(action.subject_entity_facet.value), ordinal + ) + + for topic_text in knowledge.topics: + ordinal = await semantic_refs.size() + await semantic_refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range_from_message_chunk(message_ordinal, chunk_ordinal), + knowledge=Topic(text=topic_text), + ) + ) + await semantic_ref_index.add_term(topic_text, ordinal) + + +# --------------------------------------------------------------------------- +# Synthetic data +# --------------------------------------------------------------------------- + + +def synthetic_knowledge(chunk_index: int) -> kplib.KnowledgeResponse: + return kplib.KnowledgeResponse( + entities=[ + kplib.ConcreteEntity( + name=f"entity_{chunk_index}_{j}", + type=[f"type_{j}", f"category_{chunk_index % 5}"], + facets=[ + kplib.Facet(name=f"facet_{j}", value=f"value_{j}") + for j in range(2) + ], + ) + for j in range(3) + ], + actions=[ + kplib.Action( + verbs=[f"verb_{chunk_index}"], + verb_tense="past", + subject_entity_name=f"entity_{chunk_index}_0", + object_entity_name=f"entity_{chunk_index}_1", + indirect_object_entity_name="none", + params=[f"param_{chunk_index}"], + ) + ], + inverse_actions=[], + topics=[f"topic_{chunk_index}", f"theme_{chunk_index % 3}"], + ) + + +def synthetic_messages(count: int) -> list[TranscriptMessage]: + return [ + TranscriptMessage( + text_chunks=[f"Message {i} about topic {i % 10}"], + metadata=TranscriptMessageMeta(speaker=f"Speaker{i % 3}"), + tags=[f"tag{i % 5}"], + ) + for i in range(count) + ] + + +# --------------------------------------------------------------------------- +# Benchmark harness +# --------------------------------------------------------------------------- + + +async def create_transcript(db_path: str) -> Transcript: + model = create_test_embedding_model() + settings = ConversationSettings(model=model) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + storage = SqliteStorageProvider( + db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings.message_text_index_settings, + related_term_index_settings=settings.related_term_index_settings, + ) + settings.storage_provider = storage + return await Transcript.create(settings, name="bench-semref") + + +async def bench_individual(transcript: Transcript, chunks: int) -> None: + for i in range(chunks): + await _individual_add_knowledge(transcript, i, 0, synthetic_knowledge(i)) + + +async def bench_batched(transcript: Transcript, chunks: int) -> None: + items = [(i, 0, synthetic_knowledge(i)) for i in range(chunks)] + await add_knowledge_batch_to_semantic_ref_index(transcript, items) + + +async def run_benchmark( + label: str, + factory, + chunks: int, + rounds: int, + warmup: int, +) -> list[float]: + samples_us: list[float] = [] + for r in range(warmup + rounds): + temp_dir = tempfile.mkdtemp(prefix="bench-semref-") + db_path = os.path.join(temp_dir, "bench.db") + try: + transcript = await create_transcript(db_path) + msgs = synthetic_messages(chunks) + await transcript.add_messages_with_indexing(msgs) + + start = time.perf_counter_ns() + await factory(transcript, chunks) + elapsed_us = (time.perf_counter_ns() - start) / 1_000 + + if r >= warmup: + samples_us.append(elapsed_us) + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + return samples_us + + +def print_report(label: str, samples_us: list[float], rounds: int, warmup: int) -> None: + print(f"\n{label}") + print(f" rounds: {rounds} ({warmup} warmup)") + print(f" min: {min(samples_us):12.1f} us") + print(f" mean: {statistics.fmean(samples_us):12.1f} us") + print(f" median: {statistics.median(samples_us):12.1f} us") + print(f" max: {max(samples_us):12.1f} us") + + +async def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark semref index write strategies.", + ) + parser.add_argument( + "--chunks", type=int, default=50, + help="Number of knowledge chunks to write per run (default: 50).", + ) + parser.add_argument( + "--rounds", type=int, default=10, + help="Number of timed rounds (default: 10).", + ) + parser.add_argument( + "--warmup", type=int, default=2, + help="Number of untimed warmup rounds (default: 2).", + ) + args = parser.parse_args() + + knowledge_sample = synthetic_knowledge(0) + refs_per_chunk = ( + len([e for e in knowledge_sample.entities if e.name]) + + len(knowledge_sample.actions) + + len(knowledge_sample.inverse_actions) + + len(knowledge_sample.topics) + ) + print(f"Chunks per run: {args.chunks}") + print(f"Semrefs per chunk: ~{refs_per_chunk}") + print(f"Total semrefs per run: ~{refs_per_chunk * args.chunks}") + + individual = await run_benchmark( + "Individual writes", bench_individual, + args.chunks, args.rounds, args.warmup, + ) + print_report( + "Individual writes (per-entity append + add_term)", + individual, args.rounds, args.warmup, + ) + + batched = await run_benchmark( + "Batched writes", bench_batched, + args.chunks, args.rounds, args.warmup, + ) + print_report( + "Batched writes (bulk extend + add_terms_batch)", + batched, args.rounds, args.warmup, + ) + + speedup = statistics.fmean(individual) / statistics.fmean(batched) + print(f"\nSpeedup: {speedup:.2f}x") + + +if __name__ == "__main__": + asyncio.run(main())