Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ AGENTS.md. In all cases show what you added to AGENTS.md.
- Use `make test` to run all tests
- Use `make check test` to run `make check` and if it passes also run `make test`
- Use `make format` to format all files using `black`. Do this before reporting success.
- When validating changes, first run `pytest` only on new/modified test files, then run `make format check test` once at the end.
- Keep ad-hoc and performance benchmarks under `tools/`, not `tests/`, so `make test` does not run them.

## Package Management with uv
Expand All @@ -36,7 +37,7 @@ AGENTS.md. In all cases show what you added to AGENTS.md.
- uv maintains consistency between `pyproject.toml`, `uv.lock`, and installed packages
- Trust uv's automatic version resolution and file management

**IMPORTANT! YOU ARE NOT DONE UNTIL `make check test format` PASSES**
**IMPORTANT! YOU ARE NOT DONE UNTIL `make format check test` PASSES**

# Code generation

Expand Down
1 change: 1 addition & 0 deletions src/typeagent/emails/email_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(self, **data: Any) -> None:
)
timestamp: str | None = None # Use metadata.sent_on for the actual sent time
src_url: str | None = None # Source file or uri for this email
source_id: str | None = None # External source id (see IMessage.source_id)

def get_knowledge(self) -> kplib.KnowledgeResponse:
return self.metadata.get_knowledge()
Expand Down
227 changes: 207 additions & 20 deletions src/typeagent/knowpro/conversation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Base class for conversations with incremental indexing support."""

from collections.abc import AsyncIterable, Callable
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Generic, Self, TypeVar
Expand Down Expand Up @@ -36,6 +37,8 @@
MessageOrdinal,
Topic,
)
from .knowledge import extract_knowledge_from_text_batch
from .messageutils import get_all_message_chunk_locations

TMessage = TypeVar("TMessage", bound=IMessage)

Expand Down Expand Up @@ -132,9 +135,12 @@ async def add_messages_with_indexing(

Args:
messages: Messages to add
source_ids: Optional list of source IDs to mark as ingested. These are
marked within the same transaction, so if the indexing fails, the
source IDs won't be marked as ingested (for SQLite storage).
source_ids: Optional explicit list of source IDs to mark as ingested,
one per message. When ``None`` (the default), each message's
``source_id`` attribute is used instead — messages whose
``source_id`` is ``None`` are silently skipped. These are marked
within the same transaction, so if the indexing fails, the source
IDs won't be marked as ingested (for SQLite storage).

Returns:
Result with counts of messages/semrefs added
Expand All @@ -143,7 +149,7 @@ async def add_messages_with_indexing(
Exception: Any error
"""
storage = await self.settings.get_storage_provider()
if source_ids:
if source_ids is not None:
if len(source_ids) != len(messages):
raise ValueError(
f"Length of source_ids {len(source_ids)} "
Expand All @@ -152,9 +158,13 @@ async def add_messages_with_indexing(

async with storage:
# Mark source IDs as ingested (will be rolled back on error)
if source_ids:
for source_id in source_ids:
await storage.mark_source_ingested(source_id)
if source_ids is not None:
for sid in source_ids:
await storage.mark_source_ingested(sid)
else:
for msg in messages:
if msg.source_id is not None:
await storage.mark_source_ingested(msg.source_id)

start_points = IndexingStartPoints(
message_count=await self.messages.size(),
Expand All @@ -173,8 +183,11 @@ async def add_messages_with_indexing(

await self._update_secondary_indexes_incremental(start_points)

messages_added = await self.messages.size() - start_points.message_count
chunks_added = sum(len(m.text_chunks) for m in messages[:messages_added])
result = AddMessagesResult(
messages_added=await self.messages.size() - start_points.message_count,
messages_added=messages_added,
chunks_added=chunks_added,
semrefs_added=await self.semantic_refs.size()
- start_points.semref_count,
)
Expand All @@ -186,6 +199,184 @@ async def add_messages_with_indexing(

return result

async def add_messages_streaming(
self,
messages: AsyncIterable[TMessage],
*,
batch_size: int = 100,
on_batch_committed: Callable[[AddMessagesResult], None] | None = None,
) -> AddMessagesResult:
"""Add messages from an async iterable, committing in batches.

Unlike ``add_messages_with_indexing`` which processes all messages in a
single transaction, this method buffers messages into batches of
``batch_size``, processes each batch in its own transaction, and commits
after every batch. This makes it suitable for ingesting large streams
(millions of messages) where a single all-or-nothing transaction would
be impractical.

**Source-ID tracking**: each message's ``source_id`` (if not ``None``)
is checked before ingestion. Already-ingested sources are silently
skipped. Newly ingested sources are marked within the same transaction.

**Extraction failures**: when knowledge extraction returns a
``Failure`` for a chunk, the failure is recorded via
``storage.record_chunk_failure`` and processing continues with the
remaining chunks. Raised exceptions (HTTP errors, timeouts, etc.)
are treated as systemic and stop the run immediately — the current
batch is rolled back and the exception propagates.

Args:
messages: An async iterable of messages to ingest.
batch_size: Number of messages per commit batch.
on_batch_committed: Optional callback invoked after each batch is
committed, receiving the batch's ``AddMessagesResult``.

Returns:
Cumulative ``AddMessagesResult`` across all committed batches.
"""
storage = await self.settings.get_storage_provider()
total_messages_added = 0
total_semrefs_added = 0
total_chunks_added = 0

batch: list[TMessage] = []
async for msg in messages:
batch.append(msg)
if len(batch) >= batch_size:
result = await self._ingest_batch_streaming(storage, batch)
total_messages_added += result.messages_added
total_semrefs_added += result.semrefs_added
total_chunks_added += result.chunks_added
if on_batch_committed:
on_batch_committed(result)
batch = []

# Flush remaining messages
if batch:
result = await self._ingest_batch_streaming(storage, batch)
total_messages_added += result.messages_added
total_semrefs_added += result.semrefs_added
total_chunks_added += result.chunks_added
if on_batch_committed:
on_batch_committed(result)

return AddMessagesResult(
messages_added=total_messages_added,
chunks_added=total_chunks_added,
semrefs_added=total_semrefs_added,
)

async def _ingest_batch_streaming(
self,
storage: IStorageProvider[TMessage],
batch: list[TMessage],
) -> AddMessagesResult:
"""Process and commit a single batch within a transaction.

Messages whose ``source_id`` is already ingested are filtered out.
Extraction ``Failure``\\s are recorded as chunk failures.
"""
# Filter out already-ingested sources
filtered: list[TMessage] = []
for msg in batch:
if msg.source_id is not None and await storage.is_source_ingested(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_source_ingested called outside the transaction
In a multi-process scenario two workers could both pass the check and both ingest the same source.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe such a scenario is valid anyway, so let's not move the check into the transaction (it will just slow down other tasks that want to write to the db).

msg.source_id
):
continue
filtered.append(msg)

if not filtered:
return AddMessagesResult()

async with storage:
start_points = IndexingStartPoints(
message_count=await self.messages.size(),
semref_count=await self.semantic_refs.size(),
)

await self.messages.extend(filtered)

# Mark source IDs as ingested (rolled back on error)
for msg in filtered:
if msg.source_id is not None:
await storage.mark_source_ingested(msg.source_id)

await self._add_metadata_knowledge_incremental(start_points.message_count)

if self.settings.semantic_ref_index_settings.auto_extract_knowledge:
await self._add_llm_knowledge_streaming(
storage, filtered, start_points.message_count
)

await self._update_secondary_indexes_incremental(start_points)

await storage.update_conversation_timestamps(
updated_at=datetime.now(timezone.utc)
)

messages_added = await self.messages.size() - start_points.message_count
chunks_added = sum(len(m.text_chunks) for m in filtered[:messages_added])
return AddMessagesResult(
messages_added=messages_added,
chunks_added=chunks_added,
semrefs_added=await self.semantic_refs.size()
- start_points.semref_count,
)

async def _add_llm_knowledge_streaming(
self,
storage: IStorageProvider[TMessage],
messages: list[TMessage],
start_from_message_ordinal: int,
) -> None:
"""Extract LLM knowledge, recording failures instead of raising.

On ``Failure``: records a chunk failure via the storage provider and
continues. On a raised exception: lets it propagate (the caller's
``async with storage`` will roll back the transaction).
"""
settings = self.settings.semantic_ref_index_settings
knowledge_extractor = (
settings.knowledge_extractor or convknowledge.KnowledgeExtractor()
)

text_locations = get_all_message_chunk_locations(
messages, start_from_message_ordinal
)
if not text_locations:
return

start_ordinal = text_locations[0].message_ordinal
text_batch: list[str] = []
for tl in text_locations:
list_index = tl.message_ordinal - start_ordinal
text_batch.append(
messages[list_index].text_chunks[tl.chunk_ordinal].strip()
)

knowledge_results = await extract_knowledge_from_text_batch(
knowledge_extractor,
text_batch,
settings.concurrency,
)
for i, knowledge_result in enumerate(knowledge_results):
tl = text_locations[i]
if isinstance(knowledge_result, typechat.Failure):
await storage.record_chunk_failure(
tl.message_ordinal,
tl.chunk_ordinal,
type(knowledge_result).__name__,
knowledge_result.message[:500],
)
continue
await semrefindex.add_knowledge_to_semantic_ref_index(
self,
tl.message_ordinal,
tl.chunk_ordinal,
knowledge_result.value,
)

async def _add_metadata_knowledge_incremental(
self,
start_from_message_ordinal: int,
Expand Down Expand Up @@ -216,21 +407,17 @@ async def _add_llm_knowledge_incremental(
settings.knowledge_extractor or convknowledge.KnowledgeExtractor()
)

# Get batches of text locations from the message list
from .messageutils import get_message_chunk_batch_from_list

batches = get_message_chunk_batch_from_list(
text_locations = get_all_message_chunk_locations(
messages,
start_from_message_ordinal,
settings.batch_size,
)
for text_location_batch in batches:
await semrefindex.add_batch_to_semantic_ref_index_from_list(
self,
messages,
text_location_batch,
knowledge_extractor,
)
await semrefindex.add_batch_to_semantic_ref_index_from_list(
self,
messages,
text_locations,
knowledge_extractor,
concurrency=settings.concurrency,
)

async def _update_secondary_indexes_incremental(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/typeagent/knowpro/convsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, embedding_index_settings: TextEmbeddingIndexSettings):

@dataclass
class SemanticRefIndexSettings:
batch_size: int
concurrency: int
auto_extract_knowledge: bool
knowledge_extractor: IKnowledgeExtractor | None = None

Expand All @@ -54,7 +54,7 @@ def __init__(
TextEmbeddingIndexSettings(model, min_score=0.7)
)
self.semantic_ref_index_settings = SemanticRefIndexSettings(
batch_size=4, # Effectively max concurrency
concurrency=4,
Comment thread
gvanrossum marked this conversation as resolved.
auto_extract_knowledge=True, # The high-level API wants this
)

Expand Down
11 changes: 9 additions & 2 deletions src/typeagent/knowpro/interfaces_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ class IndexingStartPoints:
class AddMessagesResult:
"""Result of add_messages_with_indexing operation."""

messages_added: int
semrefs_added: int
messages_added: int = 0
chunks_added: int = 0
semrefs_added: int = 0


# Messages are referenced by their sequential ordinal numbers.
Expand Down Expand Up @@ -129,6 +130,12 @@ class IMessage[TMetadata: IMessageMetadata](IKnowledgeSource, Protocol):
# Metadata associated with the message such as its source.
metadata: TMetadata | None = None

# Optional external identifier of the source this message was ingested from
# (e.g., an email ID, a file path, a URL). Used by ingestion pipelines to
# detect already-ingested sources for restartability. None means the message
# is not associated with an external source (e.g., synthesized in tests).
source_id: str | None = None


# Semantic references are also ordinal.
type SemanticRefOrdinal = int
Expand Down
Loading
Loading