Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from retriever.data_tiers.tier_1.elasticsearch.types import (
ESDocument,
ESHit,
ESEdge,
ESPayload,
ESResponse,
)
Expand All @@ -26,7 +26,7 @@ class QueryBody(QueryInfo):

async def parse_response(
response: ObjectApiResponse[ESResponse], page_size: int
) -> tuple[list[ESHit], list[Any] | None]:
) -> tuple[list[ESEdge], list[Any] | None]:
"""Parse an ES response and for 0) list of hits, and 1) search_after i.e. the pagination anchor for next query."""
if "hits" not in response:
raise RuntimeError(f"Invalid ES response: no hits in response body: {response}")
Expand All @@ -39,12 +39,7 @@ async def parse_response(
if len(fetched_documents) == page_size:
search_after = fetched_documents[-1]["sort"]

hits: list[ESHit] = [
hit["_source"]
if "_index" not in hit
else {**hit["_source"], "_index": hit["_index"]}
for hit in fetched_documents
]
hits = [ESEdge.from_dict(hit) for hit in fetched_documents]

return hits, search_after

Expand All @@ -71,13 +66,13 @@ async def run_single_query(
index_name: str,
query: ESPayload,
page_size: int = 1000,
) -> list[ESHit]:
) -> list[ESEdge]:
"""Adapter for running single query through _search and aggregating all hits."""
query_info: QueryInfo = {
"query": query["query"],
}

results: list[ESHit] = []
results = list[ESEdge]()

while True:
query_body = generate_query_body(query_info, page_size)
Expand All @@ -99,7 +94,7 @@ async def run_batch_query(
index_name: str,
queries: list[ESPayload],
page_size: int = 1000,
) -> list[list[ESHit]]:
) -> list[list[ESEdge]]:
"""Adapter for running batch queries through _msearch and aggregating all hits."""
query_collection: list[QueryInfo] = [
{
Expand All @@ -108,7 +103,7 @@ async def run_batch_query(
for query in queries
]

results: list[list[ESHit]] = [[] for _ in query_collection]
results: list[list[ESEdge]] = [[] for _ in query_collection]

current_query_indices = range(0, len(query_collection))

Expand Down
6 changes: 3 additions & 3 deletions src/retriever/data_tiers/tier_1/elasticsearch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
run_batch_query,
run_single_query,
)
from retriever.data_tiers.tier_1.elasticsearch.types import ESHit, ESPayload
from retriever.data_tiers.tier_1.elasticsearch.types import ESEdge, ESPayload
from retriever.data_tiers.utils import parse_dingo_metadata
from retriever.types.dingo import DINGO_ADAPTER, DINGOMetadata
from retriever.types.metakg import Operation, OperationNode
Expand Down Expand Up @@ -112,7 +112,7 @@ async def close(self) -> None:

async def run(
self, query: ESPayload | list[ESPayload]
) -> list[ESHit] | list[list[ESHit]] | None:
) -> list[ESEdge] | list[list[ESEdge]] | None:
"""Execute query logic."""
# Check ES connection instance
if self.es_connection is None:
Expand Down Expand Up @@ -162,7 +162,7 @@ async def run(
@tracer.start_as_current_span("elasticsearch_query")
async def run_query(
self, query: ESPayload | list[ESPayload], *args: Any, **kwargs: Any
) -> list[ESHit] | list[list[ESHit]] | None:
) -> list[ESEdge] | list[list[ESEdge]] | None:
"""Use ES async client to execute query via the `_search/_msearch` endpoints."""
otel_span = trace.get_current_span()
if not otel_span or not otel_span.is_recording():
Expand Down
80 changes: 39 additions & 41 deletions src/retriever/data_tiers/tier_1/elasticsearch/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
)
from retriever.data_tiers.tier_1.elasticsearch.types import (
ESBooleanQuery,
ESEdge,
ESFilterClause,
ESHit,
ESNode,
ESPayload,
ESQueryContext,
)
from retriever.data_tiers.utils import (
DINGO_KG_EDGE_TOPLEVEL_VALUES,
DINGO_KG_NODE_TOPLEVEL_VALUES,
)
from retriever.types.general import BackendResult
from retriever.types.trapi import (
CURIE,
Expand Down Expand Up @@ -181,15 +178,15 @@ def convert_triple(self, qgraph: QueryGraphDict) -> ESPayload:
def convert_batch_triple(self, qgraphs: list[QueryGraphDict]) -> list[ESPayload]:
return [self.convert_triple(qgraph) for qgraph in qgraphs]

def build_nodes(self, hits: list[ESHit]) -> dict[CURIE, NodeDict]:
def build_nodes(self, edges: list[ESEdge]) -> dict[CURIE, NodeDict]:
"""Build TRAPI nodes from backend representation."""
nodes = dict[CURIE, NodeDict]()
for hit in hits:
for edge in edges:
node_ids = dict[str, CURIE]()
for argument in ("subject", "object"):
node = hit[argument]
node_id = node["id"]
node_ids[argument] = node_id
for node_pos in ("subject", "object"):
node: ESNode = getattr(edge, node_pos)
node_id = CURIE(node.id)
node_ids[node_pos] = node_id
if node_id in nodes:
continue
attributes: list[AttributeDict] = []
Expand All @@ -198,13 +195,14 @@ def build_nodes(self, hits: list[ESHit]) -> dict[CURIE, NodeDict]:
special_cases: dict[str, tuple[str, Any]] = {
"equivalent_identifiers": (
"biolink:xref",
[CURIE(i) for i in node["equivalent_identifiers"]],
[
CURIE(i)
for i in node.attributes.get("equivalent_identifiers", [])
],
)
}

for field, value in node.items():
if field in DINGO_KG_NODE_TOPLEVEL_VALUES:
continue
for field, value in node.attributes.items():
if field in special_cases:
continue
if value is not None and value not in ([], ""):
Expand All @@ -222,21 +220,21 @@ def build_nodes(self, hits: list[ESHit]) -> dict[CURIE, NodeDict]:
)

trapi_node = NodeDict(
name=node["name"],
name=node.name,
categories=[
BiolinkEntity(biolink.ensure_prefix(cat))
for cat in node["category"]
for cat in node.category
],
attributes=attributes,
)

nodes[node_id] = trapi_node
return nodes

def build_edges(self, hits: list[ESHit]) -> dict[EdgeIdentifier, EdgeDict]:
def build_edges(self, edges: list[ESEdge]) -> dict[EdgeIdentifier, EdgeDict]:
"""Build TRAPI edges from backend representation."""
edges = dict[EdgeIdentifier, EdgeDict]()
for hit in hits:
trapi_edges = dict[EdgeIdentifier, EdgeDict]()
for edge in edges:
attributes: list[AttributeDict] = []
qualifiers: list[QualifierDict] = []
sources: list[RetrievalSourceDict] = []
Expand All @@ -247,25 +245,15 @@ def build_edges(self, hits: list[ESHit]) -> dict[EdgeIdentifier, EdgeDict]:
"biolink:category",
[
BiolinkEntity(biolink.ensure_prefix(cat))
for cat in hit.get("category", [])
for cat in edge.attributes.get("category", [])
],
),
}

# Build Attributes and Qualifiers
for field, value in hit.items():
if field in DINGO_KG_EDGE_TOPLEVEL_VALUES or field in special_cases:
for field, value in edge.attributes.items():
if field in special_cases:
continue
if biolink.is_qualifier(field):
qualifiers.append(
QualifierDict(
qualifier_type_id=QualifierTypeID(
biolink.ensure_prefix(field)
),
qualifier_value=str(value),
)
)
pass
elif value is not None and value not in ([], ""):
attributes.append(
AttributeDict(
Expand All @@ -274,6 +262,15 @@ def build_edges(self, hits: list[ESHit]) -> dict[EdgeIdentifier, EdgeDict]:
)
)

# Build Qualifiers
for qtype, qval in edge.qualifiers.items():
qualifiers.append(
QualifierDict(
qualifier_type_id=QualifierTypeID(biolink.ensure_prefix(qtype)),
qualifier_value=qval,
)
)

# Special case attributes
for name, value in special_cases.values():
if value is not None and value not in ([], ""):
Expand All @@ -282,7 +279,7 @@ def build_edges(self, hits: list[ESHit]) -> dict[EdgeIdentifier, EdgeDict]:
)

# Build Sources
for source in hit["sources"]:
for source in edge.sources:
retrieval_source = RetrievalSourceDict(
resource_id=Infores(source["resource_id"]),
resource_role=source["resource_role"],
Expand All @@ -297,9 +294,9 @@ def build_edges(self, hits: list[ESHit]) -> dict[EdgeIdentifier, EdgeDict]:

# Build Edge
trapi_edge = EdgeDict(
predicate=BiolinkPredicate(biolink.ensure_prefix(hit["predicate"])),
subject=CURIE(hit["subject"]["id"]),
object=CURIE(hit["object"]["id"]),
predicate=BiolinkPredicate(biolink.ensure_prefix(edge.predicate)),
subject=CURIE(edge.subject.id),
object=CURIE(edge.object.id),
sources=sources,
)
if len(attributes) > 0:
Expand All @@ -310,12 +307,13 @@ def build_edges(self, hits: list[ESHit]) -> dict[EdgeIdentifier, EdgeDict]:
append_aggregator_source(trapi_edge, Infores(CONFIG.tier1.backend_infores))

edge_hash = hash_hex(hash_edge(trapi_edge))
edges[edge_hash] = trapi_edge
return edges
trapi_edges[edge_hash] = trapi_edge

return trapi_edges

@override
def convert_results(
self, qgraph: QueryGraphDict, results: list[ESHit] | None
self, qgraph: QueryGraphDict, results: list[ESEdge] | None
) -> BackendResult:
nodes = self.build_nodes(results) if results is not None else {}
edges = self.build_edges(results) if results is not None else {}
Expand All @@ -327,7 +325,7 @@ def convert_results(
)

def convert_batch_results(
self, qgraph_list: list[QueryGraphDict], results: list[list[ESHit]]
self, qgraph_list: list[QueryGraphDict], results: list[list[ESEdge]]
) -> list[BackendResult]:
"""Wrapper for converting results for a batch query."""
return [
Expand Down
Loading