Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 55 additions & 120 deletions src/retriever/data_tiers/tier_0/dgraph/result_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import base64
import re
from collections.abc import Mapping
from contextlib import suppress
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field
from typing import Any, Literal, Self, TypeGuard, cast

import msgpack
import orjson

from retriever.data_tiers.utils import (
Expand All @@ -25,6 +27,21 @@ def _strip_prefix(d: Mapping[str, Any], prefix: str | None) -> Mapping[str, Any]
return {(k.removeprefix(prefix)): v for k, v in d.items()}


def _decode_msgpack_base64(value: Any) -> Any | None:
"""Decode a base64 string containing msgpack data into a Python object.

Returns None if value is missing or decoding fails.
"""
if not isinstance(value, str | bytes | bytearray):
return None
try:
raw_bytes = base64.b64decode(value)
# raw=False decodes msgpack bytes into str for map keys/values
return msgpack.unpackb(raw_bytes, raw=False)
except Exception:
return None


# -----------------
# Dataclasses
# -----------------
Expand All @@ -44,6 +61,8 @@ class Source:
resource_role: str
upstream_resource_ids: list[str] = field(default_factory=list)
source_record_urls: list[str] = field(default_factory=list)
source_id: str = ""
source_category: list[str] = field(default_factory=list)

@classmethod
def from_dict(
Expand All @@ -56,6 +75,8 @@ def from_dict(
resource_role=str(norm.get("resource_role", "")),
upstream_resource_ids=_to_str_list(norm.get("upstream_resource_ids")),
source_record_urls=_to_str_list(norm.get("source_record_urls")),
source_id=str(norm.get("source_id", "")),
source_category=_to_str_list(norm.get("source_category")),
)


Expand All @@ -67,51 +88,11 @@ class Edge:
direction: Literal["in"] | Literal["out"]
predicate: str
node: Node
agent_type: str | None = None
knowledge_level: str | None = None
publications: list[str] = field(default_factory=list)
qualified_predicate: str | None = None
predicate_ancestors: list[str] = field(default_factory=list)
source_inforeses: list[str] = field(default_factory=list)
subject_form_or_variant_qualifier: str | None = None
disease_context_qualifier: str | None = None
frequency_qualifier: str | None = None
onset_qualifier: str | None = None
sex_qualifier: str | None = None
original_subject: str | None = None
original_predicate: str | None = None
original_object: str | None = None
allelic_requirement: str | None = None
update_date: str | None = None
z_score: float | None = None
has_evidence: list[str] = field(default_factory=list)
has_confidence_score: float | None = None
has_count: float | None = None
has_total: float | None = None
has_percentage: float | None = None
has_quotient: float | None = None
sources: list[Source] = field(default_factory=list)
source_inforeses: list[str] = field(default_factory=list)
id: str | None = None
category: list[str] = field(default_factory=list)

def get_attributes(self) -> dict[str, Any]:
"""Return all fields which correspond to TRAPI attributes as a dict."""
attrs = dict[str, Any]()
for data_field in fields(self):
if (
data_field.name not in DINGO_KG_EDGE_TOPLEVEL_VALUES
and not biolink.is_qualifier(data_field.name)
):
attrs[data_field.name] = getattr(self, data_field.name)
return attrs

def get_qualifiers(self) -> dict[str, Any]:
"""Return all fields which correspond to TRAPI qualfiers as a dict."""
qualifiers = dict[str, Any]()
for data_field in fields(self):
if biolink.is_qualifier(data_field.name):
qualifiers[data_field.name] = getattr(self, data_field.name)
return qualifiers
qualifiers: dict[str, str]
attributes: dict[str, Any]

@classmethod
def from_dict( # noqa: PLR0913
Expand Down Expand Up @@ -164,6 +145,27 @@ def from_dict( # noqa: PLR0913
if isinstance(source_item, Mapping)
]

# Decode supporting studies (base64 → msgpack → Python)
msgpack_encoded_keys = ["has_supporting_studies"]

qualifiers = dict[str, str]()
attributes = dict[str, Any]()
for key, value in norm.items():
if key in DINGO_KG_EDGE_TOPLEVEL_VALUES or key.startswith("node_"):
continue
if biolink.is_qualifier(key) and value is not None:
if not isinstance(value, str):
qualifiers[key] = orjson.dumps(value).decode()
else:
qualifiers[key] = str(value)
elif key in msgpack_encoded_keys:
attributes[key] = _decode_msgpack_base64(value)
elif value is not None:
new_key = key
if key == "ecategory":
new_key = "category"
attributes[new_key] = value

return cls(
binding=binding,
direction="in" if direction == "in" else "out",
Expand All @@ -175,62 +177,10 @@ def from_dict( # noqa: PLR0913
edge_id_map=edge_id_map,
node_id_map=node_id_map,
),
agent_type=str(norm["agent_type"]) if "agent_type" in norm else None,
knowledge_level=str(norm["knowledge_level"])
if "knowledge_level" in norm
else None,
publications=_to_str_list(norm.get("publications")),
qualified_predicate=str(norm["qualified_predicate"])
if "qualified_predicate" in norm
else None,
predicate_ancestors=_to_str_list(norm.get("predicate_ancestors")),
source_inforeses=_to_str_list(norm.get("source_inforeses")),
subject_form_or_variant_qualifier=str(
norm["subject_form_or_variant_qualifier"]
)
if "subject_form_or_variant_qualifier" in norm
else None,
disease_context_qualifier=str(norm["disease_context_qualifier"])
if "disease_context_qualifier" in norm
else None,
frequency_qualifier=str(norm["frequency_qualifier"])
if "frequency_qualifier" in norm
else None,
onset_qualifier=str(norm["onset_qualifier"])
if "onset_qualifier" in norm
else None,
sex_qualifier=str(norm["sex_qualifier"])
if "sex_qualifier" in norm
else None,
original_subject=str(norm["original_subject"])
if "original_subject" in norm
else None,
original_predicate=str(norm["original_predicate"])
if "original_predicate" in norm
else None,
original_object=str(norm["original_object"])
if "original_object" in norm
else None,
allelic_requirement=str(norm["allelic_requirement"])
if "allelic_requirement" in norm
else None,
update_date=str(norm["update_date"]) if "update_date" in norm else None,
z_score=float(norm["z_score"]) if "z_score" in norm else None,
has_evidence=_to_str_list(norm.get("has_evidence")),
has_confidence_score=float(norm["has_confidence_score"])
if "has_confidence_score" in norm
else None,
has_count=float(norm["has_count"]) if "has_count" in norm else None,
has_total=float(norm["has_total"]) if "has_total" in norm else None,
has_percentage=float(norm["has_percentage"])
if "has_percentage" in norm
else None,
has_quotient=float(norm["has_quotient"])
if "has_quotient" in norm
else None,
sources=parsed_sources,
id=str(norm["eid"]) if "eid" in norm else None,
category=_to_str_list(norm.get("ecategory")),
attributes=attributes,
qualifiers=qualifiers,
)


Expand All @@ -243,20 +193,7 @@ class Node:
name: str
edges: list[Edge] = field(default_factory=list)
category: list[str] = field(default_factory=list)
in_taxon: list[str] = field(default_factory=list)
information_content: float | None = None
inheritance: str | None = None
provided_by: list[str] = field(default_factory=list)
description: str | None = None
equivalent_identifiers: list[str] = field(default_factory=list)

def get_attributes(self) -> dict[str, Any]:
"""Return all fields which correspond to TRAPI attributes as a dict."""
attrs = dict[str, Any]()
for data_field in fields(self):
if data_field.name not in DINGO_KG_NODE_TOPLEVEL_VALUES:
attrs[data_field.name] = getattr(self, data_field.name)
return attrs
attributes: dict[str, Any]

@classmethod
def from_dict(
Expand All @@ -283,6 +220,7 @@ def from_dict(
norm = _strip_prefix(data, prefix)

edges: list[Edge] = []
attributes = dict[str, Any]()
for key, value in norm.items():
# Parse incoming edges (where this node is the OBJECT)
# Handle both "in_edges_e0" and "in_edges-symmetric_e0"
Expand Down Expand Up @@ -352,21 +290,18 @@ def from_dict(
)
for e in filter(_is_mapping, cast(list[Any], value))
)
elif key in DINGO_KG_NODE_TOPLEVEL_VALUES:
continue
elif value is not None:
attributes[key] = value

return cls(
binding=binding,
id=str(norm.get("id", "")),
name=str(norm.get("name", "")),
edges=edges,
category=_to_str_list(norm.get("category")),
in_taxon=_to_str_list(norm.get("in_taxon")),
information_content=float(norm["information_content"])
if "information_content" in norm
else None,
inheritance=str(norm["inheritance"]) if "inheritance" in norm else None,
provided_by=_to_str_list(norm.get("provided_by")),
description=str(norm["description"]) if "description" in norm else None,
equivalent_identifiers=_to_str_list(norm.get("equivalent_identifiers")),
attributes=attributes,
)


Expand Down
15 changes: 8 additions & 7 deletions src/retriever/data_tiers/tier_0/dgraph/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,11 +957,11 @@ def _build_trapi_node(self, node: dg.Node) -> NodeDict:
special_cases: dict[str, tuple[str, Any]] = {
"equivalent_identifiers": (
"biolink:xref",
[CURIE(i) for i in node.equivalent_identifiers],
[CURIE(i) for i in node.attributes.get("equivalent_identifiers", [])],
)
}

for attr_id, value in node.get_attributes().items():
for attr_id, value in node.attributes.items():
if attr_id in special_cases:
continue
if value is not None and value not in ([], ""):
Expand Down Expand Up @@ -995,10 +995,13 @@ def _build_trapi_edge(self, edge: dg.Edge, initial_curie: str) -> EdgeDict:
special_cases: dict[str, tuple[str, Any]] = {
"category": (
"biolink:category",
[BiolinkEntity(biolink.ensure_prefix(cat)) for cat in edge.category],
[
BiolinkEntity(biolink.ensure_prefix(cat))
for cat in edge.attributes.get("category", [])
],
),
}
for attr_id, value in edge.get_attributes().items():
for attr_id, value in edge.attributes.items():
if attr_id in special_cases:
continue
if value is not None and value not in ([], ""):
Expand All @@ -1013,9 +1016,7 @@ def _build_trapi_edge(self, edge: dg.Edge, initial_curie: str) -> EdgeDict:
attributes.append(AttributeDict(attribute_type_id=name, value=value))

# Build qualifiers
for qualifier_id, value in edge.get_qualifiers().items():
if value is None:
continue
for qualifier_id, value in edge.qualifiers.items():
qualifiers.append(
QualifierDict(
qualifier_type_id=QualifierTypeID(
Expand Down
1 change: 1 addition & 0 deletions src/retriever/data_tiers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"_index",
"seq_",
"negated", # Should only ever show up as false, field to be removed in future
"eid",
}


Expand Down
Loading