diff --git a/requirements_all.txt b/requirements_all.txt index 6983f07837..65370c9ec3 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -22,6 +22,7 @@ netCDF4 protobuf rasterio rdp +requests-mock s2sphere sentence-transformers tabula-py diff --git a/scripts/earthengine/utils.py b/scripts/earthengine/utils.py index 2f014c7dd2..44dbd71535 100644 --- a/scripts/earthengine/utils.py +++ b/scripts/earthengine/utils.py @@ -23,10 +23,10 @@ import re import sys import tempfile +from pathlib import Path from typing import Union from absl import logging -import datacommons as dc from dateutil.relativedelta import relativedelta from geopy import distance import s2sphere @@ -41,12 +41,12 @@ os.path.join(os.path.dirname(os.path.dirname(_SCRIPTS_DIR)), 'util')) from config_map import ConfigMap, read_py_dict_from_file, write_py_dict_to_file -from dc_api_wrapper import dc_api_wrapper +from dc_api_wrapper import dc_api_get_node_property # Constants _MAX_LATITUDE = 90.0 _MAX_LONGITUDE = 180.0 -_DC_API_ROOT = 'http://autopush.api.datacommons.org' +_DC_API_ROOT = 'https://api.datacommons.org' # Utilities for dicts. @@ -366,27 +366,31 @@ def place_id_to_lat_lng(placeid: str, placeid) elif dc_api_lookup: # Get the lat/lng from the DC API - latlng = [] - for prop in ['latitude', 'longitude']: - # dc.utils._API_ROOT = 'http://autopush.api.datacommons.org' - # resp = dc.get_property_values([placeid], prop) - resp = dc_api_wrapper( - function=dc.get_property_values, - args={ - 'dcids': [placeid], - 'prop': prop, - }, - use_cache=True, - api_root=_DC_API_ROOT, - ) - if not resp or placeid not in resp: - return (0, 0) - values = resp[placeid] - if not len(values): - return (0, 0) - latlng.append(float(values[0])) - lat = latlng[0] - lng = latlng[1] + resp = dc_api_get_node_property( + [placeid], + ['latitude', 'longitude'], + { + 'dc_api_version': 'V2', + 'dc_api_use_cache': True, + 'dc_api_root': _DC_API_ROOT, + }, + ) + node_props = resp.get(placeid) if resp else None + if not node_props: + return (0, 0) + + def _parse_coordinate(val): + if isinstance(val, list): + val = val[0] if val else None + if isinstance(val, str): + val = val.split(',')[0].strip().strip('"') + return str_get_numeric_value(val) + + lat = _parse_coordinate(node_props.get('latitude')) + lng = _parse_coordinate(node_props.get('longitude')) + + if lat is None or lng is None: + return (0, 0) return (lat, lng) diff --git a/scripts/earthengine/utils_test.py b/scripts/earthengine/utils_test.py index 1847cf2d31..bfdd347661 100644 --- a/scripts/earthengine/utils_test.py +++ b/scripts/earthengine/utils_test.py @@ -15,9 +15,11 @@ import math import os +from pathlib import Path import sys import tempfile import unittest +from unittest import mock from absl import logging import s2sphere @@ -371,3 +373,26 @@ def test_date_format_by_time_period(self): utils.date_format_by_time_period('2022-04-10', 'P3M')) self.assertEqual('2021', utils.date_format_by_time_period('2021-01-10', '1Y')) + + +class PlaceUtilsTest(unittest.TestCase): + + def test_place_id_to_lat_lng_dc_api(self): + placeid = 'geoId/06085' + response = { + placeid: { + 'latitude': '"37.221614","37.36"', + 'longitude': '"-121.68954","-121.97"', + } + } + with mock.patch('utils.dc_api_get_node_property', + return_value=response) as mock_get: + lat, lng = utils.place_id_to_lat_lng(placeid, dc_api_lookup=True) + self.assertAlmostEqual(37.221614, lat) + self.assertAlmostEqual(-121.68954, lng) + mock_get.assert_called_once_with( + [placeid], ['latitude', 'longitude'], { + 'dc_api_version': 'V2', + 'dc_api_use_cache': True, + 'dc_api_root': utils._DC_API_ROOT, + }) diff --git a/scripts/rff/preprocess_raster.py b/scripts/rff/preprocess_raster.py index 2aff529ae7..50be43928e 100644 --- a/scripts/rff/preprocess_raster.py +++ b/scripts/rff/preprocess_raster.py @@ -1,16 +1,30 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import csv -import datacommons as dc import glob import json import numpy as np import os from osgeo import gdal +from pathlib import Path from shapely import geometry import sys -RFF_DIR = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(RFF_DIR) -from rff import util +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) +from util.dc_api_wrapper import dc_api_batched_wrapper, get_datacommons_client +from scripts.rff import util bandname_to_gdcStatVars = { "std_dev": "StandardDeviation_", @@ -36,16 +50,61 @@ def get_dcid(sp_scale, lat, lon): def get_county_geoid(lat, lon): - counties = dc.get_places_in(['country/USA'], 'County')['country/USA'] - counties_simp = dc.get_property_values(counties, 'geoJsonCoordinatesDP1') + config = {'dc_api_use_cache': True} + client = get_datacommons_client(config) + + def extract_geojson(node_data, prop_name): + nodes = node_data.get('arcs', {}).get(prop_name, {}).get('nodes', []) + if not nodes: + return None + first_node = nodes[0] + if isinstance(first_node, dict): + return first_node.get('value') + return first_node.value + + counties_result = client.node.fetch_place_children( + place_dcids=['country/USA'], + children_type='County', + as_dict=True, + ) + counties = [ + node.get('dcid') + for node in counties_result.get('country/USA', []) + if node.get('dcid') + ] + counties_simp = dc_api_batched_wrapper( + function=client.node.fetch_property_values, + dcids=counties, + args={'properties': 'geoJsonCoordinatesDP1'}, + dcid_arg_kw='node_dcids', + config=config, + ) point = geometry.Point(lon, lat) - for p, gj in counties_simp.items(): - if len(gj) == 0: - gj = dc.get_property_values([p], 'geoJsonCoordinates')[p] - if len(gj) == 0: # property not defined for one county in alaska - continue - if geometry.shape(json.loads(gj[0])).contains(point): - return p + counties_missing_dp1 = [] + for county in counties: + node_data = counties_simp.get(county, {}) + geojson = extract_geojson(node_data, 'geoJsonCoordinatesDP1') + if not geojson: + counties_missing_dp1.append(county) + continue + if geometry.shape(json.loads(geojson)).contains(point): + return county + fallback = {} + if counties_missing_dp1: + fallback = dc_api_batched_wrapper( + function=client.node.fetch_property_values, + dcids=counties_missing_dp1, + args={'properties': 'geoJsonCoordinates'}, + dcid_arg_kw='node_dcids', + config=config, + ) + for county in counties_missing_dp1: + node_data = fallback.get(county, {}) + geojson = extract_geojson(node_data, 'geoJsonCoordinates') + if not geojson: # property not defined for one county in alaska + continue + if geometry.shape(json.loads(geojson)).contains(point): + return county return None diff --git a/scripts/rff/preprocess_raster_test.py b/scripts/rff/preprocess_raster_test.py new file mode 100644 index 0000000000..9066e74da8 --- /dev/null +++ b/scripts/rff/preprocess_raster_test.py @@ -0,0 +1,122 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from pathlib import Path +import types +import unittest +from unittest import mock + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) + +if "osgeo" not in sys.modules: + osgeo_module = types.ModuleType("osgeo") + gdal_module = types.ModuleType("gdal") + osgeo_module.gdal = gdal_module + sys.modules["osgeo"] = osgeo_module + sys.modules["osgeo.gdal"] = gdal_module + +from scripts.rff import preprocess_raster + + +class FakeNodeEndpoint: + + def __init__(self, place_children): + self._place_children = place_children + + def fetch_place_children(self, place_dcids, children_type, as_dict): + return {"country/USA": self._place_children} + + def fetch_property_values(self, node_dcids, properties): + raise AssertionError("fetch_property_values should not be called") + + +class FakeClient: + + def __init__(self, node): + self.node = node + + +class PreprocessRasterTest(unittest.TestCase): + + def test_get_county_geoid_dp1(self): + county = "geoId/06085" + geojson = ( + '{"type":"Polygon","coordinates":[[[0,0],[0,2],[2,2],[2,0],[0,0]]]}' + ) + dp1_properties = { + county: { + "arcs": { + "geoJsonCoordinatesDP1": { + "nodes": [{ + "value": geojson + }], + }, + }, + }, + } + node = FakeNodeEndpoint(place_children=[{"dcid": county}]) + client = FakeClient(node) + with mock.patch.object(preprocess_raster, + "get_datacommons_client", + return_value=client), mock.patch.object( + preprocess_raster, + "dc_api_batched_wrapper", + return_value=dp1_properties) as mock_wrapper: + result = preprocess_raster.get_county_geoid(1.0, 1.0) + self.assertEqual(result, county) + self.assertEqual(mock_wrapper.call_count, 1) + + def test_get_county_geoid_fallback(self): + county = "geoId/06085" + geojson = ( + '{"type":"Polygon","coordinates":[[[0,0],[0,2],[2,2],[2,0],[0,0]]]}' + ) + dp1_properties = { + county: { + "arcs": { + "geoJsonCoordinatesDP1": { + "nodes": [], + }, + }, + }, + } + fallback_properties = { + county: { + "arcs": { + "geoJsonCoordinates": { + "nodes": [{ + "value": geojson + }], + }, + }, + }, + } + node = FakeNodeEndpoint(place_children=[{"dcid": county}]) + client = FakeClient(node) + with mock.patch.object(preprocess_raster, + "get_datacommons_client", + return_value=client), mock.patch.object( + preprocess_raster, + "dc_api_batched_wrapper", + side_effect=[ + dp1_properties, fallback_properties + ]) as mock_wrapper: + result = preprocess_raster.get_county_geoid(1.0, 1.0) + self.assertEqual(result, county) + self.assertEqual(mock_wrapper.call_count, 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/scripts/us_census/enhanced_tmcf/process_etmcf.py b/scripts/us_census/enhanced_tmcf/process_etmcf.py index fd51dbbf2d..9afa4e7957 100644 --- a/scripts/us_census/enhanced_tmcf/process_etmcf.py +++ b/scripts/us_census/enhanced_tmcf/process_etmcf.py @@ -1,11 +1,17 @@ import csv -import datacommons as dc import os +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Tuple from absl import app from absl import flags -from dataclasses import dataclass -from typing import Dict, List, Tuple + +REPO_ROOT = Path(__file__).resolve().parents[3] +sys.path.insert(0, str(REPO_ROOT)) + +from util.dc_api_wrapper import dc_api_get_node_property GEO_ID_COLUMN = 'GEO_ID' NUM_DCIDS_TO_QUERY = 50 @@ -70,9 +76,10 @@ def _get_places_not_found(census_geoids: List[str]) -> List[str]: for i in range(0, len(geo_ids), NUM_DCIDS_TO_QUERY): selected_geo_ids = geo_ids[i:i + NUM_DCIDS_TO_QUERY] selected_dcids = [geoId_to_dcids[g] for g in selected_geo_ids] - res = dc.get_property_values(selected_dcids, 'name') + res = dc_api_get_node_property(selected_dcids, 'name') for index in range(len(selected_dcids)): - if not res[selected_dcids[index]]: + name = res.get(selected_dcids[index], {}).get('name') + if not name: geoIds_not_found.append(selected_geo_ids[index]) return geoIds_not_found @@ -292,4 +299,4 @@ def process_enhanced_tmcf(input_folder, output_folder, etmcf_filename, # Use the existing input CSV, the new_csv_columns list and maps of geoIds to DCIDs (for places) # and a list of geoIds not found to produce the processed (traditional) TMCF and corresponding CSV. _write_modified_csv(input_csv_filepath, csv_out, new_csv_columns, - geo_ids_to_dcids, geo_ids_not_found) \ No newline at end of file + geo_ids_to_dcids, geo_ids_not_found) diff --git a/scripts/us_census/enhanced_tmcf/process_etmcf_test.py b/scripts/us_census/enhanced_tmcf/process_etmcf_test.py index f87a48cddc..941e5c1e99 100644 --- a/scripts/us_census/enhanced_tmcf/process_etmcf_test.py +++ b/scripts/us_census/enhanced_tmcf/process_etmcf_test.py @@ -14,9 +14,16 @@ """Tests for process_etmcf.py""" import os +import sys import tempfile import unittest -from .process_etmcf import * +from pathlib import Path +from unittest import mock + +REPO_ROOT = Path(__file__).resolve().parents[3] +sys.path.insert(0, str(REPO_ROOT)) + +from scripts.us_census.enhanced_tmcf import process_etmcf _CODEDIR = os.path.dirname(os.path.realpath(__file__)) _INPUT_DIR = os.path.join(_CODEDIR, 'testdata', 'input') @@ -33,6 +40,15 @@ def compare_files(t, output_path, expected_path): class Process_ETMCF_Test(unittest.TestCase): + def test_get_places_not_found_uses_v2_wrapper(self): + geo_ids = ['0500000US06085', '0500000US06001'] + mock_response = {'geoId/06085': {'name': 'Santa Clara County'}} + with mock.patch.object(process_etmcf, + 'dc_api_get_node_property', + return_value=mock_response): + got = process_etmcf._get_places_not_found(geo_ids) + self.assertEqual(got, ['0500000US06001']) + def test_simple_success(self): self.maxDiff = None input_etmcf = 'simple' @@ -41,8 +57,17 @@ def test_simple_success(self): output_tmcf = "simple_processed" output_csv = "simple_processed" with tempfile.TemporaryDirectory() as tmp_dir: - process_enhanced_tmcf(_INPUT_DIR, tmp_dir, input_etmcf, input_csv, - output_tmcf, output_csv) + with mock.patch.object( + process_etmcf, + 'dc_api_get_node_property', + side_effect=lambda dcids, prop: + {dcid: { + 'name': 'name' + } for dcid in dcids}, + ): + process_etmcf.process_enhanced_tmcf(_INPUT_DIR, tmp_dir, + input_etmcf, input_csv, + output_tmcf, output_csv) for fname in [output_tmcf + ".tmcf", output_csv + ".csv"]: output_path = os.path.join(tmp_dir, fname) expected_path = os.path.join(_EXPECTED_DIR, fname) @@ -56,8 +81,17 @@ def test_simple_opaque_success(self): output_tmcf = "simple_opaque_processed" output_csv = "simple_opaque_processed" with tempfile.TemporaryDirectory() as tmp_dir: - process_enhanced_tmcf(_INPUT_DIR, tmp_dir, input_etmcf, input_csv, - output_tmcf, output_csv) + with mock.patch.object( + process_etmcf, + 'dc_api_get_node_property', + side_effect=lambda dcids, prop: + {dcid: { + 'name': 'name' + } for dcid in dcids}, + ): + process_etmcf.process_enhanced_tmcf(_INPUT_DIR, tmp_dir, + input_etmcf, input_csv, + output_tmcf, output_csv) for fname in [output_tmcf + ".tmcf", output_csv + ".csv"]: output_path = os.path.join(tmp_dir, fname) expected_path = os.path.join(_EXPECTED_DIR, fname) @@ -71,8 +105,17 @@ def test_process_enhanced_tmcf_medium_success(self): output_tmcf = "ECNBASIC2012.EC1200A1_processed" output_csv = "ECNBASIC2012.EC1200A1_processed" with tempfile.TemporaryDirectory() as tmp_dir: - process_enhanced_tmcf(_INPUT_DIR, tmp_dir, input_etmcf, input_csv, - output_tmcf, output_csv) + with mock.patch.object( + process_etmcf, + 'dc_api_get_node_property', + side_effect=lambda dcids, prop: + {dcid: { + 'name': 'name' + } for dcid in dcids}, + ): + process_etmcf.process_enhanced_tmcf(_INPUT_DIR, tmp_dir, + input_etmcf, input_csv, + output_tmcf, output_csv) for fname in [output_tmcf + ".tmcf", output_csv + ".csv"]: output_path = os.path.join(tmp_dir, fname) expected_path = os.path.join(_EXPECTED_DIR, fname) @@ -87,8 +130,9 @@ def test_tmcf_file_not_found_exception(self): output_csv = "no_output" with tempfile.TemporaryDirectory() as tmp_dir: with self.assertRaises(Exception): - process_enhanced_tmcf(_INPUT_DIR, tmp_dir, input_etmcf, - input_csv, output_tmcf, output_csv) + process_etmcf.process_enhanced_tmcf(_INPUT_DIR, tmp_dir, + input_etmcf, input_csv, + output_tmcf, output_csv) def test_csv_file_not_found_exception(self): self.maxDiff = None @@ -99,8 +143,9 @@ def test_csv_file_not_found_exception(self): output_csv = "no_output" with tempfile.TemporaryDirectory() as tmp_dir: with self.assertRaises(Exception): - process_enhanced_tmcf(_INPUT_DIR, tmp_dir, input_etmcf, - input_csv, output_tmcf, output_csv) + process_etmcf.process_enhanced_tmcf(_INPUT_DIR, tmp_dir, + input_etmcf, input_csv, + output_tmcf, output_csv) def test_bad_tmcf_variable_measured_two_question_marks_exception(self): self.maxDiff = None @@ -113,8 +158,9 @@ def test_bad_tmcf_variable_measured_two_question_marks_exception(self): with self.assertRaisesRegex( Exception, "Exactly one '\?' expected in variableMeasured*"): - process_enhanced_tmcf(_INPUT_DIR, tmp_dir, input_etmcf, - input_csv, output_tmcf, output_csv) + process_etmcf.process_enhanced_tmcf(_INPUT_DIR, tmp_dir, + input_etmcf, input_csv, + output_tmcf, output_csv) def test_bad_tmcf_variable_measured_two_equals_exception(self): self.maxDiff = None @@ -127,8 +173,9 @@ def test_bad_tmcf_variable_measured_two_equals_exception(self): with self.assertRaisesRegex( Exception, "Exactly one '=' expected in the key/val opaque mapping*"): - process_enhanced_tmcf(_INPUT_DIR, tmp_dir, input_etmcf, - input_csv, output_tmcf, output_csv) + process_etmcf.process_enhanced_tmcf(_INPUT_DIR, tmp_dir, + input_etmcf, input_csv, + output_tmcf, output_csv) if __name__ == '__main__': diff --git a/scripts/us_epa/parent_company/download_existing_facilities.py b/scripts/us_epa/parent_company/download_existing_facilities.py index 7f87ae830c..1156172e42 100644 --- a/scripts/us_epa/parent_company/download_existing_facilities.py +++ b/scripts/us_epa/parent_company/download_existing_facilities.py @@ -14,34 +14,65 @@ """A simple script to download existing Facilities in Data Commons.""" import os -import pathlib +import sys +from pathlib import Path -import datacommons import pandas as pd +import requests from absl import app from absl import flags +REPO_ROOT = Path(__file__).resolve().parents[3] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from util.dc_api_wrapper import get_dc_api_key + FLAGS = flags.FLAGS -flags.DEFINE_string('output_path', 'tmp_data', 'Output directory') +_V2_SPARQL_URL = "https://api.datacommons.org/v2/sparql" + +def _define_flags() -> None: + flags.DEFINE_string('output_path', 'tmp_data', 'Output directory') -def main(_): - assert FLAGS.output_path - pathlib.Path(FLAGS.output_path).mkdir(exist_ok=True) - out_file = os.path.join(FLAGS.output_path, 'existing_facilities.csv') + +def download_existing_facilities(output_path: str) -> str: + Path(output_path).mkdir(exist_ok=True) + out_file = os.path.join(output_path, 'existing_facilities.csv') q = "SELECT DISTINCT ?dcid WHERE {?a typeOf EpaReportingFacility . ?a dcid ?dcid }" - res = datacommons.query(q) + headers = {"Content-Type": "application/json"} + api_key = get_dc_api_key() + if api_key: + headers["X-API-Key"] = api_key + response = requests.post(_V2_SPARQL_URL, json={"query": q}, headers=headers) + response.raise_for_status() + res = response.json() facility_ids = [] - for facility in res: - facility_ids.append(facility["?dcid"]) + for row in res.get('rows', []): + cells = row.get('cells', []) + if not cells: + continue + value = cells[0].get('value') + if value: + facility_ids.append(value) df = pd.DataFrame.from_dict({"epaGhgrpFacilityId": facility_ids}) df.to_csv(out_file, mode="w", header=True, index=False) + return out_file + + +def main(_: list[str]) -> int: + output_path = FLAGS.output_path + if not output_path: + raise ValueError("output_path is required.") + download_existing_facilities(output_path) + return 0 if __name__ == '__main__': + _define_flags() app.run(main) diff --git a/scripts/us_epa/parent_company/download_existing_facilities_test.py b/scripts/us_epa/parent_company/download_existing_facilities_test.py new file mode 100644 index 0000000000..12b5c7c6f7 --- /dev/null +++ b/scripts/us_epa/parent_company/download_existing_facilities_test.py @@ -0,0 +1,81 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for download_existing_facilities.py.""" + +import os +import sys +import tempfile +from pathlib import Path +from unittest import mock + +import requests_mock +from absl.testing import absltest + +REPO_ROOT = Path(__file__).resolve().parents[3] +sys.path.insert(0, str(REPO_ROOT)) + +from scripts.us_epa.parent_company.download_existing_facilities import ( + download_existing_facilities,) +from scripts.us_epa.parent_company.download_existing_facilities import ( + _V2_SPARQL_URL,) + + +class DownloadExistingFacilitiesTest(absltest.TestCase): + + def test_download_existing_facilities(self): + response = { + "header": ["?dcid"], + "rows": [ + { + "cells": [{ + "value": "epaGhgrpFacilityId/1001" + }] + }, + { + "cells": [{ + "value": "epaGhgrpFacilityId/1002" + }] + }, + ], + } + with tempfile.TemporaryDirectory() as tmp_dir: + with requests_mock.Mocker() as mocker: + mocker.post(_V2_SPARQL_URL, json=response) + with mock.patch( + "scripts.us_epa.parent_company." + "download_existing_facilities.get_dc_api_key", + return_value="test-key"): + output_path = download_existing_facilities(tmp_dir) + + self.assertTrue(os.path.exists(output_path)) + with open(output_path, "r", encoding="utf-8") as handle: + contents = handle.read() + self.assertEqual( + contents, + "epaGhgrpFacilityId\n" + "epaGhgrpFacilityId/1001\n" + "epaGhgrpFacilityId/1002\n", + ) + self.assertLen(mocker.request_history, 1) + request = mocker.request_history[0] + self.assertEqual(request.headers.get("X-API-Key"), "test-key") + self.assertEqual( + request.json().get("query"), + "SELECT DISTINCT ?dcid WHERE {?a typeOf " + "EpaReportingFacility . ?a dcid ?dcid }", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/scripts/us_epa/parent_company/process_parent_company.py b/scripts/us_epa/parent_company/process_parent_company.py index 0af9d3353f..454a4cbb58 100644 --- a/scripts/us_epa/parent_company/process_parent_company.py +++ b/scripts/us_epa/parent_company/process_parent_company.py @@ -18,8 +18,6 @@ import sys import csv -import datacommons -import json import pandas as pd from absl import app diff --git a/scripts/us_epa/parent_company/process_parent_company_test.py b/scripts/us_epa/parent_company/process_parent_company_test.py index db968ee81a..32908bf298 100644 --- a/scripts/us_epa/parent_company/process_parent_company_test.py +++ b/scripts/us_epa/parent_company/process_parent_company_test.py @@ -14,11 +14,20 @@ """Tests for process_parent_company.py""" import os +import sys import tempfile import unittest -from .process_parent_company import process_companies -from .process_parent_company import process_svobs -from .process_parent_company import _COUNTERS_COMPANIES +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[3] +sys.path.insert(0, str(REPO_ROOT)) + +from scripts.us_epa.parent_company.process_parent_company import ( + process_companies,) +from scripts.us_epa.parent_company.process_parent_company import ( + process_svobs,) +from scripts.us_epa.parent_company.process_parent_company import ( + _COUNTERS_COMPANIES,) _CODEDIR = os.path.dirname(os.path.realpath(__file__)) _INPUT_DATA_DIR = os.path.join(_CODEDIR, 'testdata', 'input') diff --git a/tools/statvar_importer/mcf_filter.py b/tools/statvar_importer/mcf_filter.py index 11bc79997a..b06b94cf50 100644 --- a/tools/statvar_importer/mcf_filter.py +++ b/tools/statvar_importer/mcf_filter.py @@ -34,7 +34,6 @@ from absl import app from absl import flags from absl import logging -import datacommons as dc _FLAGS = flags.FLAGS diff --git a/tools/statvar_importer/place/place_resolver.py b/tools/statvar_importer/place/place_resolver.py index 330677cb0f..643dc986d9 100644 --- a/tools/statvar_importer/place/place_resolver.py +++ b/tools/statvar_importer/place/place_resolver.py @@ -51,7 +51,6 @@ from absl import app from absl import flags from absl import logging -import datacommons as dc # uncomment to run pprof # from pypprof.net_http import start_pprof_server diff --git a/util/dc_api_wrapper.py b/util/dc_api_wrapper.py index 2e084eb766..f5cbb868f7 100644 --- a/util/dc_api_wrapper.py +++ b/util/dc_api_wrapper.py @@ -30,6 +30,7 @@ import urllib import requests import threading +from typing import Union from absl import logging from datacommons_client.client import DataCommonsClient @@ -274,8 +275,8 @@ def dc_api_merge_results(results: dict, new_result: dict) -> dict: return results -def get_datacommons_client(config: dict = None) -> DataCommonsClient: - """Returns a DataCommonsClient object initialized using config.""" +def get_dc_api_key(config: dict = None) -> str: + """Returns the API key for DC API calls.""" if config is None: config = {} api_key = config.get('dc_api_key', os.environ.get('DC_API_KEY')) @@ -287,6 +288,14 @@ def get_datacommons_client(config: dict = None) -> DataCommonsClient: 'for more details.', n=1) api_key = _DEFAULT_DC_API_KEY + return api_key + + +def get_datacommons_client(config: dict = None) -> DataCommonsClient: + """Returns a DataCommonsClient object initialized using config.""" + if config is None: + config = {} + api_key = get_dc_api_key(config) dc_instance = config.get('dc_api_root') url = None # Check if API root is a host or url endpoint. @@ -348,7 +357,9 @@ def dc_api_is_defined_dcid(dcids: list, config: dict = {}) -> dict: return response -def dc_api_get_node_property(dcids: list, prop: str, config: dict = {}) -> dict: +def dc_api_get_node_property(dcids: list, + prop: Union[str, list], + config: dict = {}) -> dict: """Returns a dictionary keyed by dcid with { prop:value } for each dcid. Uses the get_property_values() DC API to lookup the property for each dcid. @@ -362,19 +373,29 @@ def dc_api_get_node_property(dcids: list, prop: str, config: dict = {}) -> dict: dictionary with each input dcid mapped to a True/False value. """ is_v2 = config.get('dc_api_version', 'V2') == 'V2' + if isinstance(prop, list): + if not prop: + raise ValueError('prop list is empty.') + if len(prop) == 1: + prop = prop[0] + if not is_v2: + raise ValueError( + 'V1 dc_api_get_node_property supports a single property.') + if is_v2: + return _dc_api_get_node_property_v2(dcids=dcids, + prop=prop, + config=config) + return _dc_api_get_node_property_v1(dcids=dcids, prop=prop, config=config) + + +def _dc_api_get_node_property_v2(dcids: list, + prop: Union[str, list], + config: dict = {}) -> dict: # Set parameters for V2 node API. client = get_datacommons_client(config) api_function = client.node.fetch_property_values args = {'properties': prop} dcid_arg_kw = 'node_dcids' - if not is_v2: - # Set parameters for V1 API. - api_function = dc.get_property_values - args = { - 'prop': prop, - 'out': True, - } - dcid_arg_kw = 'dcids' api_result = dc_api_batched_wrapper(function=api_function, dcids=dcids, args=args, @@ -387,10 +408,11 @@ def dc_api_get_node_property(dcids: list, prop: str, config: dict = {}) -> dict: if not node_data: continue - if is_v2: + arcs = node_data.get('arcs', {}) + prop_list = prop if isinstance(prop, list) else [prop] + for prop_name in prop_list: values = [] - arcs = node_data.get('arcs', {}) - prop_nodes = arcs.get(prop, {}).get('nodes', []) + prop_nodes = arcs.get(prop_name, {}).get('nodes', []) for node in prop_nodes: val_dcid = node.get('dcid') if val_dcid: @@ -400,10 +422,33 @@ def dc_api_get_node_property(dcids: list, prop: str, config: dict = {}) -> dict: value = '"' + value + '"' values.append(value) if values: - response[dcid] = {prop: ','.join(values)} - else: # V1 - if node_data: - response[dcid] = {prop: node_data} + if dcid not in response: + response[dcid] = {} + response[dcid][prop_name] = ','.join(values) + return response + + +def _dc_api_get_node_property_v1(dcids: list, + prop: str, + config: dict = {}) -> dict: + # Set parameters for V1 API. + api_function = dc.get_property_values + args = { + 'prop': prop, + 'out': True, + } + dcid_arg_kw = 'dcids' + api_result = dc_api_batched_wrapper(function=api_function, + dcids=dcids, + args=args, + dcid_arg_kw=dcid_arg_kw, + config=config) + response = {} + for dcid in dcids: + dcid_stripped = _strip_namespace(dcid) + node_data = api_result.get(dcid_stripped) + if node_data: + response[dcid] = {prop: node_data} return response diff --git a/util/dc_api_wrapper_test.py b/util/dc_api_wrapper_test.py index 184127922a..48dab3ca88 100644 --- a/util/dc_api_wrapper_test.py +++ b/util/dc_api_wrapper_test.py @@ -102,6 +102,19 @@ def test_dc_api_get_node_property(self): self.assertEqual(response_v2['Count_Person'], {'name': '"Total population"'}) + def test_dc_api_get_node_property_multi_v2(self): + """Test API wrapper to get multiple properties for a node.""" + dcids = ['Count_Person'] + props = ['populationType', 'measuredProperty', 'typeOf'] + response_v2 = dc_api.dc_api_get_node_property(dcids, props, + {'dc_api_version': 'V2'}) + self.assertTrue(response_v2) + self.assertIn('Count_Person', response_v2) + statvar_pvs = response_v2['Count_Person'] + self.assertTrue(statvar_pvs.get('populationType')) + self.assertTrue(statvar_pvs.get('measuredProperty')) + self.assertIn('StatisticalVariable', statvar_pvs.get('typeOf', '')) + def test_dc_api_resolve_placeid(self): """Test API wrapper to resolve entity using a placeid.""" placeids = ['ChIJT3IGqvxznW4Rqgw7pv9zYz8']