diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index d4712d10..673695d2 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -226,8 +226,8 @@ async def add_messages_streaming( nearly doubles throughput for multi-batch ingestions. **Source-ID tracking**: each message's ``source_id`` (if not ``None``) - is checked before ingestion. Already-ingested sources are silently - skipped. Newly ingested sources are marked within the same transaction. + is marked as ingested within the commit transaction. Callers are + responsible for filtering duplicates before yielding messages. **Extraction failures**: when knowledge extraction returns a ``Failure`` for a chunk, the failure is recorded via @@ -258,56 +258,39 @@ def _accumulate(result: AddMessagesResult) -> None: total.messages_added += result.messages_added total.semrefs_added += result.semrefs_added total.chunks_added += result.chunks_added - total.messages_skipped += result.messages_skipped if on_batch_committed: on_batch_committed(result) pending_commit: asyncio.Task[AddMessagesResult] | None = None pending_extraction: asyncio.Task[_ExtractionResult | None] | None = None - pending_skipped: int = 0 async def _drain_commit() -> None: - nonlocal pending_commit, pending_skipped + nonlocal pending_commit if pending_commit is not None: - result = await pending_commit - result.messages_skipped += pending_skipped - _accumulate(result) + _accumulate(await pending_commit) pending_commit = None - pending_skipped = 0 - async def _submit_batch(filtered: list[TMessage], skipped: int) -> None: - nonlocal pending_commit, pending_extraction, pending_skipped - if not filtered and not skipped: + async def _submit_batch(batch: list[TMessage]) -> None: + nonlocal pending_commit, pending_extraction + if not batch: return - if filtered and should_extract: + if should_extract: next_extraction = asyncio.create_task( - self._extract_knowledge_for_batch(filtered) + self._extract_knowledge_for_batch(batch) ) else: next_extraction = None pending_extraction = next_extraction - # Wait for previous commit to finish (frees the DB connection) await _drain_commit() - if not filtered: - # Nothing to commit, just report skipped - total.messages_skipped += skipped - if on_batch_committed: - on_batch_committed(AddMessagesResult(messages_skipped=skipped)) - return - - # Await extraction result for this batch extraction = await next_extraction if next_extraction is not None else None pending_extraction = None - # Start commit (DB transaction) — runs concurrently with the - # *next* batch's LLM extraction once we yield back to the loop. pending_commit = asyncio.create_task( - self._commit_batch_streaming(storage, filtered, extraction) + self._commit_batch_streaming(storage, batch, extraction) ) - pending_skipped = skipped try: batch: list[TMessage] = [] @@ -315,21 +298,18 @@ async def _submit_batch(filtered: list[TMessage], skipped: int) -> None: async for msg in messages: msg_chunks = len(msg.text_chunks) if batch and batch_chunks + msg_chunks > batch_size: - filtered, skipped = await self._filter_ingested(storage, batch) - await _submit_batch(filtered, skipped) + await _submit_batch(batch) batch = [] batch_chunks = 0 batch.append(msg) batch_chunks += msg_chunks if batch_chunks >= batch_size: - filtered, skipped = await self._filter_ingested(storage, batch) - await _submit_batch(filtered, skipped) + await _submit_batch(batch) batch = [] batch_chunks = 0 if batch: - filtered, skipped = await self._filter_ingested(storage, batch) - await _submit_batch(filtered, skipped) + await _submit_batch(batch) await _drain_commit() except BaseException: @@ -345,31 +325,6 @@ async def _submit_batch(filtered: list[TMessage], skipped: int) -> None: return total - async def _filter_ingested( - self, - storage: IStorageProvider[TMessage], - batch: list[TMessage], - ) -> tuple[list[TMessage], int]: - """Filter out messages whose source_id has already been ingested. - - Returns (filtered_messages, skipped_count). - - Uses a single batch query instead of per-message lookups. - """ - source_ids = [m.source_id for m in batch if m.source_id is not None] - if source_ids: - ingested = await storage.are_sources_ingested(source_ids) - else: - ingested = set[str]() - filtered: list[TMessage] = [] - skipped = 0 - for msg in batch: - if msg.source_id is not None and msg.source_id in ingested: - skipped += 1 - continue - filtered.append(msg) - return filtered, skipped - async def _extract_knowledge_for_batch( self, messages: list[TMessage], diff --git a/tests/test_add_messages_streaming.py b/tests/test_add_messages_streaming.py index 062d056e..9a8b1fd6 100644 --- a/tests/test_add_messages_streaming.py +++ b/tests/test_add_messages_streaming.py @@ -161,29 +161,6 @@ async def test_streaming_batching() -> None: await storage.close() -@pytest.mark.asyncio -async def test_streaming_skips_already_ingested() -> None: - """Messages whose source_id is already ingested are skipped.""" - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - transcript, storage = await _create_transcript(db_path) - - # Pre-mark some sources as ingested - async with storage: - await storage.mark_source_ingested("s-1") - await storage.mark_source_ingested("s-3") - - msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(5)] - result = await transcript.add_messages_streaming(_async_iter(msgs)) - - # s-1 and s-3 skipped -> only 3 added - assert result.messages_added == 3 - assert await transcript.messages.size() == 3 - assert _ingested_count(storage) == 5 # 2 pre-existing + 3 new - - await storage.close() - - @pytest.mark.asyncio async def test_streaming_no_source_id_always_ingested() -> None: """Messages without source_id are always ingested (never skipped).""" @@ -274,26 +251,6 @@ async def test_streaming_empty_iterable() -> None: @pytest.mark.asyncio -async def test_streaming_all_skipped_batch() -> None: - """A batch where all messages are already ingested produces no commit.""" - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - transcript, storage = await _create_transcript(db_path) - - # Pre-mark all sources - async with storage: - for i in range(3): - await storage.mark_source_ingested(f"s-{i}") - - msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(3)] - result = await transcript.add_messages_streaming(_async_iter(msgs)) - - assert result.messages_added == 0 - assert await transcript.messages.size() == 0 - - await storage.close() - - # --------------------------------------------------------------------------- # Pipeline overlap and DB batching tests # --------------------------------------------------------------------------- @@ -663,36 +620,6 @@ async def test_streaming_multi_chunk_exception_preserves_earlier_batch() -> None await storage.close() -@pytest.mark.asyncio -async def test_streaming_multi_chunk_skip_and_ingest_mixed() -> None: - """Multi-chunk messages are skipped or ingested as a whole based on source_id.""" - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - extractor = ControlledExtractor() - transcript, storage = await _create_transcript( - db_path, auto_extract=True, knowledge_extractor=extractor - ) - - # Pre-mark s-1 as ingested - async with storage: - await storage.mark_source_ingested("s-1") - - msgs = [ - _make_multi_chunk_message(["a", "b"], source_id="s-0"), # ingested - _make_multi_chunk_message(["c", "d", "e"], source_id="s-1"), # skipped - _make_message("f", source_id="s-2"), # ingested - ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) - - assert result.messages_added == 2 - assert result.messages_skipped == 1 - assert result.chunks_added == 3 # 2 + 1 (not 5) - # Only 3 extraction calls (2 chunks from s-0 + 1 chunk from s-2) - assert extractor.call_count == 3 - - await storage.close() - - @pytest.mark.asyncio async def test_streaming_batch_size_1_separates_all() -> None: """batch_size=1 commits every single-chunk message individually.""" @@ -714,38 +641,6 @@ async def test_streaming_batch_size_1_separates_all() -> None: await storage.close() -@pytest.mark.asyncio -async def test_streaming_callback_reports_skipped_multi_chunk() -> None: - """on_batch_committed reports skipped count for batches with multi-chunk messages.""" - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - transcript, storage = await _create_transcript(db_path) - - # Pre-mark s-1 as ingested - async with storage: - await storage.mark_source_ingested("s-1") - - msgs = [ - _make_multi_chunk_message(["a", "b"], source_id="s-0"), # 2 chunks - _make_multi_chunk_message(["c", "d"], source_id="s-1"), # 2 chunks, skipped - _make_message("e", source_id="s-2"), # 1 chunk → total = 5 chunks in batch - ] - callback_results: list[tuple[int, int]] = [] - result = await transcript.add_messages_streaming( - _async_iter(msgs), - batch_size=10, # all fit in one batch - on_batch_committed=lambda r: callback_results.append( - (r.messages_added, r.messages_skipped) - ), - ) - - assert result.messages_added == 2 - assert result.messages_skipped == 1 - assert callback_results == [(2, 1)] - - await storage.close() - - @pytest.mark.asyncio async def test_streaming_preflush_avoids_oversized_batch() -> None: """Adding a message that would exceed batch_size flushes first. @@ -779,48 +674,6 @@ async def test_streaming_preflush_avoids_oversized_batch() -> None: await storage.close() -@pytest.mark.asyncio -async def test_streaming_all_skipped_batch_after_real_batch() -> None: - """A batch of all-duplicates reports skipped correctly. - - First call ingests messages s-0..s-2. Second call re-submits the same - source_ids — they should all be filtered by _filter_ingested, exercising - the all-skipped + pending_commit path. - """ - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - transcript, storage = await _create_transcript(db_path) - - msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(3)] - - # First call — ingest originals - result1 = await transcript.add_messages_streaming( - _async_iter(msgs), - batch_size=3, - ) - assert result1.messages_added == 3 - assert result1.messages_skipped == 0 - - # Second call — all duplicates - dupes = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(3)] - batch_results: list[AddMessagesResult] = [] - result2 = await transcript.add_messages_streaming( - _async_iter(dupes), - batch_size=3, - on_batch_committed=lambda r: batch_results.append(r), - ) - - assert result2.messages_added == 0 - assert result2.messages_skipped == 3 - - # One callback for the all-skipped batch - assert len(batch_results) == 1 - assert batch_results[0].messages_added == 0 - assert batch_results[0].messages_skipped == 3 - - await storage.close() - - # --------------------------------------------------------------------------- # Coverage gap tests # --------------------------------------------------------------------------- @@ -932,8 +785,8 @@ async def slow_commit(*args, **kwargs): @pytest.mark.asyncio -async def test_streaming_empty_batch_after_filter() -> None: - """Streaming with an empty iterator after a real batch returns zeros.""" +async def test_streaming_empty_iterator() -> None: + """Streaming with an empty iterator returns zeros.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") transcript, storage = await _create_transcript(db_path) diff --git a/tools/ingest_email.py b/tools/ingest_email.py index fb62fe18..2a8619dd 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -342,9 +342,6 @@ async def _email_generator( skip_count = 0 for source_id, email_file, label in _iter_emails(eml_paths, verbose, offset, limit): - # Pre-parse dedup: skip before opening the file. - # A second dedup pass happens in _filter_ingested() to catch - # sources committed by an earlier batch in the same run. if await storage.is_source_ingested(source_id): counters["skipped"] += 1 basename = email_file.name @@ -451,7 +448,6 @@ async def ingest_emails( counters: dict[str, int] = { "parsed": 0, "skipped": 0, - "batch_skipped": 0, "date_skipped": 0, "failed": 0, "ingested": 0, @@ -463,7 +459,6 @@ async def ingest_emails( def on_batch_committed(result: AddMessagesResult) -> None: nonlocal last_batch_time counters["ingested"] += result.messages_added - counters["batch_skipped"] += result.messages_skipped counters["chunks"] += result.chunks_added counters["semrefs"] += result.semrefs_added counters["batches"] += 1 @@ -518,7 +513,7 @@ def on_batch_committed(result: AddMessagesResult) -> None: ) total_chunks = result.chunks_added if result is not None else counters["chunks"] semrefs_added = result.semrefs_added if result is not None else counters["semrefs"] - total_skipped = counters["skipped"] + counters["batch_skipped"] + total_skipped = counters["skipped"] overall_per_chunk = elapsed / total_chunks if total_chunks else 0 print()