diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 3f7a999b..fcc5e667 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -4,8 +4,8 @@ """Base class for conversations with incremental indexing support.""" import asyncio -import contextlib from collections.abc import AsyncIterable, Callable, Sequence +import contextlib from dataclasses import dataclass from datetime import datetime, timezone from typing import Generic, Self, TypeVar @@ -282,9 +282,7 @@ async def _submit_batch(filtered: list[TMessage]) -> None: await _drain_commit() # Await extraction result for this batch - extraction = ( - await next_extraction if next_extraction is not None else None - ) + extraction = await next_extraction if next_extraction is not None else None # Start commit (DB transaction) — runs concurrently with the # *next* batch's LLM extraction once we yield back to the loop. diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index 69465f0a..9dedefbe 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -7,7 +7,6 @@ from typechat import Result -from . import convknowledge from . import knowledge_schema as kplib from .interfaces import IKnowledgeExtractor diff --git a/src/typeagent/storage/memory/semrefindex.py b/src/typeagent/storage/memory/semrefindex.py index 84892171..76feb2d1 100644 --- a/src/typeagent/storage/memory/semrefindex.py +++ b/src/typeagent/storage/memory/semrefindex.py @@ -76,9 +76,7 @@ async def add_batch_to_semantic_ref_index[ (tl.message_ordinal, tl.chunk_ordinal, knowledge_result.value) ) if bulk_items: - await add_knowledge_batch_to_semantic_ref_index( - conversation, bulk_items - ) + await add_knowledge_batch_to_semantic_ref_index(conversation, bulk_items) async def add_batch_to_semantic_ref_index_from_list[ @@ -103,9 +101,7 @@ async def add_batch_to_semantic_ref_index_from_list[ f"Message ordinal {tl.message_ordinal} out of range " f"for list starting at {start_ordinal}" ) - text_batch.append( - messages[list_index].text_chunks[tl.chunk_ordinal].strip() - ) + text_batch.append(messages[list_index].text_chunks[tl.chunk_ordinal].strip()) knowledge_results = await extract_knowledge_from_text_batch( knowledge_extractor, @@ -123,9 +119,7 @@ async def add_batch_to_semantic_ref_index_from_list[ (tl.message_ordinal, tl.chunk_ordinal, knowledge_result.value) ) if bulk_items: - await add_knowledge_batch_to_semantic_ref_index( - conversation, bulk_items - ) + await add_knowledge_batch_to_semantic_ref_index(conversation, bulk_items) async def add_term_to_index( @@ -360,11 +354,13 @@ def _collect_knowledge_refs_and_terms( for entity in knowledge.entities: if not validate_entity(entity): continue - refs.append(SemanticRef( - semantic_ref_ordinal=ordinal, - range=text_range, - knowledge=entity, - )) + refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range, + knowledge=entity, + ) + ) terms.append((entity.name, ordinal)) for type_name in entity.type: terms.append((type_name, ordinal)) @@ -377,11 +373,13 @@ def _collect_knowledge_refs_and_terms( ordinal += 1 for action in list(knowledge.actions) + list(knowledge.inverse_actions): - refs.append(SemanticRef( - semantic_ref_ordinal=ordinal, - range=text_range, - knowledge=action, - )) + refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range, + knowledge=action, + ) + ) terms.append((" ".join(action.verbs), ordinal)) if action.subject_entity_name != "none": terms.append((action.subject_entity_name, ordinal)) @@ -404,11 +402,13 @@ def _collect_knowledge_refs_and_terms( ordinal += 1 for topic_text in knowledge.topics: - refs.append(SemanticRef( - semantic_ref_ordinal=ordinal, - range=text_range, - knowledge=Topic(text=topic_text), - )) + refs.append( + SemanticRef( + semantic_ref_ordinal=ordinal, + range=text_range, + knowledge=Topic(text=topic_text), + ) + ) terms.append((topic_text, ordinal)) ordinal += 1 @@ -431,7 +431,10 @@ async def add_knowledge_to_semantic_ref_index( base_ordinal = await semantic_refs.size() refs, terms = _collect_knowledge_refs_and_terms( - base_ordinal, message_ordinal, chunk_ordinal, knowledge, + base_ordinal, + message_ordinal, + chunk_ordinal, + knowledge, ) if refs: @@ -460,7 +463,10 @@ async def add_knowledge_batch_to_semantic_ref_index( for msg_ord, chunk_ord, knowledge in items: refs, terms = _collect_knowledge_refs_and_terms( - base_ordinal + len(all_refs), msg_ord, chunk_ord, knowledge, + base_ordinal + len(all_refs), + msg_ord, + chunk_ord, + knowledge, ) all_refs.extend(refs) all_terms.extend(terms) diff --git a/tests/test_add_messages_streaming.py b/tests/test_add_messages_streaming.py index dc3f55b4..5707f71b 100644 --- a/tests/test_add_messages_streaming.py +++ b/tests/test_add_messages_streaming.py @@ -384,9 +384,7 @@ async def test_streaming_exception_in_later_batch_preserves_earlier() -> None: msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] with pytest.raises(ExceptionGroup) as exc_info: - await transcript.add_messages_streaming( - _async_iter(msgs), batch_size=3 - ) + await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) assert any( isinstance(e, RuntimeError) and "Systemic failure" in str(e) diff --git a/tools/benchmark_semref_writes.py b/tools/benchmark_semref_writes.py index d799cba1..3162f495 100644 --- a/tools/benchmark_semref_writes.py +++ b/tools/benchmark_semref_writes.py @@ -43,14 +43,16 @@ TranscriptMessageMeta, ) - # --------------------------------------------------------------------------- # Inlined pre-optimization write path (one append + add_term per item) # --------------------------------------------------------------------------- async def _individual_add_knowledge( - conversation, message_ordinal, chunk_ordinal, knowledge, + conversation, + message_ordinal, + chunk_ordinal, + knowledge, ): """Reproduces the pre-optimization per-item write logic.""" verify_has_semantic_ref_index(conversation) @@ -95,7 +97,9 @@ async def _individual_add_knowledge( if action.object_entity_name != "none": await semantic_ref_index.add_term(action.object_entity_name, ordinal) if action.indirect_object_entity_name != "none": - await semantic_ref_index.add_term(action.indirect_object_entity_name, ordinal) + await semantic_ref_index.add_term( + action.indirect_object_entity_name, ordinal + ) if action.params: for param in action.params: if isinstance(param, str): @@ -135,8 +139,7 @@ def synthetic_knowledge(chunk_index: int) -> kplib.KnowledgeResponse: name=f"entity_{chunk_index}_{j}", type=[f"type_{j}", f"category_{chunk_index % 5}"], facets=[ - kplib.Facet(name=f"facet_{j}", value=f"value_{j}") - for j in range(2) + kplib.Facet(name=f"facet_{j}", value=f"value_{j}") for j in range(2) ], ) for j in range(3) @@ -237,15 +240,21 @@ async def main() -> None: description="Benchmark semref index write strategies.", ) parser.add_argument( - "--chunks", type=int, default=50, + "--chunks", + type=int, + default=50, help="Number of knowledge chunks to write per run (default: 50).", ) parser.add_argument( - "--rounds", type=int, default=10, + "--rounds", + type=int, + default=10, help="Number of timed rounds (default: 10).", ) parser.add_argument( - "--warmup", type=int, default=2, + "--warmup", + type=int, + default=2, help="Number of untimed warmup rounds (default: 2).", ) args = parser.parse_args() @@ -262,21 +271,31 @@ async def main() -> None: print(f"Total semrefs per run: ~{refs_per_chunk * args.chunks}") individual = await run_benchmark( - "Individual writes", bench_individual, - args.chunks, args.rounds, args.warmup, + "Individual writes", + bench_individual, + args.chunks, + args.rounds, + args.warmup, ) print_report( "Individual writes (per-entity append + add_term)", - individual, args.rounds, args.warmup, + individual, + args.rounds, + args.warmup, ) batched = await run_benchmark( - "Batched writes", bench_batched, - args.chunks, args.rounds, args.warmup, + "Batched writes", + bench_batched, + args.chunks, + args.rounds, + args.warmup, ) print_report( "Batched writes (bulk extend + add_terms_batch)", - batched, args.rounds, args.warmup, + batched, + args.rounds, + args.warmup, ) speedup = statistics.fmean(individual) / statistics.fmean(batched)