Skip to content
Merged
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ Documentation = "https://github.com/microsoft/typeagent-py/tree/main/docs/README
[tool.uv.build-backend]
module-root = "src"

[tool.uv.sources]
pytest-async-benchmark = { git = "https://github.com/KRRT7/pytest-async-benchmark.git", rev = "feat/pedantic-mode" }

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]
Expand Down Expand Up @@ -91,6 +94,7 @@ dev = [
"opentelemetry-instrumentation-httpx>=0.57b0",
"pyright>=1.1.408", # 407 has a regression
"pytest>=8.3.5",
"pytest-async-benchmark",
"pytest-asyncio>=0.26.0",
"pytest-benchmark>=5.1.0",
"pytest-mock>=3.14.0",
Expand Down
29 changes: 16 additions & 13 deletions src/typeagent/knowpro/answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,19 +452,22 @@ async def get_scored_semantic_refs_from_ordinals_iter(
semantic_ref_matches: list[ScoredSemanticRefOrdinal],
knowledge_type: KnowledgeType,
) -> list[Scored[SemanticRef]]:
result = []
for semantic_ref_match in semantic_ref_matches:
semantic_ref = await semantic_refs.get_item(
semantic_ref_match.semantic_ref_ordinal
)
if semantic_ref.knowledge.knowledge_type == knowledge_type:
result.append(
Scored(
item=semantic_ref,
score=semantic_ref_match.score,
)
)
return result
if not semantic_ref_matches:
return []
ordinals = [m.semantic_ref_ordinal for m in semantic_ref_matches]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
matching = [
(sr_match, m.ordinal)
for sr_match, m in zip(semantic_ref_matches, metadata)
if m.knowledge_type == knowledge_type
]
if not matching:
return []
full_refs = await semantic_refs.get_multiple([o for _, o in matching])
return [
Scored(item=ref, score=sr_match.score)
for (sr_match, _), ref in zip(matching, full_refs)
]


def merge_scored_concrete_entities(
Expand Down
37 changes: 23 additions & 14 deletions src/typeagent/knowpro/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,17 @@ async def group_matches_by_type(
self,
semantic_refs: ISemanticRefCollection,
) -> dict[KnowledgeType, "SemanticRefAccumulator"]:
matches = list(self)
if not matches:
return {}
ordinals = [match.value for match in matches]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
groups: dict[KnowledgeType, SemanticRefAccumulator] = {}
for match in self:
semantic_ref = await semantic_refs.get_item(match.value)
group = groups.get(semantic_ref.knowledge.knowledge_type)
for match, m in zip(matches, metadata):
group = groups.get(m.knowledge_type)
if group is None:
group = SemanticRefAccumulator(self.search_term_matches)
groups[semantic_ref.knowledge.knowledge_type] = group
groups[m.knowledge_type] = group
group.set_match(match)
return groups

Expand All @@ -346,11 +350,14 @@ async def get_matches_in_scope(
semantic_refs: ISemanticRefCollection,
ranges_in_scope: "TextRangesInScope",
) -> "SemanticRefAccumulator":
matches = list(self)
if not matches:
return SemanticRefAccumulator(self.search_term_matches)
ordinals = [match.value for match in matches]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
accumulator = SemanticRefAccumulator(self.search_term_matches)
for match in self:
if ranges_in_scope.is_range_in_scope(
(await semantic_refs.get_item(match.value)).range
):
for match, m in zip(matches, metadata):
if ranges_in_scope.is_range_in_scope(m.range):
accumulator.set_match(match)
return accumulator

Expand Down Expand Up @@ -519,12 +526,14 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No
self.add_range(text_range)

def contains_range(self, inner_range: TextRange) -> bool:
# Since ranges are sorted by start, once we pass inner_range's start
# no further range can contain it.
for outer_range in self._ranges:
if outer_range.start > inner_range.start:
break
if inner_range in outer_range:
if not self._ranges:
return False
# Bisect on start only to find all ranges with start <= inner.start,
# then scan backwards — the most likely containing range has the
# largest start still <= inner's.
hi = bisect.bisect_right(self._ranges, inner_range.start, key=lambda r: r.start)
for i in range(hi - 1, -1, -1):
if inner_range in self._ranges[i]:
return True
return False

Expand Down
36 changes: 12 additions & 24 deletions src/typeagent/knowpro/interfaces_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,32 +255,24 @@ def __repr__(self) -> str:
else:
return f"{self.__class__.__name__}({self.start}, {self.end})"

@staticmethod
def _effective_end(tr: "TextRange") -> tuple[int, int]:
"""Return (message_ordinal, chunk_ordinal) for the effective end."""
if tr.end is not None:
return (tr.end.message_ordinal, tr.end.chunk_ordinal)
return (tr.start.message_ordinal, tr.start.chunk_ordinal + 1)

def __eq__(self, other: object) -> bool:
if not isinstance(other, TextRange):
return NotImplemented

if self.start != other.start:
return False

# Get the effective end for both ranges
self_end = self.end or TextLocation(
self.start.message_ordinal, self.start.chunk_ordinal + 1
)
other_end = other.end or TextLocation(
other.start.message_ordinal, other.start.chunk_ordinal + 1
)
return self_end == other_end
return TextRange._effective_end(self) == TextRange._effective_end(other)

def __lt__(self, other: Self) -> bool:
if self.start != other.start:
return self.start < other.start
self_end = self.end or TextLocation(
self.start.message_ordinal, self.start.chunk_ordinal + 1
)
other_end = other.end or TextLocation(
other.start.message_ordinal, other.start.chunk_ordinal + 1
)
return self_end < other_end
return TextRange._effective_end(self) < TextRange._effective_end(other)

def __gt__(self, other: Self) -> bool:
return other.__lt__(self)
Expand All @@ -292,13 +284,9 @@ def __le__(self, other: Self) -> bool:
return not other.__lt__(self)

def __contains__(self, other: Self) -> bool:
other_end = other.end or TextLocation(
other.start.message_ordinal, other.start.chunk_ordinal + 1
)
self_end = self.end or TextLocation(
self.start.message_ordinal, self.start.chunk_ordinal + 1
)
return self.start <= other.start and other_end <= self_end
if not (self.start <= other.start):
return False
return TextRange._effective_end(other) <= TextRange._effective_end(self)

def serialize(self) -> TextRangeData:
return self.__pydantic_serializer__.to_python( # type: ignore
Expand Down
19 changes: 18 additions & 1 deletion src/typeagent/knowpro/interfaces_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@

from collections.abc import AsyncIterable, Iterable
from datetime import datetime as Datetime
from typing import Any, Protocol, Self
from typing import Any, NamedTuple, Protocol, Self

from pydantic.dataclasses import dataclass

from .interfaces_core import (
IMessage,
ITermToSemanticRefIndex,
KnowledgeType,
MessageOrdinal,
SemanticRef,
SemanticRefOrdinal,
TextRange,
)
from .interfaces_indexes import (
IConversationSecondaryIndexes,
Expand Down Expand Up @@ -57,6 +59,14 @@ class ConversationMetadata:
extra: dict[str, str] | None = None


class SemanticRefMetadata(NamedTuple):
"""Lightweight metadata for filtering without full knowledge deserialization."""

ordinal: SemanticRefOrdinal
range: TextRange
knowledge_type: KnowledgeType


class IReadonlyCollection[T, TOrdinal](AsyncIterable[T], Protocol):
async def size(self) -> int: ...

Expand Down Expand Up @@ -91,6 +101,12 @@ class IMessageCollection[TMessage: IMessage](
class ISemanticRefCollection(ICollection[SemanticRef, SemanticRefOrdinal], Protocol):
"""A collection of SemanticRefs."""

async def get_metadata_multiple(
self, ordinals: list[SemanticRefOrdinal]
) -> list[SemanticRefMetadata]:
"""Batch-fetch lightweight metadata without deserializing knowledge."""
...


class IStorageProvider[TMessage: IMessage](Protocol):
"""API spec for storage providers -- maybe in-memory or persistent."""
Expand Down Expand Up @@ -190,4 +206,5 @@ class IConversation[
"ISemanticRefCollection",
"IStorageProvider",
"STATUS_INGESTED",
"SemanticRefMetadata",
]
18 changes: 7 additions & 11 deletions src/typeagent/knowpro/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ScoredSemanticRefOrdinal,
SearchTerm,
SemanticRef,
SemanticRefMetadata,
SemanticRefOrdinal,
SemanticRefSearchResult,
Term,
Expand Down Expand Up @@ -174,17 +175,14 @@ async def lookup_term_filtered(
semantic_ref_index: ITermToSemanticRefIndex,
term: Term,
semantic_refs: ISemanticRefCollection,
filter: Callable[[SemanticRef, ScoredSemanticRefOrdinal], bool],
filter: Callable[[SemanticRefMetadata, ScoredSemanticRefOrdinal], bool],
) -> list[ScoredSemanticRefOrdinal] | None:
"""Look up a term in the semantic reference index and filter the results."""
scored_refs = await semantic_ref_index.lookup_term(term.text)
if scored_refs:
filtered = []
for sr in scored_refs:
semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal)
if filter(semantic_ref, sr):
filtered.append(sr)
return filtered
ordinals = [sr.semantic_ref_ordinal for sr in scored_refs]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
return [sr for sr, m in zip(scored_refs, metadata) if filter(m, sr)]
return None


Expand All @@ -202,10 +200,8 @@ async def lookup_term(
semantic_ref_index,
term,
semantic_refs,
lambda sr, _: (
not knowledge_type or sr.knowledge.knowledge_type == knowledge_type
)
and ranges_in_scope.is_range_in_scope(sr.range),
lambda m, _: (not knowledge_type or m.knowledge_type == knowledge_type)
and ranges_in_scope.is_range_in_scope(m.range),
)
return await semantic_ref_index.lookup_term(term.text)

Expand Down
13 changes: 13 additions & 0 deletions src/typeagent/storage/memory/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
IMessage,
MessageOrdinal,
SemanticRef,
SemanticRefMetadata,
SemanticRefOrdinal,
)

Expand Down Expand Up @@ -63,6 +64,18 @@ async def extend(self, items: Iterable[T]) -> None:
class MemorySemanticRefCollection(MemoryCollection[SemanticRef, SemanticRefOrdinal]):
"""A collection of semantic references."""

async def get_metadata_multiple(
self, ordinals: list[SemanticRefOrdinal]
) -> list[SemanticRefMetadata]:
return [
SemanticRefMetadata(
ordinal=o,
range=self.items[o].range,
knowledge_type=self.items[o].knowledge.knowledge_type,
)
for o in ordinals
]


class MemoryMessageCollection[TMessage: IMessage](
MemoryCollection[TMessage, MessageOrdinal]
Expand Down
13 changes: 7 additions & 6 deletions src/typeagent/storage/memory/propindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,13 @@ async def lookup_property_in_property_index(
property_value,
)
if ranges_in_scope is not None and scored_refs:
filtered_refs = []
for sr in scored_refs:
semantic_ref = await semantic_refs.get_item(sr.semantic_ref_ordinal)
if ranges_in_scope.is_range_in_scope(semantic_ref.range):
filtered_refs.append(sr)
scored_refs = filtered_refs
ordinals = [sr.semantic_ref_ordinal for sr in scored_refs]
metadata = await semantic_refs.get_metadata_multiple(ordinals)
scored_refs = [
sr
for sr, m in zip(scored_refs, metadata)
if ranges_in_scope.is_range_in_scope(m.range)
]

return scored_refs or None # Return None if no results

Expand Down
44 changes: 44 additions & 0 deletions src/typeagent/storage/sqlite/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,50 @@ async def get_multiple(self, arg: list[int]) -> list[interfaces.SemanticRef]:
assert set(rowdict) == set(arg)
return [self._deserialize_semantic_ref_from_row(rowdict[ordl]) for ordl in arg]

async def get_metadata_multiple(
Comment thread
bmerkle marked this conversation as resolved.
self, ordinals: list[int]
) -> list[interfaces.SemanticRefMetadata]:
if not ordinals:
return []
cursor = self.db.cursor()
placeholders = ",".join("?" * len(ordinals))
cursor.execute(
f"""
SELECT semref_id, range_json, knowledge_type
FROM SemanticRefs WHERE semref_id IN ({placeholders})
""",
ordinals,
)
rows = cursor.fetchall()
rowdict = {r[0]: r for r in rows}
result = []
for o in ordinals:
row = rowdict[o]
range_data = json.loads(row[1])
start = range_data["start"]
end_data = range_data.get("end")
result.append(
interfaces.SemanticRefMetadata(
ordinal=row[0],
range=interfaces.TextRange(
start=interfaces.TextLocation(
start["messageOrdinal"],
start.get("chunkOrdinal", 0),
),
end=(
interfaces.TextLocation(
end_data["messageOrdinal"],
end_data.get("chunkOrdinal", 0),
)
if end_data
else None
),
),
knowledge_type=row[2],
)
)
return result

async def append(self, item: interfaces.SemanticRef) -> None:
cursor = self.db.cursor()
semref_id, range_json, knowledge_type, knowledge_json = (
Expand Down
Loading
Loading