diff --git a/src/retriever/data_tiers/tier_0/dgraph/result_models.py b/src/retriever/data_tiers/tier_0/dgraph/result_models.py index 52ff8141..20f33ffd 100644 --- a/src/retriever/data_tiers/tier_0/dgraph/result_models.py +++ b/src/retriever/data_tiers/tier_0/dgraph/result_models.py @@ -1,11 +1,13 @@ from __future__ import annotations +import base64 import re from collections.abc import Mapping from contextlib import suppress from dataclasses import dataclass, field, fields from typing import Any, Literal, Self, TypeGuard, cast +import msgpack import orjson from retriever.data_tiers.utils import ( @@ -25,6 +27,21 @@ def _strip_prefix(d: Mapping[str, Any], prefix: str | None) -> Mapping[str, Any] return {(k.removeprefix(prefix)): v for k, v in d.items()} +def _decode_msgpack_base64(value: Any) -> Any | None: + """Decode a base64 string containing msgpack data into a Python object. + + Returns None if value is missing or decoding fails. + """ + if not isinstance(value, str | bytes | bytearray): + return None + try: + raw_bytes = base64.b64decode(value) + # raw=False decodes msgpack bytes into str for map keys/values + return msgpack.unpackb(raw_bytes, raw=False) + except Exception: + return None + + # ----------------- # Dataclasses # ----------------- @@ -44,6 +61,8 @@ class Source: resource_role: str upstream_resource_ids: list[str] = field(default_factory=list) source_record_urls: list[str] = field(default_factory=list) + source_id: str = "" + source_category: list[str] = field(default_factory=list) @classmethod def from_dict( @@ -56,6 +75,8 @@ def from_dict( resource_role=str(norm.get("resource_role", "")), upstream_resource_ids=_to_str_list(norm.get("upstream_resource_ids")), source_record_urls=_to_str_list(norm.get("source_record_urls")), + source_id=str(norm.get("source_id", "")), + source_category=_to_str_list(norm.get("source_category")), ) @@ -93,6 +114,24 @@ class Edge: sources: list[Source] = field(default_factory=list) id: str | None = None category: list[str] = field(default_factory=list) + anatomical_context_qualifier: list[str] = field(default_factory=list) + causal_mechanism_qualifier: str | None = None + species_context_qualifier: str | None = None + object_aspect_qualifier: str | None = None + object_direction_qualifier: str | None = None + subject_aspect_qualifier: str | None = None + subject_direction_qualifier: str | None = None + qualifiers: list[str] = field(default_factory=list) + FDA_regulatory_approvals: list[str] = field(default_factory=list) + clinical_approval_status: str | None = None + max_research_phase: str | None = None + p_value: float | None = None + adjusted_p_value: float | None = None + number_of_cases: int | None = None + dgidb_evidence_score: float | None = None + dgidb_interaction_score: float | None = None + has_supporting_studies_raw: str | None = None + has_supporting_studies: Any | None = None def get_attributes(self) -> dict[str, Any]: """Return all fields which correspond to TRAPI attributes as a dict.""" @@ -164,6 +203,10 @@ def from_dict( # noqa: PLR0913 if isinstance(source_item, Mapping) ] + # Decode supporting studies (base64 → msgpack → Python) + raw_supporting = cast(str | None, norm.get("has_supporting_studies")) + decoded_supporting = _decode_msgpack_base64(raw_supporting) + return cls( binding=binding, direction="in" if direction == "in" else "out", @@ -231,6 +274,50 @@ def from_dict( # noqa: PLR0913 sources=parsed_sources, id=str(norm["eid"]) if "eid" in norm else None, category=_to_str_list(norm.get("ecategory")), + anatomical_context_qualifier=_to_str_list( + norm.get("anatomical_context_qualifier") + ), + causal_mechanism_qualifier=str(norm["causal_mechanism_qualifier"]) + if "causal_mechanism_qualifier" in norm + else None, + species_context_qualifier=str(norm["species_context_qualifier"]) + if "species_context_qualifier" in norm + else None, + object_aspect_qualifier=str(norm["object_aspect_qualifier"]) + if "object_aspect_qualifier" in norm + else None, + object_direction_qualifier=str(norm["object_direction_qualifier"]) + if "object_direction_qualifier" in norm + else None, + subject_aspect_qualifier=str(norm["subject_aspect_qualifier"]) + if "subject_aspect_qualifier" in norm + else None, + subject_direction_qualifier=str(norm["subject_direction_qualifier"]) + if "subject_direction_qualifier" in norm + else None, + qualifiers=_to_str_list(norm.get("qualifiers")), + FDA_regulatory_approvals=_to_str_list(norm.get("FDA_regulatory_approvals")), + clinical_approval_status=str(norm["clinical_approval_status"]) + if "clinical_approval_status" in norm + else None, + max_research_phase=str(norm["max_research_phase"]) + if "max_research_phase" in norm + else None, + p_value=float(norm["p_value"]) if "p_value" in norm else None, + adjusted_p_value=float(norm["adjusted_p_value"]) + if "adjusted_p_value" in norm + else None, + number_of_cases=int(norm["number_of_cases"]) + if "number_of_cases" in norm + else None, + dgidb_evidence_score=float(norm["dgidb_evidence_score"]) + if "dgidb_evidence_score" in norm + else None, + dgidb_interaction_score=float(norm["dgidb_interaction_score"]) + if "dgidb_interaction_score" in norm + else None, + has_supporting_studies_raw=raw_supporting, + has_supporting_studies=decoded_supporting, ) @@ -249,6 +336,15 @@ class Node: provided_by: list[str] = field(default_factory=list) description: str | None = None equivalent_identifiers: list[str] = field(default_factory=list) + full_name: str | None = None + symbol: str | None = None + synonym: list[str] = field(default_factory=list) + xref: list[str] = field(default_factory=list) + taxon: str | None = None + chembl_availability_type: str | None = None + chembl_black_box_warning: str | None = None + chembl_natural_product: bool | None = None + chembl_prodrug: bool | None = None def get_attributes(self) -> dict[str, Any]: """Return all fields which correspond to TRAPI attributes as a dict.""" @@ -367,6 +463,23 @@ def from_dict( provided_by=_to_str_list(norm.get("provided_by")), description=str(norm["description"]) if "description" in norm else None, equivalent_identifiers=_to_str_list(norm.get("equivalent_identifiers")), + full_name=str(norm["full_name"]) if "full_name" in norm else None, + symbol=str(norm["symbol"]) if "symbol" in norm else None, + synonym=_to_str_list(norm.get("synonym")), + xref=_to_str_list(norm.get("xref")), + taxon=str(norm["taxon"]) if "taxon" in norm else None, + chembl_availability_type=str(norm["chembl_availability_type"]) + if "chembl_availability_type" in norm + else None, + chembl_black_box_warning=str(norm["chembl_black_box_warning"]) + if "chembl_black_box_warning" in norm + else None, + chembl_natural_product=bool(norm["chembl_natural_product"]) + if "chembl_natural_product" in norm + else None, + chembl_prodrug=bool(norm["chembl_prodrug"]) + if "chembl_prodrug" in norm + else None, ) diff --git a/tests/data_tiers/tier_0/dgraph/test_result_models.py b/tests/data_tiers/tier_0/dgraph/test_result_models.py index 408e5274..e9cb6a6c 100644 --- a/tests/data_tiers/tier_0/dgraph/test_result_models.py +++ b/tests/data_tiers/tier_0/dgraph/test_result_models.py @@ -1,21 +1,32 @@ -import pytest +import base64 from typing import Any +import msgpack +import pytest + from retriever.data_tiers.tier_0.dgraph import result_models as dg_models -def test_parse_single_success_case_versioned(): - """Test parsing of a well-formed, multi-hop Dgraph response.""" - # The raw response must match the actual Dgraph output format. - # A well-formed response with data for all node and edge properties. - raw_response = { +def _build_supporting_studies_payload() -> tuple[dict[str, Any], str]: + supporting_obj = { + "studies": [ + {"pmid": "PMID:42", "score": 0.95, "label": "supporting evidence"}, + {"pmid": "PMID:43", "score": 0.75, "label": "additional evidence"}, + ] + } + supporting_b64 = base64.b64encode( + msgpack.packb(supporting_obj, use_bin_type=True) + ).decode("ascii") + return supporting_obj, supporting_b64 + + +def _sample_versioned_raw_response(supporting_b64: str) -> dict[str, Any]: + return { "q1_node_n0": [ { "vA_name": "cytoplasmic vesicle", "vA_information_content": 56.8, - "vA_equivalent_identifiers": [ - "GO:0031410" - ], + "vA_equivalent_identifiers": ["GO:0031410"], "vA_id": "GO:0031410", "vA_category": [ "NamedThing", @@ -26,44 +37,65 @@ def test_parse_single_success_case_versioned(): "ThingWithTaxon", "SubjectOfInvestigation", "AnatomicalEntity", - "BiologicalEntity" + "BiologicalEntity", ], "vA_description": "A vesicle found in the cytoplasm of a cell.", + "vA_full_name": "Cytoplasmic vesicle", + "vA_symbol": "CV", + "vA_synonym": ["cytoplasmic vesicle", "cytoplasmic vesicles"], + "vA_xref": ["GO:0031410", "UMLS:C123"], + "vA_taxon": "NCBITaxon:9606", + "vA_chembl_availability_type": "clinical", + "vA_chembl_black_box_warning": "WARNING", + "vA_chembl_natural_product": True, + "vA_chembl_prodrug": False, "in_edges_e0": [ { "vA_knowledge_level": "prediction", - "vA_has_evidence": [ - "ECO:IEA" - ], + "vA_has_evidence": ["ECO:IEA"], "vA_original_subject": "UniProtKB:Q9UMZ2", "vA_sources": [ { "vA_resource_id": "infores:biolink", "vA_resource_role": "aggregator_knowledge_source", "vA_upstream_resource_ids": ["infores:goa"], - "vA_source_record_urls": ["https://example.com/record/123"] + "vA_source_record_urls": ["https://example.com/record/123"], + "vA_source_id": "123", + "vA_source_category": ["category1", "category2"], }, { "vA_resource_id": "infores:goa", - "vA_resource_role": "primary_knowledge_source" - } - ], - "vA_ecategory": [ - "Association" + "vA_resource_role": "primary_knowledge_source", + }, ], + "vA_ecategory": ["Association"], "vA_predicate": "located_in", - "vA_source_inforeses": [ - "infores:biolink", - "infores:goa" - ], + "vA_source_inforeses": ["infores:biolink", "infores:goa"], "vA_predicate_ancestors": [ "related_to_at_instance_level", "located_in", - "related_to" + "related_to", ], "vA_agent_type": "automated_agent", "vA_original_object": "GO:0031410", "vA_eid": "urn:uuid:0763a393-7cc8-4d80-8720-0efcc0f9245f", + "vA_anatomical_context_qualifier": ["UBERON:0001062"], + "vA_causal_mechanism_qualifier": "increases_activity", + "vA_species_context_qualifier": "NCBITaxon:9606", + "vA_object_aspect_qualifier": "expression", + "vA_object_direction_qualifier": "increased", + "vA_subject_aspect_qualifier": "activity", + "vA_subject_direction_qualifier": "reduced", + "vA_qualifiers": ["qual:1", "qual:2"], + "vA_FDA_regulatory_approvals": ["FDA:DrugA"], + "vA_clinical_approval_status": "approved", + "vA_max_research_phase": "Phase 4", + "vA_p_value": 0.0123, + "vA_adjusted_p_value": 0.0234, + "vA_number_of_cases": 42, + "vA_dgidb_evidence_score": 0.75, + "vA_dgidb_interaction_score": 0.88, + "vA_has_supporting_studies": supporting_b64, "node_n1": { "vA_information_content": 83.6, "vA_category": [ @@ -81,7 +113,7 @@ def test_parse_single_success_case_versioned(): "GenomicEntity", "GeneProductMixin", "Protein", - "BiologicalEntity" + "BiologicalEntity", ], "vA_equivalent_identifiers": [ "PR:Q9UMZ2", @@ -92,28 +124,24 @@ def test_parse_single_success_case_versioned(): "UMLS:C0893518", "MESH:C121510", "HGNC:557", - "NCBIGene:11276" + "NCBIGene:11276", ], "vA_id": "NCBIGene:11276", "vA_name": "SYNRG", "vA_description": "synergin gamma", - "vA_in_taxon": [ - "NCBITaxon:9606" - ] - } + "vA_in_taxon": ["NCBITaxon:9606"], + "vA_symbol": "SYNRG", + "vA_synonym": ["synergin gamma"], + "vA_xref": ["HGNC:557", "UniProtKB:Q9UMZ2"], + }, } - ] + ], } ] } - # 1. Parse the response - parsed = dg_models.DgraphResponse.parse(raw_response, prefix="vA_") - assert "q1" in parsed.data - assert len(parsed.data["q1"]) == 1 - # 2. Assertions for the root node (n0) - root_node = parsed.data["q1"][0] +def _assert_root_node_fields(root_node: dg_models.Node) -> None: assert root_node.binding == "n0" assert root_node.id == "GO:0031410" assert root_node.name == "cytoplasmic vesicle" @@ -131,10 +159,21 @@ def test_parse_single_success_case_versioned(): assert root_node.information_content == 56.8 assert root_node.equivalent_identifiers == ["GO:0031410"] assert root_node.description == "A vesicle found in the cytoplasm of a cell." + assert root_node.full_name == "Cytoplasmic vesicle" + assert root_node.symbol == "CV" + assert root_node.synonym == ["cytoplasmic vesicle", "cytoplasmic vesicles"] + assert root_node.xref == ["GO:0031410", "UMLS:C123"] + assert root_node.taxon == "NCBITaxon:9606" + assert root_node.chembl_availability_type == "clinical" + assert root_node.chembl_black_box_warning == "WARNING" + assert root_node.chembl_natural_product is True + assert root_node.chembl_prodrug is False assert len(root_node.edges) == 1 - # 3. Assertions for the incoming edge (e0) - in_edge = root_node.edges[0] + +def _assert_in_edge_fields( + in_edge: dg_models.Edge, supporting_b64: str, supporting_obj: dict[str, Any] +) -> None: assert in_edge.binding == "e0" assert in_edge.direction == "in" assert in_edge.predicate == "located_in" @@ -151,13 +190,43 @@ def test_parse_single_success_case_versioned(): ] assert in_edge.id == "urn:uuid:0763a393-7cc8-4d80-8720-0efcc0f9245f" assert in_edge.category == ["Association"] + assert in_edge.anatomical_context_qualifier == ["UBERON:0001062"] + assert in_edge.causal_mechanism_qualifier == "increases_activity" + assert in_edge.species_context_qualifier == "NCBITaxon:9606" + assert in_edge.object_aspect_qualifier == "expression" + assert in_edge.object_direction_qualifier == "increased" + assert in_edge.subject_aspect_qualifier == "activity" + assert in_edge.subject_direction_qualifier == "reduced" + assert in_edge.qualifiers == ["qual:1", "qual:2"] + assert in_edge.FDA_regulatory_approvals == ["FDA:DrugA"] + assert in_edge.clinical_approval_status == "approved" + assert in_edge.max_research_phase == "Phase 4" + assert in_edge.p_value == 0.0123 + assert in_edge.adjusted_p_value == 0.0234 + assert in_edge.number_of_cases == 42 + assert in_edge.dgidb_evidence_score == 0.75 + assert in_edge.dgidb_interaction_score == 0.88 + assert in_edge.has_supporting_studies_raw == supporting_b64 + assert in_edge.has_supporting_studies == supporting_obj assert in_edge.sources == [ - dg_models.Source(resource_id="infores:biolink", resource_role="aggregator_knowledge_source", upstream_resource_ids=["infores:goa"], source_record_urls=["https://example.com/record/123"]), - dg_models.Source(resource_id="infores:goa", resource_role="primary_knowledge_source", upstream_resource_ids=[], source_record_urls=[]), + dg_models.Source( + resource_id="infores:biolink", + resource_role="aggregator_knowledge_source", + upstream_resource_ids=["infores:goa"], + source_record_urls=["https://example.com/record/123"], + source_id="123", + source_category=["category1", "category2"], + ), + dg_models.Source( + resource_id="infores:goa", + resource_role="primary_knowledge_source", + upstream_resource_ids=[], + source_record_urls=[], + ), ] - # 4. Assertions for the connected node (n1) - connected_node = in_edge.node + +def _assert_connected_node_fields(connected_node: dg_models.Node) -> None: assert connected_node.binding == "n1" assert connected_node.id == "NCBIGene:11276" assert connected_node.name == "SYNRG" @@ -192,6 +261,27 @@ def test_parse_single_success_case_versioned(): "HGNC:557", "NCBIGene:11276", ] + assert connected_node.symbol == "SYNRG" + assert connected_node.synonym == ["synergin gamma"] + assert connected_node.xref == ["HGNC:557", "UniProtKB:Q9UMZ2"] + + +def test_parse_single_success_case_versioned(): + """Test parsing of a well-formed, multi-hop Dgraph response.""" + supporting_obj, supporting_b64 = _build_supporting_studies_payload() + raw_response = _sample_versioned_raw_response(supporting_b64) + + parsed = dg_models.DgraphResponse.parse(raw_response, prefix="vA_") + assert "q1" in parsed.data + assert len(parsed.data["q1"]) == 1 + + root_node = parsed.data["q1"][0] + _assert_root_node_fields(root_node) + + in_edge = root_node.edges[0] + _assert_in_edge_fields(in_edge, supporting_b64, supporting_obj) + + _assert_connected_node_fields(in_edge.node) def test_parse_batch_success_case(): @@ -366,7 +456,7 @@ def test_parse_symmetric_predicate_success_case(): "vC_name": "monoatomic ion channel complex", } } - ] + ] }] } @@ -381,7 +471,7 @@ def test_parse_symmetric_predicate_success_case(): assert root_node.id == "NCBIGene:3778" assert root_node.name == "KCNMA1" - # 3. Assert symmetric predicates: both in_edges_e0 and in_edges_e0_reverse + # 3. Assert symmetric predicates: both in_edges_e0 and in_edges_e0_reverse # are merged under binding "e0" (due to split("_", 3)) # Total: 1 out_edge + 2 in_edges (one from in_edges_e0, one from in_edges_e0_reverse) assert len(root_node.edges) == 3, "Should have 3 edges total: 1 out + 2 in (merged)"