Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 13 additions & 58 deletions src/typeagent/knowpro/conversation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ async def add_messages_streaming(
nearly doubles throughput for multi-batch ingestions.

**Source-ID tracking**: each message's ``source_id`` (if not ``None``)
is checked before ingestion. Already-ingested sources are silently
skipped. Newly ingested sources are marked within the same transaction.
is marked as ingested within the commit transaction. Callers are
responsible for filtering duplicates before yielding messages.

**Extraction failures**: when knowledge extraction returns a
``Failure`` for a chunk, the failure is recorded via
Expand Down Expand Up @@ -258,78 +258,58 @@ def _accumulate(result: AddMessagesResult) -> None:
total.messages_added += result.messages_added
total.semrefs_added += result.semrefs_added
total.chunks_added += result.chunks_added
total.messages_skipped += result.messages_skipped
if on_batch_committed:
on_batch_committed(result)

pending_commit: asyncio.Task[AddMessagesResult] | None = None
pending_extraction: asyncio.Task[_ExtractionResult | None] | None = None
pending_skipped: int = 0

async def _drain_commit() -> None:
nonlocal pending_commit, pending_skipped
nonlocal pending_commit
if pending_commit is not None:
result = await pending_commit
result.messages_skipped += pending_skipped
_accumulate(result)
_accumulate(await pending_commit)
pending_commit = None
pending_skipped = 0

async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:
nonlocal pending_commit, pending_extraction, pending_skipped
if not filtered and not skipped:
async def _submit_batch(batch: list[TMessage]) -> None:
nonlocal pending_commit, pending_extraction
if not batch:
return

if filtered and should_extract:
if should_extract:
next_extraction = asyncio.create_task(
self._extract_knowledge_for_batch(filtered)
self._extract_knowledge_for_batch(batch)
)
else:
next_extraction = None
pending_extraction = next_extraction

# Wait for previous commit to finish (frees the DB connection)
await _drain_commit()

if not filtered:
# Nothing to commit, just report skipped
total.messages_skipped += skipped
if on_batch_committed:
on_batch_committed(AddMessagesResult(messages_skipped=skipped))
return

# Await extraction result for this batch
extraction = await next_extraction if next_extraction is not None else None
pending_extraction = None

# Start commit (DB transaction) — runs concurrently with the
# *next* batch's LLM extraction once we yield back to the loop.
pending_commit = asyncio.create_task(
self._commit_batch_streaming(storage, filtered, extraction)
self._commit_batch_streaming(storage, batch, extraction)
)
pending_skipped = skipped

try:
batch: list[TMessage] = []
batch_chunks = 0
async for msg in messages:
msg_chunks = len(msg.text_chunks)
if batch and batch_chunks + msg_chunks > batch_size:
filtered, skipped = await self._filter_ingested(storage, batch)
await _submit_batch(filtered, skipped)
await _submit_batch(batch)
batch = []
batch_chunks = 0
batch.append(msg)
batch_chunks += msg_chunks
if batch_chunks >= batch_size:
filtered, skipped = await self._filter_ingested(storage, batch)
await _submit_batch(filtered, skipped)
await _submit_batch(batch)
batch = []
batch_chunks = 0

if batch:
filtered, skipped = await self._filter_ingested(storage, batch)
await _submit_batch(filtered, skipped)
await _submit_batch(batch)

await _drain_commit()
except BaseException:
Expand All @@ -345,31 +325,6 @@ async def _submit_batch(filtered: list[TMessage], skipped: int) -> None:

return total

async def _filter_ingested(
self,
storage: IStorageProvider[TMessage],
batch: list[TMessage],
) -> tuple[list[TMessage], int]:
"""Filter out messages whose source_id has already been ingested.

Returns (filtered_messages, skipped_count).

Uses a single batch query instead of per-message lookups.
"""
source_ids = [m.source_id for m in batch if m.source_id is not None]
if source_ids:
ingested = await storage.are_sources_ingested(source_ids)
else:
ingested = set[str]()
filtered: list[TMessage] = []
skipped = 0
for msg in batch:
if msg.source_id is not None and msg.source_id in ingested:
skipped += 1
continue
filtered.append(msg)
return filtered, skipped

async def _extract_knowledge_for_batch(
self,
messages: list[TMessage],
Expand Down
151 changes: 2 additions & 149 deletions tests/test_add_messages_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,29 +161,6 @@ async def test_streaming_batching() -> None:
await storage.close()


@pytest.mark.asyncio
async def test_streaming_skips_already_ingested() -> None:
"""Messages whose source_id is already ingested are skipped."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
transcript, storage = await _create_transcript(db_path)

# Pre-mark some sources as ingested
async with storage:
await storage.mark_source_ingested("s-1")
await storage.mark_source_ingested("s-3")

msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(5)]
result = await transcript.add_messages_streaming(_async_iter(msgs))

# s-1 and s-3 skipped -> only 3 added
assert result.messages_added == 3
assert await transcript.messages.size() == 3
assert _ingested_count(storage) == 5 # 2 pre-existing + 3 new

await storage.close()


@pytest.mark.asyncio
async def test_streaming_no_source_id_always_ingested() -> None:
"""Messages without source_id are always ingested (never skipped)."""
Expand Down Expand Up @@ -274,26 +251,6 @@ async def test_streaming_empty_iterable() -> None:


@pytest.mark.asyncio
async def test_streaming_all_skipped_batch() -> None:
"""A batch where all messages are already ingested produces no commit."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
transcript, storage = await _create_transcript(db_path)

# Pre-mark all sources
async with storage:
for i in range(3):
await storage.mark_source_ingested(f"s-{i}")

msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(3)]
result = await transcript.add_messages_streaming(_async_iter(msgs))

assert result.messages_added == 0
assert await transcript.messages.size() == 0

await storage.close()


# ---------------------------------------------------------------------------
# Pipeline overlap and DB batching tests
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -663,36 +620,6 @@ async def test_streaming_multi_chunk_exception_preserves_earlier_batch() -> None
await storage.close()


@pytest.mark.asyncio
async def test_streaming_multi_chunk_skip_and_ingest_mixed() -> None:
"""Multi-chunk messages are skipped or ingested as a whole based on source_id."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
extractor = ControlledExtractor()
transcript, storage = await _create_transcript(
db_path, auto_extract=True, knowledge_extractor=extractor
)

# Pre-mark s-1 as ingested
async with storage:
await storage.mark_source_ingested("s-1")

msgs = [
_make_multi_chunk_message(["a", "b"], source_id="s-0"), # ingested
_make_multi_chunk_message(["c", "d", "e"], source_id="s-1"), # skipped
_make_message("f", source_id="s-2"), # ingested
]
result = await transcript.add_messages_streaming(_async_iter(msgs))

assert result.messages_added == 2
assert result.messages_skipped == 1
assert result.chunks_added == 3 # 2 + 1 (not 5)
# Only 3 extraction calls (2 chunks from s-0 + 1 chunk from s-2)
assert extractor.call_count == 3

await storage.close()


@pytest.mark.asyncio
async def test_streaming_batch_size_1_separates_all() -> None:
"""batch_size=1 commits every single-chunk message individually."""
Expand All @@ -714,38 +641,6 @@ async def test_streaming_batch_size_1_separates_all() -> None:
await storage.close()


@pytest.mark.asyncio
async def test_streaming_callback_reports_skipped_multi_chunk() -> None:
"""on_batch_committed reports skipped count for batches with multi-chunk messages."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
transcript, storage = await _create_transcript(db_path)

# Pre-mark s-1 as ingested
async with storage:
await storage.mark_source_ingested("s-1")

msgs = [
_make_multi_chunk_message(["a", "b"], source_id="s-0"), # 2 chunks
_make_multi_chunk_message(["c", "d"], source_id="s-1"), # 2 chunks, skipped
_make_message("e", source_id="s-2"), # 1 chunk → total = 5 chunks in batch
]
callback_results: list[tuple[int, int]] = []
result = await transcript.add_messages_streaming(
_async_iter(msgs),
batch_size=10, # all fit in one batch
on_batch_committed=lambda r: callback_results.append(
(r.messages_added, r.messages_skipped)
),
)

assert result.messages_added == 2
assert result.messages_skipped == 1
assert callback_results == [(2, 1)]

await storage.close()


@pytest.mark.asyncio
async def test_streaming_preflush_avoids_oversized_batch() -> None:
"""Adding a message that would exceed batch_size flushes first.
Expand Down Expand Up @@ -779,48 +674,6 @@ async def test_streaming_preflush_avoids_oversized_batch() -> None:
await storage.close()


@pytest.mark.asyncio
async def test_streaming_all_skipped_batch_after_real_batch() -> None:
"""A batch of all-duplicates reports skipped correctly.

First call ingests messages s-0..s-2. Second call re-submits the same
source_ids — they should all be filtered by _filter_ingested, exercising
the all-skipped + pending_commit path.
"""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
transcript, storage = await _create_transcript(db_path)

msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(3)]

# First call — ingest originals
result1 = await transcript.add_messages_streaming(
_async_iter(msgs),
batch_size=3,
)
assert result1.messages_added == 3
assert result1.messages_skipped == 0

# Second call — all duplicates
dupes = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(3)]
batch_results: list[AddMessagesResult] = []
result2 = await transcript.add_messages_streaming(
_async_iter(dupes),
batch_size=3,
on_batch_committed=lambda r: batch_results.append(r),
)

assert result2.messages_added == 0
assert result2.messages_skipped == 3

# One callback for the all-skipped batch
assert len(batch_results) == 1
assert batch_results[0].messages_added == 0
assert batch_results[0].messages_skipped == 3

await storage.close()


# ---------------------------------------------------------------------------
# Coverage gap tests
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -932,8 +785,8 @@ async def slow_commit(*args, **kwargs):


@pytest.mark.asyncio
async def test_streaming_empty_batch_after_filter() -> None:
"""Streaming with an empty iterator after a real batch returns zeros."""
async def test_streaming_empty_iterator() -> None:
"""Streaming with an empty iterator returns zeros."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test.db")
transcript, storage = await _create_transcript(db_path)
Expand Down
7 changes: 1 addition & 6 deletions tools/ingest_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,6 @@ async def _email_generator(
skip_count = 0

for source_id, email_file, label in _iter_emails(eml_paths, verbose, offset, limit):
# Pre-parse dedup: skip before opening the file.
# A second dedup pass happens in _filter_ingested() to catch
# sources committed by an earlier batch in the same run.
if await storage.is_source_ingested(source_id):
counters["skipped"] += 1
basename = email_file.name
Expand Down Expand Up @@ -451,7 +448,6 @@ async def ingest_emails(
counters: dict[str, int] = {
"parsed": 0,
"skipped": 0,
"batch_skipped": 0,
"date_skipped": 0,
"failed": 0,
"ingested": 0,
Expand All @@ -463,7 +459,6 @@ async def ingest_emails(
def on_batch_committed(result: AddMessagesResult) -> None:
nonlocal last_batch_time
counters["ingested"] += result.messages_added
counters["batch_skipped"] += result.messages_skipped
counters["chunks"] += result.chunks_added
counters["semrefs"] += result.semrefs_added
counters["batches"] += 1
Expand Down Expand Up @@ -518,7 +513,7 @@ def on_batch_committed(result: AddMessagesResult) -> None:
)
total_chunks = result.chunks_added if result is not None else counters["chunks"]
semrefs_added = result.semrefs_added if result is not None else counters["semrefs"]
total_skipped = counters["skipped"] + counters["batch_skipped"]
total_skipped = counters["skipped"]
overall_per_chunk = elapsed / total_chunks if total_chunks else 0

print()
Expand Down
Loading