diff --git a/server/services/datacommons.py b/server/services/datacommons.py index 0a2bd4e6f5..e64d0a83e8 100644 --- a/server/services/datacommons.py +++ b/server/services/datacommons.py @@ -14,6 +14,7 @@ """Copy of Data Commons Python Client API Core without pandas dependency.""" import asyncio +import collections import json import logging from typing import Dict, List @@ -370,12 +371,6 @@ def v2event(node, prop): return post(url, {"node": node, "property": prop}) -def get_place_info(dcids: List[str]) -> Dict: - """Retrieves Place Info given a list of DCIDs.""" - url = get_service_url("/v1/bulk/info/place") - return post(f"{url}", {"nodes": sorted(set(dcids))}) - - def get_variable_group_info(nodes: List[str], entities: List[str], numEntitiesExistence=1) -> Dict: @@ -403,16 +398,228 @@ def get_variable_ancestors(dcid: str): return get(url).get("ancestors", []) +PLACE_TYPE_RANK = { + "CensusZipCodeTabulationArea": 1, + "AdministrativeArea5": 2, + "AdministrativeArea4": 2, + "Village": 5, + "City": 5, + "Town": 5, + "Borough": 5, + "AdministrativeArea3": 5, + "County": 10, + "AdministrativeArea2": 10, + "EurostatNUTS3": 10, + "CensusDivision": 15, + "State": 20, + "AdministrativeArea1": 20, + "EurostatNUTS2": 20, + "EurostatNUTS1": 20, + "Country": 30, + "CensusRegion": 35, + "GeoRegion": 38, + "Continent": 40, + "Place": 50, +} + + +def get_place_info(dcids: List[str]) -> Dict: + """Retrieves Place Info given a list of DCIDs.""" + # Get ancestors using BFS since v2/node doesn't support recursive ->containedInPlace+ + ancestors_map = {dcid: set() for dcid in dcids} + + parent_graph = {} # child_dcid -> list of parent_dcids + frontier = set(dcids) + visited = set() + + # BFS to build parent graph (max depth 10) + for _ in range(10): + if not frontier: + break + + # Filter visited nodes to avoid cycles + fetch_dcids = [d for d in frontier if d not in visited] + if not fetch_dcids: + break + + resp = v2node(fetch_dcids, '->containedInPlace') + data = resp.get('data', {}) + + current_frontier = set() + for dcid in fetch_dcids: + visited.add(dcid) + node_data = data.get(dcid, {}) + + arcs_obj = node_data.get('arcs', {}).get('containedInPlace', {}) + nodes_list = arcs_obj.get('nodes', []) if isinstance(arcs_obj, + dict) else [] + + parents = [x['dcid'] for x in nodes_list if 'dcid' in x] + if parents: + parent_graph[dcid] = parents + current_frontier.update(parents) + + frontier = current_frontier + + # Build ancestors list from the graph + for dcid in dcids: + queue = collections.deque([dcid]) + seen = {dcid} + while queue: + curr = queue.popleft() + parents = parent_graph.get(curr, []) + for p in parents: + if p not in seen: + seen.add(p) + # Add to ancestors if it's not the node itself + if p != dcid: + ancestors_map[dcid].add(p) + queue.append(p) + + all_dcids = set() + for anc_set in ancestors_map.values(): + all_dcids.update(anc_set) + all_dcids.update(dcids) + + all_dcids_list = sorted(list(all_dcids)) + if not all_dcids_list: + return {'data': []} + + types_resp = v2node(all_dcids_list, '->typeOf') + names_resp = v2node(all_dcids_list, '->name') + + def get_first_value(resp, dcid, prop, key='dcid'): + node_data = resp.get('data', {}).get(dcid, {}) + arcs_obj = node_data.get('arcs', {}).get(prop, {}) + if not arcs_obj: + # Try checking without arrow if key mismatch + arcs_obj = node_data.get('arcs', {}).get(prop.replace('->', ''), {}) + + nodes_list = arcs_obj.get('nodes', []) if isinstance(arcs_obj, dict) else [] + + if nodes_list: + return nodes_list[0].get(key, '') + return '' + + result_data = [] + for dcid in dcids: + self_type = get_first_value(types_resp, dcid, 'typeOf') + self_name = get_first_value(names_resp, dcid, 'name', 'value') + + parents = [] + for anc_dcid in ancestors_map.get(dcid, []): + if anc_dcid == dcid: + continue + + anc_type = get_first_value(types_resp, anc_dcid, 'typeOf') + anc_name = get_first_value(names_resp, anc_dcid, 'name', 'value') + + if anc_type in PLACE_TYPE_RANK: + parents.append({ + 'dcid': anc_dcid, + 'type': anc_type, + 'name': anc_name, + 'rank': PLACE_TYPE_RANK[anc_type] + }) + + parents.sort(key=lambda x: x['rank']) + for p in parents: + del p['rank'] + + result_data.append({ + 'node': dcid, + 'info': { + 'self': { + 'dcid': dcid, + 'type': self_type, + 'name': self_name + }, + 'parents': parents + } + }) + + return {'data': result_data} + + def get_series_dates(parent_entity, child_type, variables): """Get series dates.""" - url = get_service_url("/v1/bulk/observation-dates/linked") - return post( - url, { - "linked_property": "containedInPlace", - "linked_entity": parent_entity, - "entity_type": child_type, - "variables": variables, - }) + # Get direct children + children_resp = v2node([parent_entity], '<-containedInPlace') + child_dcids = [] + + node_data = children_resp.get('data', {}).get(parent_entity, {}) + arcs_obj = node_data.get('arcs', {}).get('containedInPlace', {}) + nodes_list = arcs_obj.get('nodes', []) if isinstance(arcs_obj, dict) else [] + possible_children = [x['dcid'] for x in nodes_list if 'dcid' in x] + + # Filter by type if there are children + if possible_children: + # Filter children by requested type + type_resp = v2node(possible_children, 'typeOf') + for child in possible_children: + # Check node types + c_data = type_resp.get('data', {}).get(child, {}) + c_arcs = c_data.get('arcs', {}).get('typeOf', {}) + c_types = c_arcs.get('nodes', []) if isinstance(c_arcs, dict) else [] + c_type_ids = [t.get('dcid') for t in c_types] + if child_type in c_type_ids: + child_dcids.append(child) + + if not child_dcids: + return {"datesByVariable": [], "facets": {}} + + # Get observation dates for the filtered children + + obs_resp = v2observation( + select=['date', 'variable', 'entity', 'value', 'facet'], + entity={'dcids': child_dcids}, + variable={'dcids': variables}) + + # Aggregate results + # Aggregate results: { variable: { date: { facet: count } } } + agg_data = collections.defaultdict( + lambda: collections.defaultdict(lambda: collections.defaultdict(int))) + + # Iterate through V2 response + by_var = obs_resp.get('byVariable', {}) + + all_facets = obs_resp.get('facets', {}) + + for var, var_data in by_var.items(): + by_ent = var_data.get('byEntity', {}) + for ent, ent_data in by_ent.items(): + + series = ent_data.get('series', []) + for obs in series: + date = obs.get('date') + if not date: + continue + + # Facet handling + facet_id = obs.get('facet', "") + agg_data[var][date][facet_id] += 1 + # Assuming facets details are in 'facets' key of response? + # v2observation response should have 'facets' top level key if requested? + # 'facet' in select might return the ID in the series or the object? + # Usually it returns facetID and a top-level facets map. + + # Construct response + resp_dates = [] + for var, dates_map in agg_data.items(): + obs_dates = [] + for date, facet_counts in dates_map.items(): + entity_counts = [] + for facet_id, count in facet_counts.items(): + entity_counts.append({ + "count": count, + "facet": facet_id # V1 expects facet ID or object? + # V1 proto: EntityCount { count, facet } where facet is string (ID?). + # But typically it might expect the full facet object in a separate map. + }) + obs_dates.append({"date": date, "entityCount": entity_counts}) + resp_dates.append({"variable": var, "observationDates": obs_dates}) + + return {"datesByVariable": resp_dates, "facets": all_facets} def resolve(nodes, prop): diff --git a/server/tests/migration_verification_test.py b/server/tests/migration_verification_test.py new file mode 100644 index 0000000000..f33d88106c --- /dev/null +++ b/server/tests/migration_verification_test.py @@ -0,0 +1,359 @@ +import json +import unittest +from unittest.mock import patch + +from server.services import datacommons as dc + + +class TestMigrationVerification(unittest.TestCase): + + @patch('server.services.datacommons.post') + def test_get_place_info_v2(self, mock_post): + # Setup + dcids = ["geoId/06"] + + # Mock V2 responses for ancestors, types, and names + # Mock side_effect to handle these calls + def side_effect(url, data, api_key=None, log_extreme_calls=False): + prop = data.get("property", "") + + # Ancestors call (BFS uses ->containedInPlace) + if "->containedInPlace" in prop: + # Mock returning USA as parent of California + resp_data = {} + nodes = data.get("nodes", []) + for node in nodes: + if node == "geoId/06": + resp_data[node] = { + "arcs": { + "containedInPlace": { + "nodes": [{ + "dcid": "country/USA" + }] + } + } + } + # no parent for country/USA result in empty dict or just no entry + return {"data": resp_data} + + # Key property call (types or name) + if "nodes" in data and len(data["nodes"]) > 0: + nodes = data["nodes"] + if "typeOf" in prop: + return { + "data": { + "country/USA": { + "arcs": { + "typeOf": { + "nodes": [{ + "dcid": "Country" + }] + } + } + }, + "geoId/06": { + "arcs": { + "typeOf": { + "nodes": [{ + "dcid": "State" + }] + } + } + } + } + } + if "name" in prop: + return { + "data": { + "country/USA": { + "arcs": { + "name": { + "nodes": [{ + "value": "United States" + }] + } + } + }, + "geoId/06": { + "arcs": { + "name": { + "nodes": [{ + "value": "California" + }] + } + } + } + } + } + return {} + + mock_post.side_effect = side_effect + + # Execute + result = dc.get_place_info(dcids) + + # Verify + expected_parents = [{ + "dcid": "country/USA", + "type": "Country", + "name": "United States" + }] + + self.assertIn("data", result) + self.assertEqual(len(result["data"]), 1) + item = result["data"][0] + self.assertEqual(item["node"], "geoId/06") + self.assertIn("info", item) + self.assertEqual(item["info"]["self"]["dcid"], "geoId/06") + self.assertEqual(item["info"]["self"]["type"], "State") + self.assertEqual(item["info"]["self"]["name"], "California") + + # Verify parents are sorted + self.assertEqual(item["info"]["parents"], expected_parents) + + @patch('server.services.datacommons.post') + def test_get_series_dates_v2(self, mock_post): + # Setup + parent_entity = "geoId/06" + child_type = "County" + variables = ["Count_Person"] + + def side_effect(url, data, api_key=None, log_extreme_calls=False): + # Child nodes + if "<-containedInPlace" in data.get("property", ""): + return { + "data": { + "geoId/06": { + "arcs": { + "containedInPlace": { + "nodes": [{ + "dcid": "geoId/06001", + "name": "Alameda County", + "types": ["County"] + }, { + "dcid": "geoId/06085", + "name": "Santa Clara County", + "types": ["County"] + }] + } + } + } + } + } + # Child types + if "typeOf" in data.get("property", ""): + return { + "data": { + "geoId/06001": { + "arcs": { + "typeOf": { + "nodes": [{ + "dcid": "County" + }] + } + } + }, + "geoId/06085": { + "arcs": { + "typeOf": { + "nodes": [{ + "dcid": "County" + }] + } + } + } + } + } + # Observations + if "variable" in data and "entity" in data: + return { + "byVariable": { + "Count_Person": { + "byEntity": { + "geoId/06001": { + "series": [{ + "date": "2020", + "value": 100 + }] + }, + "geoId/06085": { + "series": [{ + "date": "2020", + "value": 200 + }, { + "date": "2021", + "value": 210 + }] + } + } + } + } + } + return {} + + mock_post.side_effect = side_effect + + # Execute + result = dc.get_series_dates(parent_entity, child_type, variables) + + # Verify + self.assertIn("datesByVariable", result) + self.assertEqual(len(result["datesByVariable"]), 1) + var_data = result["datesByVariable"][0] + self.assertEqual(var_data["variable"], "Count_Person") + + dates = {d["date"]: d for d in var_data["observationDates"]} + self.assertIn("2020", dates) + self.assertEqual(dates["2020"]["entityCount"][0]["count"], 2) + + self.assertIn("2021", dates) + self.assertEqual(dates["2021"]["entityCount"][0]["count"], 1) + + @patch('server.services.datacommons.post') + def test_get_place_info_edge_cases(self, mock_post): + """Test recursion limits, cycles, and missing data for get_place_info.""" + + # Scenario 1: Max Recursion Depth (Chain > 10 levels) + # We expect it to stop at level 10. + def recursion_side_effect(url, data, api_key=None, log_extreme_calls=False): + if "->containedInPlace" in data.get("property", ""): + # Expect 'nodes' in payload + nodes = data.get("nodes", []) + resp_data = {} + for node in nodes: + if node.startswith("node"): + try: + idx = int(node[4:]) + if idx < 15: # Create chain up to 15 + parent = f"node{idx+1}" + resp_data[node] = { + "arcs": { + "containedInPlace": { + "nodes": [{ + "dcid": parent + }] + } + } + } + except ValueError: + pass + return {"data": resp_data} + + # For types/names, return dummy data + if "nodes" in data: + resp_data = {} + for node in data["nodes"]: + resp_data[node] = { + "arcs": { + "typeOf": { + "nodes": [{ + "dcid": "Place" + }] + }, + "name": { + "nodes": [{ + "value": f"Name {node}" + }] + } + } + } + return {"data": resp_data} + return {} + + mock_post.side_effect = recursion_side_effect + + # Test max depth + dcids = ["node0"] + result = dc.get_place_info(dcids) + + self.assertIn("data", result) + item = result["data"][0] + # We expect parents to contain node1..node10 (10 levels) + self.assertEqual(len(item["info"]["parents"]), 10) + parent_dcids = [p["dcid"] for p in item["info"]["parents"]] + self.assertIn("node10", parent_dcids) + self.assertNotIn("node11", parent_dcids) + + @patch('server.services.datacommons.post') + def test_get_place_info_cycle(self, mock_post): + """Test handling of cycles in parent graph (A -> B -> A).""" + + def cycle_side_effect(url, data, api_key=None, log_extreme_calls=False): + if "->containedInPlace" in data.get("property", ""): + nodes = data.get("nodes", []) + resp_data = {} + for node in nodes: + if node == "nodeA": + resp_data[node] = { + "arcs": { + "containedInPlace": { + "nodes": [{ + "dcid": "nodeB" + }] + } + } + } + elif node == "nodeB": + resp_data[node] = { + "arcs": { + "containedInPlace": { + "nodes": [{ + "dcid": "nodeA" + }] + } + } + } + return {"data": resp_data} + # names/types + return { + "data": { + "nodeA": { + "arcs": { + "typeOf": { + "nodes": [{ + "dcid": "Place" + }] + }, + "name": { + "nodes": [{ + "value": "A" + }] + } + } + }, + "nodeB": { + "arcs": { + "typeOf": { + "nodes": [{ + "dcid": "Place" + }] + }, + "name": { + "nodes": [{ + "value": "B" + }] + } + } + } + } + } + + mock_post.side_effect = cycle_side_effect + + dcids = ["nodeA"] + result = dc.get_place_info(dcids) + + # Should not hang or crash + item = result["data"][0] + parents = item["info"]["parents"] + # nodeA's parents should include nodeB. + self.assertEqual(len(parents), 1) + self.assertEqual(parents[0]["dcid"], "nodeB") + + @patch('server.services.datacommons.post') + def test_get_series_dates_error(self, mock_post): + """Test error handling (e.g. 500 from API).""" + mock_post.side_effect = Exception("API Error") + + with self.assertRaises(Exception): + dc.get_series_dates("geoId/06", "County", ["Var1"])