Skip to content

Commit 7d66982

Browse files
authored
Fix streaming skip accounting and next_extraction task leak (#266)
## Summary Follow-up to #265. Fixes two issues identified during review: - **Skip accounting gap.** `ingest_email.py` only reported generator-level skips (`counters["skipped"]`). Batch-level skips from `_filter_ingested` (`result.messages_skipped`) were never surfaced in the final summary. The two populations are disjoint — a source caught by the generator never reaches the batch layer. The summary now reports `total_skipped = counters["skipped"] + counters["batch_skipped"]`. Survives ^C since `on_batch_committed` fires per committed batch. - **`next_extraction` task leak.** In `_submit_batch`, if `_drain_commit()` raises after `next_extraction = asyncio.create_task(...)` but before it's awaited, the task leaked. Promoted `next_extraction` to a `nonlocal` (`pending_extraction`) so the `except BaseException` block can cancel it alongside `pending_commit`. ## Changes - `tools/ingest_email.py`: Add `batch_skipped` counter, track `result.messages_skipped` in `on_batch_committed`, report combined total in summary - `src/typeagent/knowpro/conversation_base.py`: Track `pending_extraction`, cancel in except block - `tests/test_add_messages_streaming.py`: 4 new tests covering both cancellation paths and edge cases ## Test plan - [x] `make format check test` passes (701 tests) - [x] Coverage for `conversation_base.py`: 94% → 96% - [x] New test: `pending_extraction` cancelled when prior commit raises during `_drain_commit` - [x] New test: `pending_commit` cancelled when message iterator raises - [x] New test: empty iterator returns zeros - [x] New test: messages with empty `text_chunks` skip extraction entirely
1 parent b40c125 commit 7d66982

3 files changed

Lines changed: 184 additions & 5 deletions

File tree

src/typeagent/knowpro/conversation_base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def _accumulate(result: AddMessagesResult) -> None:
263263
on_batch_committed(result)
264264

265265
pending_commit: asyncio.Task[AddMessagesResult] | None = None
266+
pending_extraction: asyncio.Task[_ExtractionResult | None] | None = None
266267
pending_skipped: int = 0
267268

268269
async def _drain_commit() -> None:
@@ -275,7 +276,7 @@ async def _drain_commit() -> None:
275276
pending_skipped = 0
276277

277278
async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:
278-
nonlocal pending_commit, pending_skipped
279+
nonlocal pending_commit, pending_extraction, pending_skipped
279280
if not filtered and not skipped:
280281
return
281282

@@ -285,6 +286,7 @@ async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:
285286
)
286287
else:
287288
next_extraction = None
289+
pending_extraction = next_extraction
288290

289291
# Wait for previous commit to finish (frees the DB connection)
290292
await _drain_commit()
@@ -298,6 +300,7 @@ async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:
298300

299301
# Await extraction result for this batch
300302
extraction = await next_extraction if next_extraction is not None else None
303+
pending_extraction = None
301304

302305
# Start commit (DB transaction) — runs concurrently with the
303306
# *next* batch's LLM extraction once we yield back to the loop.
@@ -330,6 +333,10 @@ async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:
330333

331334
await _drain_commit()
332335
except BaseException:
336+
if pending_extraction is not None and not pending_extraction.done():
337+
pending_extraction.cancel()
338+
with contextlib.suppress(asyncio.CancelledError):
339+
await pending_extraction
333340
if pending_commit is not None and not pending_commit.done():
334341
pending_commit.cancel()
335342
with contextlib.suppress(asyncio.CancelledError):

tests/test_add_messages_streaming.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Tests for add_messages_streaming."""
55

6+
import asyncio
67
from collections.abc import AsyncIterator
78
import os
89
import tempfile
@@ -818,3 +819,171 @@ async def test_streaming_all_skipped_batch_after_real_batch() -> None:
818819
assert batch_results[0].messages_skipped == 3
819820

820821
await storage.close()
822+
823+
824+
# ---------------------------------------------------------------------------
825+
# Coverage gap tests
826+
# ---------------------------------------------------------------------------
827+
828+
829+
class SlowExtractor:
830+
"""Extractor that blocks on an event, allowing tests to control timing."""
831+
832+
def __init__(self, block_from: int) -> None:
833+
self.call_count = 0
834+
self.block_from = block_from
835+
self.blocked = asyncio.Event()
836+
self.cancelled = False
837+
838+
async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]:
839+
idx = self.call_count
840+
self.call_count += 1
841+
if idx >= self.block_from:
842+
self.blocked.set()
843+
try:
844+
await asyncio.sleep(60)
845+
except asyncio.CancelledError:
846+
self.cancelled = True
847+
raise
848+
return typechat.Success(_EMPTY_RESPONSE)
849+
850+
851+
@pytest.mark.asyncio
852+
async def test_streaming_pending_extraction_cancelled_on_commit_failure() -> None:
853+
"""pending_extraction is cancelled when a prior commit raises during _drain_commit.
854+
855+
Timeline:
856+
1. Batch 0: extraction succeeds (calls 0-2, fast), commit task created
857+
(pending_commit = failing_commit)
858+
2. Batch 1: extraction task created (pending_extraction, calls 3+, slow),
859+
_drain_commit awaits batch 0's pending_commit which raises
860+
3. except block: pending_extraction (batch 1's) is still in-flight → cancelled
861+
"""
862+
with tempfile.TemporaryDirectory() as tmpdir:
863+
db_path = os.path.join(tmpdir, "test.db")
864+
# Block extraction starting from call 3 (first call of batch 1)
865+
# so that pending_extraction is still running when the except fires
866+
extractor = SlowExtractor(block_from=3)
867+
transcript, storage = await _create_transcript(
868+
db_path, auto_extract=True, knowledge_extractor=extractor
869+
)
870+
871+
async def failing_commit(*args, **kwargs):
872+
raise RuntimeError("Simulated commit failure")
873+
874+
transcript._commit_batch_streaming = failing_commit # type: ignore[assignment]
875+
876+
msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)]
877+
878+
with pytest.raises(RuntimeError, match="Simulated commit failure"):
879+
await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3)
880+
881+
assert extractor.cancelled
882+
883+
await storage.close()
884+
885+
886+
@pytest.mark.asyncio
887+
async def test_streaming_pending_commit_cancelled_on_iterator_error() -> None:
888+
"""pending_commit is cancelled when the message iterator raises.
889+
890+
After batch 0 is submitted (pending_commit in flight), the async iterator
891+
raises on the next message. The except block must cancel the still-running
892+
pending_commit.
893+
"""
894+
895+
async def _error_after(
896+
items: list[TranscriptMessage], error_after: int
897+
) -> AsyncIterator[TranscriptMessage]:
898+
for i, item in enumerate(items):
899+
if i == error_after:
900+
# Yield to event loop so pending tasks start running
901+
await asyncio.sleep(0)
902+
raise ValueError("Iterator error")
903+
yield item
904+
905+
with tempfile.TemporaryDirectory() as tmpdir:
906+
db_path = os.path.join(tmpdir, "test.db")
907+
transcript, storage = await _create_transcript(db_path)
908+
909+
commit_cancelled = False
910+
911+
async def slow_commit(*args, **kwargs):
912+
nonlocal commit_cancelled
913+
try:
914+
await asyncio.sleep(60)
915+
except asyncio.CancelledError:
916+
commit_cancelled = True
917+
raise
918+
return AddMessagesResult()
919+
920+
transcript._commit_batch_streaming = slow_commit # type: ignore[assignment]
921+
922+
msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)]
923+
924+
with pytest.raises(ValueError, match="Iterator error"):
925+
await transcript.add_messages_streaming(
926+
_error_after(msgs, error_after=4), batch_size=3
927+
)
928+
929+
assert commit_cancelled
930+
931+
await storage.close()
932+
933+
934+
@pytest.mark.asyncio
935+
async def test_streaming_empty_batch_after_filter() -> None:
936+
"""Streaming with an empty iterator after a real batch returns zeros."""
937+
with tempfile.TemporaryDirectory() as tmpdir:
938+
db_path = os.path.join(tmpdir, "test.db")
939+
transcript, storage = await _create_transcript(db_path)
940+
941+
# Ingest one real message, then do a second call with an empty iterator
942+
msgs = [_make_message("msg-0", source_id="s-0")]
943+
r1 = await transcript.add_messages_streaming(_async_iter(msgs))
944+
assert r1.messages_added == 1
945+
946+
# Empty iterator → _submit_batch never called with content
947+
r2 = await transcript.add_messages_streaming(_async_iter([]))
948+
assert r2.messages_added == 0
949+
assert r2.messages_skipped == 0
950+
951+
await storage.close()
952+
953+
954+
@pytest.mark.asyncio
955+
async def test_streaming_extraction_returns_none_for_empty_chunks() -> None:
956+
"""_extract_knowledge_for_batch returns None when no text_locations exist.
957+
958+
Messages with empty text_chunks produce no TextLocations, so extraction
959+
should be skipped entirely.
960+
"""
961+
with tempfile.TemporaryDirectory() as tmpdir:
962+
db_path = os.path.join(tmpdir, "test.db")
963+
extractor = ControlledExtractor()
964+
transcript, storage = await _create_transcript(
965+
db_path, auto_extract=True, knowledge_extractor=extractor
966+
)
967+
968+
msgs = [
969+
TranscriptMessage(
970+
text_chunks=[],
971+
metadata=TranscriptMessageMeta(speaker="Alice"),
972+
tags=["test"],
973+
source_id="empty-0",
974+
),
975+
TranscriptMessage(
976+
text_chunks=[],
977+
metadata=TranscriptMessageMeta(speaker="Bob"),
978+
tags=["test"],
979+
source_id="empty-1",
980+
),
981+
]
982+
result = await transcript.add_messages_streaming(_async_iter(msgs))
983+
984+
assert result.messages_added == 2
985+
assert result.chunks_added == 0
986+
# No extraction calls since there are no chunks
987+
assert extractor.call_count == 0
988+
989+
await storage.close()

tools/ingest_email.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ async def ingest_emails(
451451
counters: dict[str, int] = {
452452
"parsed": 0,
453453
"skipped": 0,
454+
"batch_skipped": 0,
454455
"date_skipped": 0,
455456
"failed": 0,
456457
"ingested": 0,
@@ -462,6 +463,7 @@ async def ingest_emails(
462463
def on_batch_committed(result: AddMessagesResult) -> None:
463464
nonlocal last_batch_time
464465
counters["ingested"] += result.messages_added
466+
counters["batch_skipped"] += result.messages_skipped
465467
counters["chunks"] += result.chunks_added
466468
counters["semrefs"] += result.semrefs_added
467469
counters["batches"] += 1
@@ -516,6 +518,7 @@ def on_batch_committed(result: AddMessagesResult) -> None:
516518
)
517519
total_chunks = result.chunks_added if result is not None else counters["chunks"]
518520
semrefs_added = result.semrefs_added if result is not None else counters["semrefs"]
521+
total_skipped = counters["skipped"] + counters["batch_skipped"]
519522
overall_per_chunk = elapsed / total_chunks if total_chunks else 0
520523

521524
print()
@@ -524,8 +527,8 @@ def on_batch_committed(result: AddMessagesResult) -> None:
524527
print("Ingestion interrupted by user (^C).")
525528
print(f"Successfully ingested {messages_ingested} email(s)")
526529
print(f"Ingested {total_chunks} chunk(s)")
527-
if counters["skipped"]:
528-
print(f"Skipped {counters['skipped']} already-ingested email(s)")
530+
if total_skipped:
531+
print(f"Skipped {total_skipped} already-ingested email(s)")
529532
if counters["date_skipped"]:
530533
print(f"Skipped {counters['date_skipped']} email(s) outside date range")
531534
if counters["failed"]:
@@ -539,8 +542,8 @@ def on_batch_committed(result: AddMessagesResult) -> None:
539542
f"({total_chunks} chunks, {semrefs_added} refs added, {elapsed:.1f}s, "
540543
f"{overall_per_chunk:.2f}s/chunk)"
541544
)
542-
if counters["skipped"]:
543-
print(f"Skipped: {counters['skipped']} (already ingested)")
545+
if total_skipped:
546+
print(f"Skipped: {total_skipped} (already ingested)")
544547
if counters["date_skipped"]:
545548
print(f"Skipped: {counters['date_skipped']} (outside date range)")
546549
if counters["failed"]:

0 commit comments

Comments
 (0)