Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions src/retriever/data_tiers/tier_0/dgraph/result_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import base64
import re
from collections.abc import Mapping
from contextlib import suppress
from dataclasses import dataclass, field, fields
from typing import Any, Literal, Self, TypeGuard, cast

import msgpack
import orjson

from retriever.data_tiers.utils import (
Expand All @@ -25,6 +27,21 @@ def _strip_prefix(d: Mapping[str, Any], prefix: str | None) -> Mapping[str, Any]
return {(k.removeprefix(prefix)): v for k, v in d.items()}


def _decode_msgpack_base64(value: Any) -> Any | None:
"""Decode a base64 string containing msgpack data into a Python object.

Returns None if value is missing or decoding fails.
"""
if not isinstance(value, str | bytes | bytearray):
return None
try:
raw_bytes = base64.b64decode(value)
# raw=False decodes msgpack bytes into str for map keys/values
return msgpack.unpackb(raw_bytes, raw=False)
except Exception:
return None


# -----------------
# Dataclasses
# -----------------
Expand All @@ -44,6 +61,8 @@ class Source:
resource_role: str
upstream_resource_ids: list[str] = field(default_factory=list)
source_record_urls: list[str] = field(default_factory=list)
source_id: str = ""
source_category: list[str] = field(default_factory=list)

@classmethod
def from_dict(
Expand All @@ -56,6 +75,8 @@ def from_dict(
resource_role=str(norm.get("resource_role", "")),
upstream_resource_ids=_to_str_list(norm.get("upstream_resource_ids")),
source_record_urls=_to_str_list(norm.get("source_record_urls")),
source_id=str(norm.get("source_id", "")),
source_category=_to_str_list(norm.get("source_category")),
)


Expand Down Expand Up @@ -93,6 +114,24 @@ class Edge:
sources: list[Source] = field(default_factory=list)
id: str | None = None
category: list[str] = field(default_factory=list)
anatomical_context_qualifier: list[str] = field(default_factory=list)
causal_mechanism_qualifier: str | None = None
species_context_qualifier: str | None = None
object_aspect_qualifier: str | None = None
object_direction_qualifier: str | None = None
subject_aspect_qualifier: str | None = None
subject_direction_qualifier: str | None = None
qualifiers: list[str] = field(default_factory=list)
FDA_regulatory_approvals: list[str] = field(default_factory=list)
clinical_approval_status: str | None = None
max_research_phase: str | None = None
p_value: float | None = None
adjusted_p_value: float | None = None
number_of_cases: int | None = None
dgidb_evidence_score: float | None = None
dgidb_interaction_score: float | None = None
has_supporting_studies_raw: str | None = None
has_supporting_studies: Any | None = None

def get_attributes(self) -> dict[str, Any]:
"""Return all fields which correspond to TRAPI attributes as a dict."""
Expand Down Expand Up @@ -164,6 +203,10 @@ def from_dict( # noqa: PLR0913
if isinstance(source_item, Mapping)
]

# Decode supporting studies (base64 → msgpack → Python)
raw_supporting = cast(str | None, norm.get("has_supporting_studies"))
decoded_supporting = _decode_msgpack_base64(raw_supporting)

return cls(
binding=binding,
direction="in" if direction == "in" else "out",
Expand Down Expand Up @@ -231,6 +274,50 @@ def from_dict( # noqa: PLR0913
sources=parsed_sources,
id=str(norm["eid"]) if "eid" in norm else None,
category=_to_str_list(norm.get("ecategory")),
anatomical_context_qualifier=_to_str_list(
norm.get("anatomical_context_qualifier")
),
causal_mechanism_qualifier=str(norm["causal_mechanism_qualifier"])
if "causal_mechanism_qualifier" in norm
else None,
species_context_qualifier=str(norm["species_context_qualifier"])
if "species_context_qualifier" in norm
else None,
object_aspect_qualifier=str(norm["object_aspect_qualifier"])
if "object_aspect_qualifier" in norm
else None,
object_direction_qualifier=str(norm["object_direction_qualifier"])
if "object_direction_qualifier" in norm
else None,
subject_aspect_qualifier=str(norm["subject_aspect_qualifier"])
if "subject_aspect_qualifier" in norm
else None,
subject_direction_qualifier=str(norm["subject_direction_qualifier"])
if "subject_direction_qualifier" in norm
else None,
qualifiers=_to_str_list(norm.get("qualifiers")),
FDA_regulatory_approvals=_to_str_list(norm.get("FDA_regulatory_approvals")),
clinical_approval_status=str(norm["clinical_approval_status"])
if "clinical_approval_status" in norm
else None,
max_research_phase=str(norm["max_research_phase"])
if "max_research_phase" in norm
else None,
p_value=float(norm["p_value"]) if "p_value" in norm else None,
adjusted_p_value=float(norm["adjusted_p_value"])
if "adjusted_p_value" in norm
else None,
number_of_cases=int(norm["number_of_cases"])
if "number_of_cases" in norm
else None,
dgidb_evidence_score=float(norm["dgidb_evidence_score"])
if "dgidb_evidence_score" in norm
else None,
dgidb_interaction_score=float(norm["dgidb_interaction_score"])
if "dgidb_interaction_score" in norm
else None,
has_supporting_studies_raw=raw_supporting,
has_supporting_studies=decoded_supporting,
)


Expand All @@ -249,6 +336,15 @@ class Node:
provided_by: list[str] = field(default_factory=list)
description: str | None = None
equivalent_identifiers: list[str] = field(default_factory=list)
full_name: str | None = None
symbol: str | None = None
synonym: list[str] = field(default_factory=list)
xref: list[str] = field(default_factory=list)
taxon: str | None = None
chembl_availability_type: str | None = None
chembl_black_box_warning: str | None = None
chembl_natural_product: bool | None = None
chembl_prodrug: bool | None = None

def get_attributes(self) -> dict[str, Any]:
"""Return all fields which correspond to TRAPI attributes as a dict."""
Expand Down Expand Up @@ -367,6 +463,23 @@ def from_dict(
provided_by=_to_str_list(norm.get("provided_by")),
description=str(norm["description"]) if "description" in norm else None,
equivalent_identifiers=_to_str_list(norm.get("equivalent_identifiers")),
full_name=str(norm["full_name"]) if "full_name" in norm else None,
symbol=str(norm["symbol"]) if "symbol" in norm else None,
synonym=_to_str_list(norm.get("synonym")),
xref=_to_str_list(norm.get("xref")),
taxon=str(norm["taxon"]) if "taxon" in norm else None,
chembl_availability_type=str(norm["chembl_availability_type"])
if "chembl_availability_type" in norm
else None,
chembl_black_box_warning=str(norm["chembl_black_box_warning"])
if "chembl_black_box_warning" in norm
else None,
chembl_natural_product=bool(norm["chembl_natural_product"])
if "chembl_natural_product" in norm
else None,
chembl_prodrug=bool(norm["chembl_prodrug"])
if "chembl_prodrug" in norm
else None,
)


Expand Down
Loading