From 16c895ddbd615e799c7ffdd3669ee1841448843a Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 11:45:25 +1000 Subject: [PATCH 01/11] Add _xarray module --- src/scmdata/_xarray.py | 209 ++++++++++++++++++++++++++++++++++++++ src/scmdata/netcdf.py | 208 +------------------------------------ src/scmdata/run.py | 4 + tests/unit/test_netcdf.py | 25 +++-- tests/unit/test_run.py | 8 ++ 5 files changed, 237 insertions(+), 217 deletions(-) create mode 100644 src/scmdata/_xarray.py diff --git a/src/scmdata/_xarray.py b/src/scmdata/_xarray.py new file mode 100644 index 00000000..af350f38 --- /dev/null +++ b/src/scmdata/_xarray.py @@ -0,0 +1,209 @@ +""" +Interface with `xarray `_ +""" +import numpy as np +import xarray as xr + +from .errors import NonUniqueMetadataError + + +def _to_xarray(run, dimensions, extras): + timeseries = _get_timeseries_for_xr_dataset(run, dimensions, extras) + non_dimension_extra_metadata = _get_other_metdata_for_xr_dataset( + run, dimensions, extras + ) + + if extras: + ids, ids_dimensions = _get_ids_for_xr_dataset(run, extras, dimensions) + else: + ids = None + ids_dimensions = None + + for_xarray = _get_dataframe_for_xr_dataset( + timeseries, dimensions, extras, ids, ids_dimensions + ) + xr_ds = xr.Dataset.from_dataframe(for_xarray) + + if extras: + xr_ds = _add_extras(xr_ds, ids, ids_dimensions, run) + + unit_map = ( + run.meta[["variable", "unit"]].drop_duplicates().set_index("variable")["unit"] + ) + xr_ds = _add_units(xr_ds, unit_map) + xr_ds = _rename_variables(xr_ds) + xr_ds = _add_scmdata_metadata(xr_ds, non_dimension_extra_metadata) + + return xr_ds + + +def _get_timeseries_for_xr_dataset(run, dimensions, extras): + for d in dimensions: + vals = sorted(run.meta[d].unique()) + if not all([isinstance(v, str) for v in vals]) and np.isnan(vals).any(): + raise AssertionError("nan in dimension: `{}`".format(d)) + + try: + timeseries = run.timeseries(dimensions + extras + ["variable"]) + except NonUniqueMetadataError as exc: + error_msg = ( + "dimensions: `{}` and extras: `{}` do not uniquely define the " + "timeseries, please add extra dimensions and/or extras".format( + dimensions, extras + ) + ) + raise ValueError(error_msg) from exc + + timeseries.columns = run.time_points.as_cftime() + + return timeseries + + +def _get_other_metdata_for_xr_dataset(run, dimensions, extras): + other_dimensions = list( + set(run.meta.columns) - set(dimensions) - set(extras) - {"variable", "unit"} + ) + other_metdata = run.meta[other_dimensions].drop_duplicates() + if other_metdata.shape[0] > 1 and not other_metdata.empty: + error_msg = ( + "Other metadata is not unique for dimensions: `{}` and extras: `{}`. " + "Please add meta columns with more than one value to dimensions or " + "extras.\nNumber of unique values in each column:\n{}.\n" + "Existing values in the other metadata:\n{}.".format( + dimensions, + extras, + other_metdata.nunique(), + other_metdata.drop_duplicates(), + ) + ) + raise ValueError(error_msg) + + return other_metdata + + +def _get_ids_for_xr_dataset(run, extras, dimensions): + # these loops could be very slow with lots of extras and dimensions... + ids_dimensions = {} + for extra in extras: + for col in dimensions: + if _many_to_one(run.meta, extra, col): + dim_col = col + break + else: + dim_col = "_id" + + ids_dimensions[extra] = dim_col + + ids = run.meta[extras].drop_duplicates() + ids["_id"] = range(ids.shape[0]) + ids = ids.set_index(extras) + + return ids, ids_dimensions + + +def _many_to_one(df, col1, col2): + """ + Check if there is a many to one mapping between col2 and col1 + """ + # thanks https://stackoverflow.com/a/59091549 + checker = df[[col1, col2]].drop_duplicates() + + max_count = checker.groupby(col2).count().max()[0] + if max_count < 1: # pragma: no cover # emergency valve + raise AssertionError + + return max_count == 1 + + +def _get_dataframe_for_xr_dataset(timeseries, dimensions, extras, ids, ids_dimensions): + timeseries = timeseries.reset_index() + + add_id_dimension = extras and "_id" in set(ids_dimensions.values()) + if add_id_dimension: + timeseries = ( + timeseries.set_index(ids.index.names) + .join(ids) + .reset_index(drop=True) + .set_index(dimensions + ["variable", "_id"]) + ) + else: + timeseries = timeseries.set_index(dimensions + ["variable"]) + if extras: + timeseries = timeseries.drop(extras, axis="columns") + + timeseries.columns.names = ["time"] + + if ( + len(timeseries.index.unique()) != timeseries.shape[0] + ): # pragma: no cover # emergency valve + # shouldn't be able to get here because any issues should be caught + # by initial creation of timeseries but just in case + raise AssertionError("something not unique") + + for_xarray = ( + timeseries.T.stack(dimensions + ["_id"]) + if add_id_dimension + else timeseries.T.stack(dimensions) + ) + + return for_xarray + + +def _add_extras(xr_ds, ids, ids_dimensions, run): + # this loop could also be slow... + extra_coords = {} + for extra, id_dimension in ids_dimensions.items(): + if id_dimension in ids: + ids_extra = ids.reset_index().set_index(id_dimension) + else: + ids_extra = ( + run.meta[[extra, id_dimension]] + .drop_duplicates() + .set_index(id_dimension) + ) + + extra_coords[extra] = ( + id_dimension, + ids_extra[extra].loc[xr_ds[id_dimension].values], + ) + + xr_ds = xr_ds.assign_coords(extra_coords) + + return xr_ds + + +def _add_units(xr_ds, unit_map): + for data_var in xr_ds.data_vars: + unit = unit_map[data_var] + xr_ds[data_var].attrs["units"] = unit + + return xr_ds + + +def _var_to_nc(var): + # TODO: remove renaming in this module + return var.replace("|", "__").replace(" ", "_") + + +def _rename_variables(xr_ds): + name_mapping = {} + for data_var in xr_ds.data_vars: + serialised_name = _var_to_nc(data_var) + name_mapping[data_var] = serialised_name + xr_ds[data_var].attrs["long_name"] = data_var + + xr_ds = xr_ds.rename_vars(name_mapping) + + return xr_ds + + +def _add_scmdata_metadata(xr_ds, others): + for col in others: + vals = others[col].unique() + if len(vals) > 1: # pragma: no cover # emergency valve + # should have already been caught... + raise AssertionError("More than one value for meta: {}".format(col)) + + xr_ds.attrs["_scmdata_metadata_{}".format(col)] = vals[0] + + return xr_ds diff --git a/src/scmdata/netcdf.py b/src/scmdata/netcdf.py index 918b903a..d3b5b204 100644 --- a/src/scmdata/netcdf.py +++ b/src/scmdata/netcdf.py @@ -14,21 +14,13 @@ from datetime import datetime from logging import getLogger -import numpy as np import xarray as xr from . import __version__ -from .errors import NonUniqueMetadataError logger = getLogger(__name__) -""" -Default to writing float data as 8 byte floats -""" -DEFAULT_FLOAT = "f8" - - def _var_to_nc(var): return var.replace("|", "__").replace(" ", "_") @@ -37,7 +29,8 @@ def _write_nc(fname, run, dimensions, extras, **kwargs): """ Low level function to write the dimensions, variables and metadata to disk """ - xr_ds = _get_xr_dataset(run, dimensions, extras) + # xr_ds = _get_xr_dataset(run, dimensions, extras) + xr_ds = run.to_xarray(dimensions, extras) xr_ds.attrs["created_at"] = datetime.utcnow().isoformat() xr_ds.attrs["_scmdata_version"] = __version__ @@ -49,203 +42,6 @@ def _write_nc(fname, run, dimensions, extras, **kwargs): xr_ds.to_netcdf(fname, **write_kwargs) -def _get_xr_dataset(run, dimensions, extras): - timeseries = _get_timeseries_for_xr_dataset(run, dimensions, extras) - non_dimension_extra_metadata = _get_other_metdata_for_xr_dataset( - run, dimensions, extras - ) - - if extras: - ids, ids_dimensions = _get_ids_for_xr_dataset(run, extras, dimensions) - else: - ids = None - ids_dimensions = None - - for_xarray = _get_dataframe_for_xr_dataset( - timeseries, dimensions, extras, ids, ids_dimensions - ) - xr_ds = xr.Dataset.from_dataframe(for_xarray) - - if extras: - xr_ds = _add_extras(xr_ds, ids, ids_dimensions, run) - - unit_map = ( - run.meta[["variable", "unit"]].drop_duplicates().set_index("variable")["unit"] - ) - xr_ds = _add_units(xr_ds, unit_map) - xr_ds = _rename_variables(xr_ds) - xr_ds = _add_scmdata_metadata(xr_ds, non_dimension_extra_metadata) - - return xr_ds - - -def _get_timeseries_for_xr_dataset(run, dimensions, extras): - for d in dimensions: - vals = sorted(run.meta[d].unique()) - if not all([isinstance(v, str) for v in vals]) and np.isnan(vals).any(): - raise AssertionError("nan in dimension: `{}`".format(d)) - - try: - timeseries = run.timeseries(dimensions + extras + ["variable"]) - except NonUniqueMetadataError as exc: - error_msg = ( - "dimensions: `{}` and extras: `{}` do not uniquely define the " - "timeseries, please add extra dimensions and/or extras".format( - dimensions, extras - ) - ) - raise ValueError(error_msg) from exc - - timeseries.columns = run.time_points.as_cftime() - - return timeseries - - -def _get_other_metdata_for_xr_dataset(run, dimensions, extras): - other_dimensions = list( - set(run.meta.columns) - set(dimensions) - set(extras) - {"variable", "unit"} - ) - other_metdata = run.meta[other_dimensions].drop_duplicates() - if other_metdata.shape[0] > 1 and not other_metdata.empty: - error_msg = ( - "Other metadata is not unique for dimensions: `{}` and extras: `{}`. " - "Please add meta columns with more than one value to dimensions or " - "extras.\nNumber of unique values in each column:\n{}.\n" - "Existing values in the other metadata:\n{}.".format( - dimensions, - extras, - other_metdata.nunique(), - other_metdata.drop_duplicates(), - ) - ) - raise ValueError(error_msg) - - return other_metdata - - -def _get_ids_for_xr_dataset(run, extras, dimensions): - # these loops could be very slow with lots of extras and dimensions... - ids_dimensions = {} - for extra in extras: - for col in dimensions: - if _many_to_one(run.meta, extra, col): - dim_col = col - break - else: - dim_col = "_id" - - ids_dimensions[extra] = dim_col - - ids = run.meta[extras].drop_duplicates() - ids["_id"] = range(ids.shape[0]) - ids = ids.set_index(extras) - - return ids, ids_dimensions - - -def _many_to_one(df, col1, col2): - """ - Check if there is a many to one mapping between col2 and col1 - """ - # thanks https://stackoverflow.com/a/59091549 - checker = df[[col1, col2]].drop_duplicates() - - max_count = checker.groupby(col2).count().max()[0] - if max_count < 1: # pragma: no cover # emergency valve - raise AssertionError - - return max_count == 1 - - -def _get_dataframe_for_xr_dataset(timeseries, dimensions, extras, ids, ids_dimensions): - timeseries = timeseries.reset_index() - - add_id_dimension = extras and "_id" in set(ids_dimensions.values()) - if add_id_dimension: - timeseries = ( - timeseries.set_index(ids.index.names) - .join(ids) - .reset_index(drop=True) - .set_index(dimensions + ["variable", "_id"]) - ) - else: - timeseries = timeseries.set_index(dimensions + ["variable"]) - if extras: - timeseries = timeseries.drop(extras, axis="columns") - - timeseries.columns.names = ["time"] - - if ( - len(timeseries.index.unique()) != timeseries.shape[0] - ): # pragma: no cover # emergency valve - # shouldn't be able to get here because any issues should be caught - # by initial creation of timeseries but just in case - raise AssertionError("something not unique") - - for_xarray = ( - timeseries.T.stack(dimensions + ["_id"]) - if add_id_dimension - else timeseries.T.stack(dimensions) - ) - - return for_xarray - - -def _add_extras(xr_ds, ids, ids_dimensions, run): - # this loop could also be slow... - extra_coords = {} - for extra, id_dimension in ids_dimensions.items(): - if id_dimension in ids: - ids_extra = ids.reset_index().set_index(id_dimension) - else: - ids_extra = ( - run.meta[[extra, id_dimension]] - .drop_duplicates() - .set_index(id_dimension) - ) - - extra_coords[extra] = ( - id_dimension, - ids_extra[extra].loc[xr_ds[id_dimension].values], - ) - - xr_ds = xr_ds.assign_coords(extra_coords) - - return xr_ds - - -def _add_units(xr_ds, unit_map): - for data_var in xr_ds.data_vars: - unit = unit_map[data_var] - xr_ds[data_var].attrs["units"] = unit - - return xr_ds - - -def _rename_variables(xr_ds): - name_mapping = {} - for data_var in xr_ds.data_vars: - serialised_name = _var_to_nc(data_var) - name_mapping[data_var] = serialised_name - xr_ds[data_var].attrs["long_name"] = data_var - - xr_ds = xr_ds.rename_vars(name_mapping) - - return xr_ds - - -def _add_scmdata_metadata(xr_ds, others): - for col in others: - vals = others[col].unique() - if len(vals) > 1: # pragma: no cover # emergency valve - # should have already been caught... - raise AssertionError("More than one value for meta: {}".format(col)) - - xr_ds.attrs["_scmdata_metadata_{}".format(col)] = vals[0] - - return xr_ds - - def _read_nc(cls, fname): loaded = xr.load_dataset(fname) dataframe = loaded.to_dataframe() diff --git a/src/scmdata/run.py b/src/scmdata/run.py index d646a08a..c154f5ac 100644 --- a/src/scmdata/run.py +++ b/src/scmdata/run.py @@ -22,6 +22,7 @@ from openscm_units import unit_registry as ur from xarray.core.ops import inject_binary_ops +from ._xarray import _to_xarray from .errors import MissingRequiredColumnError, NonUniqueMetadataError from .filters import ( HIERARCHY_SEPARATOR, @@ -1984,6 +1985,9 @@ def reduce(self, func, dim=None, axis=None, **kwargs): return type(self)(data, index=index, columns=meta) + def to_xarray(self, dimensions, extras): + return _to_xarray(self, dimensions, extras) + def _merge_metadata(metadata): res = metadata[0].copy() diff --git a/tests/unit/test_netcdf.py b/tests/unit/test_netcdf.py index ced183b0..c1314ec5 100644 --- a/tests/unit/test_netcdf.py +++ b/tests/unit/test_netcdf.py @@ -13,7 +13,7 @@ import xarray as xr from scmdata import ScmRun -from scmdata.netcdf import _get_xr_dataset, nc_to_run, run_to_nc +from scmdata.netcdf import nc_to_run, run_to_nc from scmdata.testing import assert_scmdf_almost_equal @@ -667,21 +667,22 @@ def test_run_to_nc_loop_tricky_variable_name(scm_run, start_variable): assert_scmdf_almost_equal(scm_run, loaded, check_ts_names=False) -@patch("scmdata.netcdf._get_xr_dataset") -def test_run_to_nc_xarray_kwarg_passing(mock_get_xr_dataset, scm_run, tmpdir): +def test_run_to_nc_xarray_kwarg_passing(scm_run, tmpdir): dimensions = ["scenario"] extras = [] + mock_ds = MagicMock() - mock_ds.data_vars = _get_xr_dataset(scm_run, dimensions, extras).data_vars - mock_get_xr_dataset.return_value = mock_ds + mock_ds.data_vars = scm_run.to_xarray(dimensions, extras).data_vars + + mock_scm_run = MagicMock() + mock_scm_run.to_xarray.return_value = mock_ds out_fname = join(tmpdir, "out.nc") - run_to_nc(scm_run, out_fname, dimensions=dimensions, extras=extras, engine="engine") + run_to_nc(mock_scm_run, out_fname, dimensions=dimensions, extras=extras, engine="engine") mock_ds.to_netcdf.assert_called_with(out_fname, engine="engine") -@patch("scmdata.netcdf._get_xr_dataset") @pytest.mark.parametrize( "in_kwargs,call_kwargs", ( @@ -708,17 +709,19 @@ def test_run_to_nc_xarray_kwarg_passing(mock_get_xr_dataset, scm_run, tmpdir): ), ) def test_run_to_nc_xarray_kwarg_passing_variable_renaming( - mock_get_xr_dataset, scm_run, tmpdir, in_kwargs, call_kwargs + scm_run, tmpdir, in_kwargs, call_kwargs ): dimensions = ["scenario"] extras = [] mock_ds = MagicMock() - mock_ds.data_vars = _get_xr_dataset(scm_run, dimensions, extras).data_vars - mock_get_xr_dataset.return_value = mock_ds + mock_ds.data_vars = scm_run.to_xarray(dimensions, extras).data_vars + + mock_scm_run = MagicMock() + mock_scm_run.to_xarray.return_value = mock_ds out_fname = join(tmpdir, "out.nc") - run_to_nc(scm_run, out_fname, dimensions=("scenario",), **in_kwargs) + run_to_nc(mock_scm_run, out_fname, dimensions=("scenario",), **in_kwargs) # variable should be renamed so it matches what goes to disk mock_ds.to_netcdf.assert_called_with(out_fname, **call_kwargs) diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index 25ab4134..1b230cda 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -10,6 +10,7 @@ import numpy as np import pandas as pd import pytest +import xarray as xr from numpy import testing as npt from packaging.version import parse from pandas.errors import UnsupportedFunctionCall @@ -3047,3 +3048,10 @@ def test_drop_meta_nonunique(): with pytest.raises(NonUniqueMetadataError): start.drop_meta("new_meta") + + +def test_to_xarray(scm_run): + res = scm_run.to_xarray() + + assert isinstance(res, xr.DataSet) + assert False From 3243912f4a8a64a452a446f56cefd6500244cf29 Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 11:51:48 +1000 Subject: [PATCH 02/11] Clean up renaming --- src/scmdata/_xarray.py | 18 ------------------ src/scmdata/netcdf.py | 22 ++++++++++++++++++++-- src/scmdata/run.py | 10 ++++++++++ tests/unit/test_netcdf.py | 25 +++++++++++-------------- tests/unit/test_run.py | 8 -------- 5 files changed, 41 insertions(+), 42 deletions(-) diff --git a/src/scmdata/_xarray.py b/src/scmdata/_xarray.py index af350f38..1c6fe59b 100644 --- a/src/scmdata/_xarray.py +++ b/src/scmdata/_xarray.py @@ -31,7 +31,6 @@ def _to_xarray(run, dimensions, extras): run.meta[["variable", "unit"]].drop_duplicates().set_index("variable")["unit"] ) xr_ds = _add_units(xr_ds, unit_map) - xr_ds = _rename_variables(xr_ds) xr_ds = _add_scmdata_metadata(xr_ds, non_dimension_extra_metadata) return xr_ds @@ -180,23 +179,6 @@ def _add_units(xr_ds, unit_map): return xr_ds -def _var_to_nc(var): - # TODO: remove renaming in this module - return var.replace("|", "__").replace(" ", "_") - - -def _rename_variables(xr_ds): - name_mapping = {} - for data_var in xr_ds.data_vars: - serialised_name = _var_to_nc(data_var) - name_mapping[data_var] = serialised_name - xr_ds[data_var].attrs["long_name"] = data_var - - xr_ds = xr_ds.rename_vars(name_mapping) - - return xr_ds - - def _add_scmdata_metadata(xr_ds, others): for col in others: vals = others[col].unique() diff --git a/src/scmdata/netcdf.py b/src/scmdata/netcdf.py index d3b5b204..31b48e2d 100644 --- a/src/scmdata/netcdf.py +++ b/src/scmdata/netcdf.py @@ -25,12 +25,30 @@ def _var_to_nc(var): return var.replace("|", "__").replace(" ", "_") +def _rename_variables(xr_ds): + name_mapping = {} + for data_var in xr_ds.data_vars: + serialised_name = _var_to_nc(data_var) + name_mapping[data_var] = serialised_name + xr_ds[data_var].attrs["long_name"] = data_var + + xr_ds = xr_ds.rename_vars(name_mapping) + + return xr_ds + + +def _get_xr_dataset_to_write(run, dimensions, extras): + xr_ds = run.to_xarray(dimensions, extras) + xr_ds = _rename_variables(xr_ds) + + return xr_ds + + def _write_nc(fname, run, dimensions, extras, **kwargs): """ Low level function to write the dimensions, variables and metadata to disk """ - # xr_ds = _get_xr_dataset(run, dimensions, extras) - xr_ds = run.to_xarray(dimensions, extras) + xr_ds = _get_xr_dataset_to_write(run, dimensions, extras) xr_ds.attrs["created_at"] = datetime.utcnow().isoformat() xr_ds.attrs["_scmdata_version"] = __version__ diff --git a/src/scmdata/run.py b/src/scmdata/run.py index c154f5ac..c5ff90bc 100644 --- a/src/scmdata/run.py +++ b/src/scmdata/run.py @@ -1986,6 +1986,16 @@ def reduce(self, func, dim=None, axis=None, **kwargs): return type(self)(data, index=index, columns=meta) def to_xarray(self, dimensions, extras): + """ + Convert to a :class:`xr.Dataset` + + TODO: write rest + + Returns + ------- + :obj:`xr.Dataset` + Data in self, re-formatted as an :obj:`xr.Dataset` + """ return _to_xarray(self, dimensions, extras) diff --git a/tests/unit/test_netcdf.py b/tests/unit/test_netcdf.py index c1314ec5..db223776 100644 --- a/tests/unit/test_netcdf.py +++ b/tests/unit/test_netcdf.py @@ -13,7 +13,7 @@ import xarray as xr from scmdata import ScmRun -from scmdata.netcdf import nc_to_run, run_to_nc +from scmdata.netcdf import _get_xr_dataset_to_write, nc_to_run, run_to_nc from scmdata.testing import assert_scmdf_almost_equal @@ -667,22 +667,21 @@ def test_run_to_nc_loop_tricky_variable_name(scm_run, start_variable): assert_scmdf_almost_equal(scm_run, loaded, check_ts_names=False) -def test_run_to_nc_xarray_kwarg_passing(scm_run, tmpdir): +@patch("scmdata.netcdf._get_xr_dataset_to_write") +def test_run_to_nc_xarray_kwarg_passing(mock_get_xr_dataset, scm_run, tmpdir): dimensions = ["scenario"] extras = [] - mock_ds = MagicMock() - mock_ds.data_vars = scm_run.to_xarray(dimensions, extras).data_vars - - mock_scm_run = MagicMock() - mock_scm_run.to_xarray.return_value = mock_ds + mock_ds.data_vars = _get_xr_dataset_to_write(scm_run, dimensions, extras).data_vars + mock_get_xr_dataset.return_value = mock_ds out_fname = join(tmpdir, "out.nc") - run_to_nc(mock_scm_run, out_fname, dimensions=dimensions, extras=extras, engine="engine") + run_to_nc(scm_run, out_fname, dimensions=dimensions, extras=extras, engine="engine") mock_ds.to_netcdf.assert_called_with(out_fname, engine="engine") +@patch("scmdata.netcdf._get_xr_dataset_to_write") @pytest.mark.parametrize( "in_kwargs,call_kwargs", ( @@ -709,19 +708,17 @@ def test_run_to_nc_xarray_kwarg_passing(scm_run, tmpdir): ), ) def test_run_to_nc_xarray_kwarg_passing_variable_renaming( - scm_run, tmpdir, in_kwargs, call_kwargs + mock_get_xr_dataset, scm_run, tmpdir, in_kwargs, call_kwargs ): dimensions = ["scenario"] extras = [] mock_ds = MagicMock() - mock_ds.data_vars = scm_run.to_xarray(dimensions, extras).data_vars - - mock_scm_run = MagicMock() - mock_scm_run.to_xarray.return_value = mock_ds + mock_ds.data_vars = _get_xr_dataset_to_write(scm_run, dimensions, extras).data_vars + mock_get_xr_dataset.return_value = mock_ds out_fname = join(tmpdir, "out.nc") - run_to_nc(mock_scm_run, out_fname, dimensions=("scenario",), **in_kwargs) + run_to_nc(scm_run, out_fname, dimensions=("scenario",), **in_kwargs) # variable should be renamed so it matches what goes to disk mock_ds.to_netcdf.assert_called_with(out_fname, **call_kwargs) diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index 1b230cda..25ab4134 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -10,7 +10,6 @@ import numpy as np import pandas as pd import pytest -import xarray as xr from numpy import testing as npt from packaging.version import parse from pandas.errors import UnsupportedFunctionCall @@ -3048,10 +3047,3 @@ def test_drop_meta_nonunique(): with pytest.raises(NonUniqueMetadataError): start.drop_meta("new_meta") - - -def test_to_xarray(scm_run): - res = scm_run.to_xarray() - - assert isinstance(res, xr.DataSet) - assert False From 636940443465e9c3ee88c8c40878b7863830915d Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 12:01:39 +1000 Subject: [PATCH 03/11] Start sketching out tests --- src/scmdata/run.py | 5 ++++- tests/unit/test_xarray.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_xarray.py diff --git a/src/scmdata/run.py b/src/scmdata/run.py index c5ff90bc..8ac4b520 100644 --- a/src/scmdata/run.py +++ b/src/scmdata/run.py @@ -1985,7 +1985,7 @@ def reduce(self, func, dim=None, axis=None, **kwargs): return type(self)(data, index=index, columns=meta) - def to_xarray(self, dimensions, extras): + def to_xarray(self, dimensions=("region",), extras=()): """ Convert to a :class:`xr.Dataset` @@ -1996,6 +1996,9 @@ def to_xarray(self, dimensions, extras): :obj:`xr.Dataset` Data in self, re-formatted as an :obj:`xr.Dataset` """ + dimensions = list(dimensions) + extras = list(extras) + return _to_xarray(self, dimensions, extras) diff --git a/tests/unit/test_xarray.py b/tests/unit/test_xarray.py new file mode 100644 index 00000000..c9533309 --- /dev/null +++ b/tests/unit/test_xarray.py @@ -0,0 +1,18 @@ +import xarray as xr + + +def test_to_xarray(scm_run): + res = scm_run.to_xarray(dimensions=("region", "scenario")) + + assert isinstance(res, xr.Dataset) + assert set(res.data_vars) == set(scm_run.get_unique_meta("variable")) + + assert False, "check dimensions as expected" + assert False, "check extras as expected" + assert False, "check units as expected" + assert False, "check metadata as expected" + +# Tests to write: +# - dimensions handling +# - extras handling +# - weird variable name handling From 465ed9b0419c4357675734247e48346c78f8c5a5 Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 15:05:47 +1000 Subject: [PATCH 04/11] Add basic test of to xarray --- src/scmdata/_xarray.py | 49 +++++++++++++++++++++++++++++++++++---- src/scmdata/run.py | 20 ++-------------- tests/unit/test_xarray.py | 30 +++++++++++++++++++----- 3 files changed, 70 insertions(+), 29 deletions(-) diff --git a/src/scmdata/_xarray.py b/src/scmdata/_xarray.py index 1c6fe59b..7fbcf758 100644 --- a/src/scmdata/_xarray.py +++ b/src/scmdata/_xarray.py @@ -7,20 +7,40 @@ from .errors import NonUniqueMetadataError -def _to_xarray(run, dimensions, extras): - timeseries = _get_timeseries_for_xr_dataset(run, dimensions, extras) +def to_xarray(run, dimensions=("region",), extras=()): + """ + Convert to a :class:`xr.Dataset` + + Parameters + ---------- + dimensions : iterable of str + Dimensions for each variable in the returned dataset. If `"time"` is not included in ``dimensions`` it will be the last dimension. + + extras : iterable of str + TODO: write + + Returns + ------- + :obj:`xr.Dataset` + Data in self, re-formatted as an :obj:`xr.Dataset` + """ + dimensions = list(dimensions) + extras = list(extras) + + timeseries_dims = list(set(dimensions) - {"time"}) + timeseries = _get_timeseries_for_xr_dataset(run, timeseries_dims, extras) non_dimension_extra_metadata = _get_other_metdata_for_xr_dataset( run, dimensions, extras ) if extras: - ids, ids_dimensions = _get_ids_for_xr_dataset(run, extras, dimensions) + ids, ids_dimensions = _get_ids_for_xr_dataset(run, extras, timeseries_dims) else: ids = None ids_dimensions = None for_xarray = _get_dataframe_for_xr_dataset( - timeseries, dimensions, extras, ids, ids_dimensions + timeseries, timeseries_dims, extras, ids, ids_dimensions ) xr_ds = xr.Dataset.from_dataframe(for_xarray) @@ -32,6 +52,8 @@ def _to_xarray(run, dimensions, extras): ) xr_ds = _add_units(xr_ds, unit_map) xr_ds = _add_scmdata_metadata(xr_ds, non_dimension_extra_metadata) + out_dimensions = dimensions if "time" in dimensions else dimensions + ["time"] + xr_ds = xr_ds.transpose(*out_dimensions) return xr_ds @@ -186,6 +208,23 @@ def _add_scmdata_metadata(xr_ds, others): # should have already been caught... raise AssertionError("More than one value for meta: {}".format(col)) - xr_ds.attrs["_scmdata_metadata_{}".format(col)] = vals[0] + xr_ds.attrs["scmdata_metadata_{}".format(col)] = vals[0] return xr_ds + + +def inject_xarray_methods(cls): + """ + Inject the xarray methods + + Parameters + ---------- + cls + Target class + """ + methods = [ + ("to_xarray", to_xarray), + ] + + for name, f in methods: + setattr(cls, name, f) diff --git a/src/scmdata/run.py b/src/scmdata/run.py index 8ac4b520..a31dcb26 100644 --- a/src/scmdata/run.py +++ b/src/scmdata/run.py @@ -22,7 +22,7 @@ from openscm_units import unit_registry as ur from xarray.core.ops import inject_binary_ops -from ._xarray import _to_xarray +from ._xarray import inject_xarray_methods from .errors import MissingRequiredColumnError, NonUniqueMetadataError from .filters import ( HIERARCHY_SEPARATOR, @@ -1985,22 +1985,6 @@ def reduce(self, func, dim=None, axis=None, **kwargs): return type(self)(data, index=index, columns=meta) - def to_xarray(self, dimensions=("region",), extras=()): - """ - Convert to a :class:`xr.Dataset` - - TODO: write rest - - Returns - ------- - :obj:`xr.Dataset` - Data in self, re-formatted as an :obj:`xr.Dataset` - """ - dimensions = list(dimensions) - extras = list(extras) - - return _to_xarray(self, dimensions, extras) - def _merge_metadata(metadata): res = metadata[0].copy() @@ -2202,7 +2186,7 @@ def _handle_potential_duplicates_in_append(data, duplicate_msg): inject_nc_methods(BaseScmRun) inject_plotting_methods(BaseScmRun) inject_ops_methods(BaseScmRun) - +inject_xarray_methods(BaseScmRun) class ScmRun(BaseScmRun): """ diff --git a/tests/unit/test_xarray.py b/tests/unit/test_xarray.py index c9533309..59b8365b 100644 --- a/tests/unit/test_xarray.py +++ b/tests/unit/test_xarray.py @@ -1,18 +1,36 @@ +import pytest import xarray as xr -def test_to_xarray(scm_run): - res = scm_run.to_xarray(dimensions=("region", "scenario")) +@pytest.mark.parametrize("dimensions,expected_dimensions", ( + (("region", "scenario", "time"), ("region", "scenario", "time")), + (("time", "region", "scenario"), ("time", "region", "scenario")), + (("region", "time", "scenario"), ("region", "time", "scenario")), + (("region", "scenario"), ("region", "scenario", "time")), + (("scenario", "region"), ("scenario", "region", "time")), + (("scenario",), ("scenario", "time")), +)) +def test_to_xarray_dimension_order(scm_run, dimensions, expected_dimensions): + res = scm_run.to_xarray(dimensions=dimensions) assert isinstance(res, xr.Dataset) assert set(res.data_vars) == set(scm_run.get_unique_meta("variable")) - assert False, "check dimensions as expected" - assert False, "check extras as expected" - assert False, "check units as expected" - assert False, "check metadata as expected" + for variable_name, data_var in res.data_vars.items(): + assert data_var.dims == expected_dimensions + # no extras + assert not set(data_var.coords) - set(data_var.dims) + + unit = scm_run.filter(variable=variable_name).get_unique_meta("unit", True) + assert data_var.units == unit + + # all other metadata should be in attrs + for meta_col in set(scm_run.meta.columns) - set(dimensions) - {"variable", "unit"}: + meta_val = scm_run.get_unique_meta(meta_col, True) + assert res.attrs["scmdata_metadata_{}".format(meta_col)] == meta_val # Tests to write: # - dimensions handling # - extras handling # - weird variable name handling +# - multiple units for given variable From 4f287a1050c6310fa004ce3e313c842ff79d9876 Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 15:37:08 +1000 Subject: [PATCH 05/11] Add first extras tests --- tests/unit/test_xarray.py | 77 ++++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 13 deletions(-) diff --git a/tests/unit/test_xarray.py b/tests/unit/test_xarray.py index 59b8365b..63070e5f 100644 --- a/tests/unit/test_xarray.py +++ b/tests/unit/test_xarray.py @@ -1,7 +1,38 @@ +import numpy as np +import numpy.testing as npt import pytest import xarray as xr +def do_basic_to_xarray_checks(res, start_run, dimensions, extras): + assert isinstance(res, xr.Dataset) + assert set(res.data_vars) == set(start_run.get_unique_meta("variable")) + + for variable_name, data_var in res.data_vars.items(): + assert data_var.dims == dimensions + + unit = start_run.filter(variable=variable_name).get_unique_meta("unit", True) + assert data_var.units == unit + + # check a couple of data points to make sure the translation is correct + for idx in [0, -1]: + xarray_spot = data_var.isel({v: idx for v in dimensions}) + fkwargs = {k: [v.values.tolist()] for k, v in xarray_spot.coords.items()} + fkwargs["variable"] = variable_name + + start_run_spot = start_run.filter(**fkwargs) + if np.isnan(xarray_spot): + assert start_run_spot.empty + else: + start_run_vals = float(start_run_spot.values.squeeze()) + npt.assert_array_equal(xarray_spot.values, start_run_vals) + + # all other metadata should be in attrs + for meta_col in set(start_run.meta.columns) - set(dimensions) - set(extras) - {"variable", "unit"}: + meta_val = start_run.get_unique_meta(meta_col, True) + assert res.attrs["scmdata_metadata_{}".format(meta_col)] == meta_val + + @pytest.mark.parametrize("dimensions,expected_dimensions", ( (("region", "scenario", "time"), ("region", "scenario", "time")), (("time", "region", "scenario"), ("time", "region", "scenario")), @@ -10,27 +41,47 @@ (("scenario", "region"), ("scenario", "region", "time")), (("scenario",), ("scenario", "time")), )) -def test_to_xarray_dimension_order(scm_run, dimensions, expected_dimensions): +def test_to_xarray(scm_run, dimensions, expected_dimensions): res = scm_run.to_xarray(dimensions=dimensions) - assert isinstance(res, xr.Dataset) - assert set(res.data_vars) == set(scm_run.get_unique_meta("variable")) + do_basic_to_xarray_checks(res, scm_run, expected_dimensions, (),) - for variable_name, data_var in res.data_vars.items(): - assert data_var.dims == expected_dimensions - # no extras - assert not set(data_var.coords) - set(data_var.dims) + # no extras + assert not set(res.coords) - set(res.dims) - unit = scm_run.filter(variable=variable_name).get_unique_meta("unit", True) - assert data_var.units == unit - # all other metadata should be in attrs - for meta_col in set(scm_run.meta.columns) - set(dimensions) - {"variable", "unit"}: - meta_val = scm_run.get_unique_meta(meta_col, True) - assert res.attrs["scmdata_metadata_{}".format(meta_col)] == meta_val +@pytest.mark.parametrize("extras", ( + ("model",), + ("climate_model",), + ("climate_model", "model"), +)) +def test_to_xarray_extras(scm_run, extras): + dimensions = ("scenario", "region", "time") + res = scm_run.to_xarray(dimensions=dimensions, extras=extras) + + do_basic_to_xarray_checks(res, scm_run, dimensions, extras) + + assert set(extras) == set(res.coords) - set(res.dims) + + scm_run_meta = scm_run.meta + for extra_col in extras: + xarray_vals = res[extra_col].values + extra_dims = res[extra_col].dims + assert len(extra_dims) == 1 + extra_dims = extra_dims[0] + xarray_coords = res[extra_col][extra_dims].values + + for xarray_extra_val, extra_xarray_coord in zip(xarray_vals, xarray_coords): + scm_run_extra_val = scm_run_meta[scm_run_meta[extra_dims] == extra_xarray_coord][extra_col].unique().tolist() + assert len(scm_run_extra_val) == 1 + scm_run_extra_val = scm_run_extra_val[0] + + assert scm_run_extra_val == xarray_extra_val # Tests to write: # - dimensions handling # - extras handling # - weird variable name handling # - multiple units for given variable +# - overlapping dimensions and extras +# - underdefined dimensions and extras From b7f9dad33d23e5632c8482a27426290ab19d275a Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 16:58:12 +1000 Subject: [PATCH 06/11] Add tests of handling with id coord --- src/scmdata/_xarray.py | 20 ++++++--- tests/unit/test_xarray.py | 90 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 99 insertions(+), 11 deletions(-) diff --git a/src/scmdata/_xarray.py b/src/scmdata/_xarray.py index 7fbcf758..45c05808 100644 --- a/src/scmdata/_xarray.py +++ b/src/scmdata/_xarray.py @@ -14,10 +14,10 @@ def to_xarray(run, dimensions=("region",), extras=()): Parameters ---------- dimensions : iterable of str - Dimensions for each variable in the returned dataset. If `"time"` is not included in ``dimensions`` it will be the last dimension. + Dimensions for each variable in the returned dataset. If ``"time"`` is not included in ``dimensions`` it will be the last dimension. If ``"_id"`` is required (see ``extras`` documentation for when ``"_id"`` is required) and is not included in ``dimensions`` then it will be the last dimension. extras : iterable of str - TODO: write + TODO: write (when _id is added) Returns ------- @@ -27,7 +27,7 @@ def to_xarray(run, dimensions=("region",), extras=()): dimensions = list(dimensions) extras = list(extras) - timeseries_dims = list(set(dimensions) - {"time"}) + timeseries_dims = list(set(dimensions) - {"time"} - {"_id"}) timeseries = _get_timeseries_for_xr_dataset(run, timeseries_dims, extras) non_dimension_extra_metadata = _get_other_metdata_for_xr_dataset( run, dimensions, extras @@ -52,8 +52,7 @@ def to_xarray(run, dimensions=("region",), extras=()): ) xr_ds = _add_units(xr_ds, unit_map) xr_ds = _add_scmdata_metadata(xr_ds, non_dimension_extra_metadata) - out_dimensions = dimensions if "time" in dimensions else dimensions + ["time"] - xr_ds = xr_ds.transpose(*out_dimensions) + xr_ds = _set_dimensions(xr_ds, dimensions) return xr_ds @@ -213,6 +212,17 @@ def _add_scmdata_metadata(xr_ds, others): return xr_ds +def _set_dimensions(xr_ds, dimensions): + out_dimensions = dimensions + if "time" not in dimensions: + out_dimensions += ["time"] + + if "_id" in xr_ds.dims and "_id" not in dimensions: + out_dimensions += ["_id"] + + return xr_ds.transpose(*out_dimensions) + + def inject_xarray_methods(cls): """ Inject the xarray methods diff --git a/tests/unit/test_xarray.py b/tests/unit/test_xarray.py index 63070e5f..3a6e84ba 100644 --- a/tests/unit/test_xarray.py +++ b/tests/unit/test_xarray.py @@ -1,5 +1,6 @@ import numpy as np import numpy.testing as npt +import pandas as pd import pytest import xarray as xr @@ -14,6 +15,14 @@ def do_basic_to_xarray_checks(res, start_run, dimensions, extras): unit = start_run.filter(variable=variable_name).get_unique_meta("unit", True) assert data_var.units == unit + # all other metadata should be in attrs + for meta_col in set(start_run.meta.columns) - set(dimensions) - set(extras) - {"variable", "unit"}: + meta_val = start_run.get_unique_meta(meta_col, True) + assert res.attrs["scmdata_metadata_{}".format(meta_col)] == meta_val + + +def do_basic_check_of_data_points(res, start_run, dimensions): + for variable_name, data_var in res.data_vars.items(): # check a couple of data points to make sure the translation is correct for idx in [0, -1]: xarray_spot = data_var.isel({v: idx for v in dimensions}) @@ -27,11 +36,6 @@ def do_basic_to_xarray_checks(res, start_run, dimensions, extras): start_run_vals = float(start_run_spot.values.squeeze()) npt.assert_array_equal(xarray_spot.values, start_run_vals) - # all other metadata should be in attrs - for meta_col in set(start_run.meta.columns) - set(dimensions) - set(extras) - {"variable", "unit"}: - meta_val = start_run.get_unique_meta(meta_col, True) - assert res.attrs["scmdata_metadata_{}".format(meta_col)] == meta_val - @pytest.mark.parametrize("dimensions,expected_dimensions", ( (("region", "scenario", "time"), ("region", "scenario", "time")), @@ -45,6 +49,7 @@ def test_to_xarray(scm_run, dimensions, expected_dimensions): res = scm_run.to_xarray(dimensions=dimensions) do_basic_to_xarray_checks(res, scm_run, expected_dimensions, (),) + do_basic_check_of_data_points(res, scm_run, expected_dimensions) # no extras assert not set(res.coords) - set(res.dims) @@ -55,11 +60,12 @@ def test_to_xarray(scm_run, dimensions, expected_dimensions): ("climate_model",), ("climate_model", "model"), )) -def test_to_xarray_extras(scm_run, extras): +def test_to_xarray_extras_no_id_coord(scm_run, extras): dimensions = ("scenario", "region", "time") res = scm_run.to_xarray(dimensions=dimensions, extras=extras) do_basic_to_xarray_checks(res, scm_run, dimensions, extras) + do_basic_check_of_data_points(res, scm_run, dimensions) assert set(extras) == set(res.coords) - set(res.dims) @@ -78,6 +84,78 @@ def test_to_xarray_extras(scm_run, extras): assert scm_run_extra_val == xarray_extra_val + +@pytest.mark.parametrize("extras", ( + ("scenario", "model", "random_key"), +)) +@pytest.mark.parametrize("dimensions,expected_dimensions", ( + (("climate_model", "run_id"), ("climate_model", "run_id", "time", "_id")), + (("run_id", "climate_model"), ("run_id", "climate_model", "time", "_id")), + (("run_id", "climate_model", "time"), ("run_id", "climate_model", "time", "_id")), + (("run_id", "time", "climate_model"), ("run_id", "time", "climate_model", "_id")), + (("run_id", "climate_model", "time", "_id"), ("run_id", "climate_model", "time", "_id")), + (("_id", "run_id", "time", "climate_model"), ("_id", "run_id", "time", "climate_model")), + (("run_id", "_id", "climate_model"), ("run_id", "_id", "climate_model", "time")), +)) +def test_to_xarray_extras_with_id_coord(scm_run, extras, dimensions, expected_dimensions): + df = scm_run.timeseries() + val_cols = df.columns.tolist() + df = df.reset_index() + + df["climate_model"] = "base_m" + df["run_id"] = 1 + df.loc[:, val_cols] = np.random.rand(df.shape[0], len(val_cols)) + + big_df = [df] + for climate_model in ["abc_m", "def_m", "ghi_m"]: + for run_id in range(10): + new_df = df.copy() + new_df["run_id"] = run_id + new_df["climate_model"] = climate_model + new_df.loc[:, val_cols] = np.random.rand(df.shape[0], len(val_cols)) + + big_df.append(new_df) + + big_df = pd.concat(big_df).reset_index(drop=True) + big_df["random_key"] = (100 * np.random.random(big_df.shape[0])).astype(int) + scm_run = scm_run.__class__(big_df) + + res = scm_run.to_xarray(dimensions=dimensions, extras=extras) + + do_basic_to_xarray_checks(res, scm_run, expected_dimensions, extras) + + assert set(extras) == set(res.coords) - set(res.dims) + + # check a couple of data points to make sure the translation is correct + # and well-defined + scm_run_meta = scm_run.meta + for id_val in res["_id"].values[::10]: + xarray_timeseries = res.sel(_id=id_val) + fkwargs = {} + for extra_col in extras: + val = xarray_timeseries[extra_col].values.tolist() + if isinstance(val, list): + assert len(set(val)) == 1 + fkwargs[extra_col] = val[0] + else: + fkwargs[extra_col] = val + + for i, (key, value) in enumerate(fkwargs.items()): + if i < 1: + keep_meta_rows = scm_run_meta[key] == value + else: + keep_meta_rows &= scm_run_meta[key] == value + + meta_timeseries = scm_run_meta[keep_meta_rows] + for _, row in meta_timeseries.iterrows(): + scm_run_filter = row.to_dict() + scm_run_spot = scm_run.filter(**scm_run_filter) + + xarray_sel = {k: v for k, v in scm_run_filter.items() if k in xarray_timeseries.dims} + xarray_spot = xarray_timeseries.sel(**xarray_sel)[scm_run_filter["variable"]] + + npt.assert_array_equal(scm_run_spot.values.squeeze(), xarray_spot.values.squeeze()) + # Tests to write: # - dimensions handling # - extras handling From 595bb8c6164e77666ec93e3b1b8b281e948bc680 Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 17:02:58 +1000 Subject: [PATCH 07/11] Add test of weird variable names --- tests/unit/test_xarray.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_xarray.py b/tests/unit/test_xarray.py index 3a6e84ba..ed7bd23a 100644 --- a/tests/unit/test_xarray.py +++ b/tests/unit/test_xarray.py @@ -156,10 +156,27 @@ def test_to_xarray_extras_with_id_coord(scm_run, extras, dimensions, expected_di npt.assert_array_equal(scm_run_spot.values.squeeze(), xarray_spot.values.squeeze()) + +@pytest.mark.parametrize("ch", "!@#$%^&*()~`+={}]<>,;:'\".") +@pytest.mark.parametrize("weird_idx", (0, -1, 5)) +def test_to_xarray_weird_names(scm_run, ch, weird_idx): + new_vars = [] + for i, variable_name in enumerate(scm_run.get_unique_meta("variable")): + if i < 1: + new_name = list(variable_name) + new_name.insert(weird_idx, ch) + new_name = "".join(new_name) + new_vars.append(new_name) + else: + new_vars.append(variable_name) + + dimensions = ("region", "scenario", "time") + res = scm_run.to_xarray(dimensions=dimensions) + + do_basic_to_xarray_checks(res, scm_run, dimensions, (),) + do_basic_check_of_data_points(res, scm_run, dimensions) + # Tests to write: -# - dimensions handling -# - extras handling -# - weird variable name handling # - multiple units for given variable # - overlapping dimensions and extras # - underdefined dimensions and extras From fe63a2bbe1e9bbf265f9a70e3b61f66838325af8 Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 18:05:31 +1000 Subject: [PATCH 08/11] Add tests of unit conversion --- src/scmdata/_xarray.py | 67 +++++++++++++++++++++++++++++++++------ tests/unit/test_xarray.py | 66 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 121 insertions(+), 12 deletions(-) diff --git a/src/scmdata/_xarray.py b/src/scmdata/_xarray.py index 45c05808..1ce0d8b7 100644 --- a/src/scmdata/_xarray.py +++ b/src/scmdata/_xarray.py @@ -2,39 +2,50 @@ Interface with `xarray `_ """ import numpy as np +import pint.errors import xarray as xr from .errors import NonUniqueMetadataError -def to_xarray(run, dimensions=("region",), extras=()): +def to_xarray(self, dimensions=("region",), extras=(), unify_units=True): """ Convert to a :class:`xr.Dataset` Parameters ---------- dimensions : iterable of str - Dimensions for each variable in the returned dataset. If ``"time"`` is not included in ``dimensions`` it will be the last dimension. If ``"_id"`` is required (see ``extras`` documentation for when ``"_id"`` is required) and is not included in ``dimensions`` then it will be the last dimension. + Dimensions for each variable in the returned dataset. If ``"time"`` is not included in ``dimensions`` it will be the last dimension. If an "_id" co-ordinate is required (see ``extras`` documentation for when "_id" is required) and is not included in ``dimensions`` then it will be the last dimension. extras : iterable of str - TODO: write (when _id is added) + Columns in ``self.meta`` from which to create "non-dimension co-ordinates" (see `xarray terminology `_ for more details). These non-dimension co-ordinates store extra information and can be mapped to each timeseries found in the data variables of the output :obj:`xr.Dataset`. Where possible, these non-dimension co-ordinates will use dimension co-ordinates as their own co-ordinates. However, if the metadata in ``extras`` is not defined by a single dimension in ``dimensions``, then the ``extras`` co-ordinates will have dimensions of "_id". This "_id" co-ordinate maps the values in the ``extras`` co-ordinates to each timeseries in the serialised dataset. Where "_id" is required, an extra "_id" dimension will also be added to ``dimensions``. + + unify_units : bool + If a given variable has multiple units, should we attempt to unify them? Returns ------- :obj:`xr.Dataset` Data in self, re-formatted as an :obj:`xr.Dataset` + + Raises + ------ + ValueError + If a given variable has multiple units and ``unify_units`` is ``False``. """ dimensions = list(dimensions) extras = list(extras) timeseries_dims = list(set(dimensions) - {"time"} - {"_id"}) - timeseries = _get_timeseries_for_xr_dataset(run, timeseries_dims, extras) + + self_unified_units = _unify_scmrun_units(self, unify_units) + timeseries = _get_timeseries_for_xr_dataset(self_unified_units, timeseries_dims, extras, unify_units) non_dimension_extra_metadata = _get_other_metdata_for_xr_dataset( - run, dimensions, extras + self_unified_units, dimensions, extras ) if extras: - ids, ids_dimensions = _get_ids_for_xr_dataset(run, extras, timeseries_dims) + ids, ids_dimensions = _get_ids_for_xr_dataset(self_unified_units, extras, timeseries_dims) else: ids = None ids_dimensions = None @@ -45,10 +56,10 @@ def to_xarray(run, dimensions=("region",), extras=()): xr_ds = xr.Dataset.from_dataframe(for_xarray) if extras: - xr_ds = _add_extras(xr_ds, ids, ids_dimensions, run) + xr_ds = _add_extras(xr_ds, ids, ids_dimensions, self_unified_units) unit_map = ( - run.meta[["variable", "unit"]].drop_duplicates().set_index("variable")["unit"] + self_unified_units.meta[["variable", "unit"]].drop_duplicates().set_index("variable")["unit"] ) xr_ds = _add_units(xr_ds, unit_map) xr_ds = _add_scmdata_metadata(xr_ds, non_dimension_extra_metadata) @@ -57,7 +68,40 @@ def to_xarray(run, dimensions=("region",), extras=()): return xr_ds -def _get_timeseries_for_xr_dataset(run, dimensions, extras): +def _unify_scmrun_units(run, unify_units): + variable_unit_table = run.meta[["variable", "unit"]].drop_duplicates() + variable_units = variable_unit_table.set_index("variable")["unit"] + + variable_counts = variable_unit_table["variable"].value_counts() + more_than_one_unit_variables = variable_counts[variable_counts > 1] + if not more_than_one_unit_variables.empty: + if not unify_units: + error_msg = ( + "The following variables are reported in more than one unit. " + "Found variable-unit combinations are:\n{}".format( + variable_units[more_than_one_unit_variables.index.values] + ) + ) + + raise ValueError(error_msg) + + for variable in more_than_one_unit_variables.index: + out_unit = variable_units[variable].iloc[0] + try: + run = run.convert_unit(out_unit, variable=variable) + except pint.errors.DimensionalityError as exc: + error_msg = ( + "Variable `{}` cannot be converted to a common unit. " + "Units in the provided dataset: {}." + .format(variable, variable_units[variable].values.tolist()) + ) + raise ValueError(error_msg) from exc + + return run + + +def _get_timeseries_for_xr_dataset(run, dimensions, extras, unify_units): + for d in dimensions: vals = sorted(run.meta[d].unique()) if not all([isinstance(v, str) for v in vals]) and np.isnan(vals).any(): @@ -195,6 +239,11 @@ def _add_extras(xr_ds, ids, ids_dimensions, run): def _add_units(xr_ds, unit_map): for data_var in xr_ds.data_vars: unit = unit_map[data_var] + if not isinstance(unit, str) and len(unit) > 1: + raise AssertionError( + "Found multiple units ({}) for {}".format(unit, data_var) + ) + xr_ds[data_var].attrs["units"] = unit return xr_ds diff --git a/tests/unit/test_xarray.py b/tests/unit/test_xarray.py index ed7bd23a..7c17a665 100644 --- a/tests/unit/test_xarray.py +++ b/tests/unit/test_xarray.py @@ -1,9 +1,14 @@ +import re + import numpy as np import numpy.testing as npt import pandas as pd import pytest import xarray as xr +import scmdata +from scmdata.errors import NonUniqueMetadataError + def do_basic_to_xarray_checks(res, start_run, dimensions, extras): assert isinstance(res, xr.Dataset) @@ -12,8 +17,8 @@ def do_basic_to_xarray_checks(res, start_run, dimensions, extras): for variable_name, data_var in res.data_vars.items(): assert data_var.dims == dimensions - unit = start_run.filter(variable=variable_name).get_unique_meta("unit", True) - assert data_var.units == unit + unit = start_run.filter(variable=variable_name).get_unique_meta("unit") + assert data_var.units in unit # all other metadata should be in attrs for meta_col in set(start_run.meta.columns) - set(dimensions) - set(extras) - {"variable", "unit"}: @@ -28,8 +33,9 @@ def do_basic_check_of_data_points(res, start_run, dimensions): xarray_spot = data_var.isel({v: idx for v in dimensions}) fkwargs = {k: [v.values.tolist()] for k, v in xarray_spot.coords.items()} fkwargs["variable"] = variable_name + xarray_unit = data_var.units - start_run_spot = start_run.filter(**fkwargs) + start_run_spot = start_run.filter(**fkwargs).convert_unit(xarray_unit) if np.isnan(xarray_spot): assert start_run_spot.empty else: @@ -176,6 +182,60 @@ def test_to_xarray_weird_names(scm_run, ch, weird_idx): do_basic_to_xarray_checks(res, scm_run, dimensions, (),) do_basic_check_of_data_points(res, scm_run, dimensions) + +def get_multiple_units_scm_run(scm_run, new_unit, new_unit_alternate): + first_var = scm_run.get_unique_meta("variable")[0] + scm_run_first_var = scm_run.filter(variable=first_var) + scm_run_first_var["unit"] = [ + v if i >= 1 else new_unit if v != new_unit else new_unit_alternate + for i, v in enumerate(scm_run_first_var["unit"].tolist()) + ] + scm_run_other_vars = scm_run.filter(variable=first_var, keep=False) + + return scmdata.run_append([scm_run_first_var, scm_run_other_vars]) + + +def test_to_xarray_multiple_units_error(scm_run): + scm_run = get_multiple_units_scm_run(scm_run, "J/yr", "MJ/yr") + + variable_unit_table = scm_run.meta[["variable", "unit"]].drop_duplicates() + variable_units = variable_unit_table.set_index("variable")["unit"] + variable_counts = variable_unit_table["variable"].value_counts() + more_than_one_unit_variables = variable_counts[variable_counts > 1] + error_msg = re.escape( + "The following variables are reported in more than one unit. " + "Found variable-unit combinations are:\n{}".format( + variable_units[more_than_one_unit_variables.index.values] + ) + ) + + with pytest.raises(ValueError, match=error_msg): + scm_run.to_xarray(dimensions=("region", "scenario", "time"), unify_units=False) + + +def test_to_xarray_unify_multiple_units(scm_run): + scm_run = get_multiple_units_scm_run(scm_run, "J/yr", "MJ/yr") + + dimensions = ("region", "scenario", "time") + res = scm_run.to_xarray(dimensions=dimensions, unify_units=True) + do_basic_to_xarray_checks(res, scm_run, dimensions, (),) + do_basic_check_of_data_points(res, scm_run, dimensions) + + +def test_to_xarray_unify_multiple_units_incompatible_units(scm_run): + scm_run = get_multiple_units_scm_run(scm_run, "kg", "g") + + dimensions = ("region", "scenario", "time") + + first_var = scm_run.get_unique_meta("variable")[0] + error_msg = re.escape( + "Variable `{}` cannot be converted to a common unit. " + "Units in the provided dataset: {}." + .format(first_var, scm_run.filter(variable=first_var).get_unique_meta("unit")) + ) + with pytest.raises(ValueError, match=error_msg): + scm_run.to_xarray(dimensions=dimensions, unify_units=True) + # Tests to write: # - multiple units for given variable # - overlapping dimensions and extras From aa16a12dd3345422c585fd267205ca3a0b4c7cfa Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 18:30:25 +1000 Subject: [PATCH 09/11] Round out tests --- src/scmdata/_xarray.py | 19 ++++++++++++++---- src/scmdata/netcdf.py | 5 ++++- tests/unit/test_xarray.py | 41 +++++++++++++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/src/scmdata/_xarray.py b/src/scmdata/_xarray.py index 1ce0d8b7..130d8c0d 100644 --- a/src/scmdata/_xarray.py +++ b/src/scmdata/_xarray.py @@ -15,7 +15,7 @@ def to_xarray(self, dimensions=("region",), extras=(), unify_units=True): Parameters ---------- dimensions : iterable of str - Dimensions for each variable in the returned dataset. If ``"time"`` is not included in ``dimensions`` it will be the last dimension. If an "_id" co-ordinate is required (see ``extras`` documentation for when "_id" is required) and is not included in ``dimensions`` then it will be the last dimension. + Dimensions for each variable in the returned dataset. If an "_id" co-ordinate is required (see ``extras`` documentation for when "_id" is required) and is not included in ``dimensions`` then it will be the last dimension (or second last dimension if "time" is also not included in ``dimensions``). If "time" is not included in ``dimensions`` it will be the last dimension. extras : iterable of str Columns in ``self.meta`` from which to create "non-dimension co-ordinates" (see `xarray terminology `_ for more details). These non-dimension co-ordinates store extra information and can be mapped to each timeseries found in the data variables of the output :obj:`xr.Dataset`. Where possible, these non-dimension co-ordinates will use dimension co-ordinates as their own co-ordinates. However, if the metadata in ``extras`` is not defined by a single dimension in ``dimensions``, then the ``extras`` co-ordinates will have dimensions of "_id". This "_id" co-ordinate maps the values in the ``extras`` co-ordinates to each timeseries in the serialised dataset. Where "_id" is required, an extra "_id" dimension will also be added to ``dimensions``. @@ -31,15 +31,26 @@ def to_xarray(self, dimensions=("region",), extras=(), unify_units=True): Raises ------ ValueError - If a given variable has multiple units and ``unify_units`` is ``False``. + If a variable has multiple units and ``unify_units`` is ``False``. + + ValueError + If a variable has multiple units which are not able to be converted to a common unit because they have different base units. """ dimensions = list(dimensions) extras = list(extras) + dimensions_extras_overlap = set(dimensions).intersection(set(extras)) + if dimensions_extras_overlap: + raise ValueError( + "dimensions and extras cannot have any overlap. " + "Current values in both dimensions and extras: {}" + .format(dimensions_extras_overlap) + ) + timeseries_dims = list(set(dimensions) - {"time"} - {"_id"}) self_unified_units = _unify_scmrun_units(self, unify_units) - timeseries = _get_timeseries_for_xr_dataset(self_unified_units, timeseries_dims, extras, unify_units) + timeseries = _get_timeseries_for_xr_dataset(self_unified_units, timeseries_dims, extras) non_dimension_extra_metadata = _get_other_metdata_for_xr_dataset( self_unified_units, dimensions, extras ) @@ -100,7 +111,7 @@ def _unify_scmrun_units(run, unify_units): return run -def _get_timeseries_for_xr_dataset(run, dimensions, extras, unify_units): +def _get_timeseries_for_xr_dataset(run, dimensions, extras): for d in dimensions: vals = sorted(run.meta[d].unique()) diff --git a/src/scmdata/netcdf.py b/src/scmdata/netcdf.py index 31b48e2d..c00d25bc 100644 --- a/src/scmdata/netcdf.py +++ b/src/scmdata/netcdf.py @@ -139,11 +139,14 @@ def run_to_nc(run, fname, dimensions=("region",), extras=(), **kwargs): Path to write the file into dimensions : iterable of str - Dimensions to include in the netCDF file. The time dimension is always included, even if not provided. An additional dimension (specifically a co-ordinate in xarray terms), "_id", will be included if ``extras`` is provided and any of the metadata in ``extras`` is not uniquely defined by ``dimensions``. "_id" maps the timeseries in each variable to their relevant metadata. + Dimensions to include in the netCDF file. The time dimension is always included (if not provided it will be the last dimension). An additional dimension (specifically a co-ordinate in xarray terms), "_id", will be included if ``extras`` is provided and any of the metadata in ``extras`` is not uniquely defined by ``dimensions``. "_id" maps the timeseries in each variable to their relevant metadata. extras : iterable of str Metadata columns to write as variables (specifically co-ordinates in xarray terms) in the netCDF file. Where possible, the metadata in ``dimensions`` will be used as the dimensions of these variables. However, if the metadata in ``extras`` is not defined by a single dimension in ``dimensions``, then the ``extras`` variables will have dimensions of "_id", which maps the metadata to each timeseries in the serialised dataset. + extras : iterable of str + Columns in ``self.meta`` from which to create "non-dimension co-ordinates" (see `xarray terminology `_ for more details). These non-dimension co-ordinates store extra information and can be mapped to each timeseries found in the data variables of the output :obj:`xr.Dataset`. Where possible, these non-dimension co-ordinates will use dimension co-ordinates as their own co-ordinates. However, if the metadata in ``extras`` is not defined by a single dimension in ``dimensions``, then the ``extras`` co-ordinates will have dimensions of "_id". This "_id" co-ordinate maps the values in the ``extras`` co-ordinates to each timeseries in the serialised dataset. Where "_id" is required, an extra "_id" dimension will also be added to ``dimensions``. + kwargs Passed through to :meth:`xarray.Dataset.to_netcdf` diff --git a/tests/unit/test_xarray.py b/tests/unit/test_xarray.py index 7c17a665..fdca1ed6 100644 --- a/tests/unit/test_xarray.py +++ b/tests/unit/test_xarray.py @@ -236,7 +236,40 @@ def test_to_xarray_unify_multiple_units_incompatible_units(scm_run): with pytest.raises(ValueError, match=error_msg): scm_run.to_xarray(dimensions=dimensions, unify_units=True) -# Tests to write: -# - multiple units for given variable -# - overlapping dimensions and extras -# - underdefined dimensions and extras + +@pytest.mark.parametrize("dimensions,extras", ( + (("junk",), (),), + (("junk",), ("climate_model"),), + (("scenario", "junk_1"), ("junk",)), + (("scenario",), ("junk",)), + (("scenario",), ("junk", "climate_model")), + (("scenario",), ("junk", "junk_2", "climate_model")), +)) +def test_dimension_and_or_extra_not_in_metadata(scm_run, dimensions, extras): + with pytest.raises(KeyError): + scm_run.to_xarray(dimensions=dimensions, extras=extras) + + +def test_to_xarray_dimensions_extra_overlap(scm_run): + dimensions = ("scenario", "region") + extras = ("scenario",) + + error_msg = re.escape( + "dimensions and extras cannot have any overlap. " + "Current values in both dimensions and extras: {}" + .format({"scenario"}) + ) + with pytest.raises(ValueError, match=error_msg): + scm_run.to_xarray(dimensions=dimensions, extras=extras) + + +def test_to_xarray_non_unique_timeseries(scm_run): + dimensions = ("region",) + + error_msg = re.escape( + "dimensions: `{}` and extras: `[]` do not uniquely define the timeseries, " + "please add extra dimensions and/or extras" + .format(list(dimensions)) + ) + with pytest.raises(ValueError, match=error_msg): + scm_run.to_xarray(dimensions=dimensions) From 189a5b7f3c97e5cb8f047d87f163648535ad92aa Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 18:44:35 +1000 Subject: [PATCH 10/11] Format --- src/scmdata/_xarray.py | 22 ++++-- src/scmdata/netcdf.py | 4 +- src/scmdata/run.py | 1 + tests/unit/test_netcdf.py | 12 ++-- tests/unit/test_xarray.py | 139 ++++++++++++++++++++++++-------------- 5 files changed, 114 insertions(+), 64 deletions(-) diff --git a/src/scmdata/_xarray.py b/src/scmdata/_xarray.py index 130d8c0d..ffe372e0 100644 --- a/src/scmdata/_xarray.py +++ b/src/scmdata/_xarray.py @@ -43,20 +43,25 @@ def to_xarray(self, dimensions=("region",), extras=(), unify_units=True): if dimensions_extras_overlap: raise ValueError( "dimensions and extras cannot have any overlap. " - "Current values in both dimensions and extras: {}" - .format(dimensions_extras_overlap) + "Current values in both dimensions and extras: {}".format( + dimensions_extras_overlap + ) ) timeseries_dims = list(set(dimensions) - {"time"} - {"_id"}) self_unified_units = _unify_scmrun_units(self, unify_units) - timeseries = _get_timeseries_for_xr_dataset(self_unified_units, timeseries_dims, extras) + timeseries = _get_timeseries_for_xr_dataset( + self_unified_units, timeseries_dims, extras + ) non_dimension_extra_metadata = _get_other_metdata_for_xr_dataset( self_unified_units, dimensions, extras ) if extras: - ids, ids_dimensions = _get_ids_for_xr_dataset(self_unified_units, extras, timeseries_dims) + ids, ids_dimensions = _get_ids_for_xr_dataset( + self_unified_units, extras, timeseries_dims + ) else: ids = None ids_dimensions = None @@ -70,7 +75,9 @@ def to_xarray(self, dimensions=("region",), extras=(), unify_units=True): xr_ds = _add_extras(xr_ds, ids, ids_dimensions, self_unified_units) unit_map = ( - self_unified_units.meta[["variable", "unit"]].drop_duplicates().set_index("variable")["unit"] + self_unified_units.meta[["variable", "unit"]] + .drop_duplicates() + .set_index("variable")["unit"] ) xr_ds = _add_units(xr_ds, unit_map) xr_ds = _add_scmdata_metadata(xr_ds, non_dimension_extra_metadata) @@ -103,8 +110,9 @@ def _unify_scmrun_units(run, unify_units): except pint.errors.DimensionalityError as exc: error_msg = ( "Variable `{}` cannot be converted to a common unit. " - "Units in the provided dataset: {}." - .format(variable, variable_units[variable].values.tolist()) + "Units in the provided dataset: {}.".format( + variable, variable_units[variable].values.tolist() + ) ) raise ValueError(error_msg) from exc diff --git a/src/scmdata/netcdf.py b/src/scmdata/netcdf.py index c00d25bc..d29ef435 100644 --- a/src/scmdata/netcdf.py +++ b/src/scmdata/netcdf.py @@ -92,8 +92,8 @@ def _reshape_to_scmrun_dataframe(dataframe, loaded): def _convert_to_cls_and_add_metadata(dataframe, loaded, cls): for k in list(loaded.attrs.keys()): - if k.startswith("_scmdata_metadata_"): - dataframe[k.replace("_scmdata_metadata_", "")] = loaded.attrs.pop(k) + if k.startswith("scmdata_metadata_"): + dataframe[k.replace("scmdata_metadata_", "")] = loaded.attrs.pop(k) run = cls(dataframe) run.metadata.update(loaded.attrs) diff --git a/src/scmdata/run.py b/src/scmdata/run.py index a31dcb26..f7308da0 100644 --- a/src/scmdata/run.py +++ b/src/scmdata/run.py @@ -2188,6 +2188,7 @@ def _handle_potential_duplicates_in_append(data, duplicate_msg): inject_ops_methods(BaseScmRun) inject_xarray_methods(BaseScmRun) + class ScmRun(BaseScmRun): """ Data container for holding one or many time-series of SCM data. diff --git a/tests/unit/test_netcdf.py b/tests/unit/test_netcdf.py index db223776..6964c12a 100644 --- a/tests/unit/test_netcdf.py +++ b/tests/unit/test_netcdf.py @@ -33,15 +33,15 @@ def test_run_to_nc(scm_run): assert ds.variables["scenario"][1] == "a_scenario2" npt.assert_allclose( - ds.variables["Primary_Energy"][:, 0], + ds.variables["Primary_Energy"][0, :], scm_run.filter(variable="Primary Energy", scenario="a_scenario").values[0], ) npt.assert_allclose( - ds.variables["Primary_Energy"][:, 1], + ds.variables["Primary_Energy"][1, :], scm_run.filter(variable="Primary Energy", scenario="a_scenario2").values[0], ) npt.assert_allclose( - ds.variables["Primary_Energy__Coal"][:, 0], + ds.variables["Primary_Energy__Coal"][0, :], scm_run.filter( variable="Primary Energy|Coal", scenario="a_scenario" ).values[0], @@ -411,15 +411,15 @@ def test_run_to_nc_with_extras(scm_run, dtype): assert run_id == exp_val npt.assert_allclose( - ds.variables["Primary_Energy"][:, 0], + ds.variables["Primary_Energy"][0, :], scm_run.filter(variable="Primary Energy", scenario="a_scenario").values[0], ) npt.assert_allclose( - ds.variables["Primary_Energy"][:, 1], + ds.variables["Primary_Energy"][1, :], scm_run.filter(variable="Primary Energy", scenario="a_scenario2").values[0], ) npt.assert_allclose( - ds.variables["Primary_Energy__Coal"][:, 0], + ds.variables["Primary_Energy__Coal"][0, :], scm_run.filter( variable="Primary Energy|Coal", scenario="a_scenario" ).values[0], diff --git a/tests/unit/test_xarray.py b/tests/unit/test_xarray.py index fdca1ed6..49bfe4a4 100644 --- a/tests/unit/test_xarray.py +++ b/tests/unit/test_xarray.py @@ -7,7 +7,6 @@ import xarray as xr import scmdata -from scmdata.errors import NonUniqueMetadataError def do_basic_to_xarray_checks(res, start_run, dimensions, extras): @@ -21,7 +20,12 @@ def do_basic_to_xarray_checks(res, start_run, dimensions, extras): assert data_var.units in unit # all other metadata should be in attrs - for meta_col in set(start_run.meta.columns) - set(dimensions) - set(extras) - {"variable", "unit"}: + for meta_col in ( + set(start_run.meta.columns) + - set(dimensions) + - set(extras) + - {"variable", "unit"} + ): meta_val = start_run.get_unique_meta(meta_col, True) assert res.attrs["scmdata_metadata_{}".format(meta_col)] == meta_val @@ -43,29 +47,32 @@ def do_basic_check_of_data_points(res, start_run, dimensions): npt.assert_array_equal(xarray_spot.values, start_run_vals) -@pytest.mark.parametrize("dimensions,expected_dimensions", ( - (("region", "scenario", "time"), ("region", "scenario", "time")), - (("time", "region", "scenario"), ("time", "region", "scenario")), - (("region", "time", "scenario"), ("region", "time", "scenario")), - (("region", "scenario"), ("region", "scenario", "time")), - (("scenario", "region"), ("scenario", "region", "time")), - (("scenario",), ("scenario", "time")), -)) +@pytest.mark.parametrize( + "dimensions,expected_dimensions", + ( + (("region", "scenario", "time"), ("region", "scenario", "time")), + (("time", "region", "scenario"), ("time", "region", "scenario")), + (("region", "time", "scenario"), ("region", "time", "scenario")), + (("region", "scenario"), ("region", "scenario", "time")), + (("scenario", "region"), ("scenario", "region", "time")), + (("scenario",), ("scenario", "time")), + ), +) def test_to_xarray(scm_run, dimensions, expected_dimensions): res = scm_run.to_xarray(dimensions=dimensions) - do_basic_to_xarray_checks(res, scm_run, expected_dimensions, (),) + do_basic_to_xarray_checks( + res, scm_run, expected_dimensions, (), + ) do_basic_check_of_data_points(res, scm_run, expected_dimensions) # no extras assert not set(res.coords) - set(res.dims) -@pytest.mark.parametrize("extras", ( - ("model",), - ("climate_model",), - ("climate_model", "model"), -)) +@pytest.mark.parametrize( + "extras", (("model",), ("climate_model",), ("climate_model", "model"),) +) def test_to_xarray_extras_no_id_coord(scm_run, extras): dimensions = ("scenario", "region", "time") res = scm_run.to_xarray(dimensions=dimensions, extras=extras) @@ -84,26 +91,48 @@ def test_to_xarray_extras_no_id_coord(scm_run, extras): xarray_coords = res[extra_col][extra_dims].values for xarray_extra_val, extra_xarray_coord in zip(xarray_vals, xarray_coords): - scm_run_extra_val = scm_run_meta[scm_run_meta[extra_dims] == extra_xarray_coord][extra_col].unique().tolist() + scm_run_extra_val = ( + scm_run_meta[scm_run_meta[extra_dims] == extra_xarray_coord][extra_col] + .unique() + .tolist() + ) assert len(scm_run_extra_val) == 1 scm_run_extra_val = scm_run_extra_val[0] assert scm_run_extra_val == xarray_extra_val -@pytest.mark.parametrize("extras", ( - ("scenario", "model", "random_key"), -)) -@pytest.mark.parametrize("dimensions,expected_dimensions", ( - (("climate_model", "run_id"), ("climate_model", "run_id", "time", "_id")), - (("run_id", "climate_model"), ("run_id", "climate_model", "time", "_id")), - (("run_id", "climate_model", "time"), ("run_id", "climate_model", "time", "_id")), - (("run_id", "time", "climate_model"), ("run_id", "time", "climate_model", "_id")), - (("run_id", "climate_model", "time", "_id"), ("run_id", "climate_model", "time", "_id")), - (("_id", "run_id", "time", "climate_model"), ("_id", "run_id", "time", "climate_model")), - (("run_id", "_id", "climate_model"), ("run_id", "_id", "climate_model", "time")), -)) -def test_to_xarray_extras_with_id_coord(scm_run, extras, dimensions, expected_dimensions): +@pytest.mark.parametrize("extras", (("scenario", "model", "random_key"),)) +@pytest.mark.parametrize( + "dimensions,expected_dimensions", + ( + (("climate_model", "run_id"), ("climate_model", "run_id", "time", "_id")), + (("run_id", "climate_model"), ("run_id", "climate_model", "time", "_id")), + ( + ("run_id", "climate_model", "time"), + ("run_id", "climate_model", "time", "_id"), + ), + ( + ("run_id", "time", "climate_model"), + ("run_id", "time", "climate_model", "_id"), + ), + ( + ("run_id", "climate_model", "time", "_id"), + ("run_id", "climate_model", "time", "_id"), + ), + ( + ("_id", "run_id", "time", "climate_model"), + ("_id", "run_id", "time", "climate_model"), + ), + ( + ("run_id", "_id", "climate_model"), + ("run_id", "_id", "climate_model", "time"), + ), + ), +) +def test_to_xarray_extras_with_id_coord( + scm_run, extras, dimensions, expected_dimensions +): df = scm_run.timeseries() val_cols = df.columns.tolist() df = df.reset_index() @@ -157,10 +186,16 @@ def test_to_xarray_extras_with_id_coord(scm_run, extras, dimensions, expected_di scm_run_filter = row.to_dict() scm_run_spot = scm_run.filter(**scm_run_filter) - xarray_sel = {k: v for k, v in scm_run_filter.items() if k in xarray_timeseries.dims} - xarray_spot = xarray_timeseries.sel(**xarray_sel)[scm_run_filter["variable"]] + xarray_sel = { + k: v for k, v in scm_run_filter.items() if k in xarray_timeseries.dims + } + xarray_spot = xarray_timeseries.sel(**xarray_sel)[ + scm_run_filter["variable"] + ] - npt.assert_array_equal(scm_run_spot.values.squeeze(), xarray_spot.values.squeeze()) + npt.assert_array_equal( + scm_run_spot.values.squeeze(), xarray_spot.values.squeeze() + ) @pytest.mark.parametrize("ch", "!@#$%^&*()~`+={}]<>,;:'\".") @@ -179,7 +214,9 @@ def test_to_xarray_weird_names(scm_run, ch, weird_idx): dimensions = ("region", "scenario", "time") res = scm_run.to_xarray(dimensions=dimensions) - do_basic_to_xarray_checks(res, scm_run, dimensions, (),) + do_basic_to_xarray_checks( + res, scm_run, dimensions, (), + ) do_basic_check_of_data_points(res, scm_run, dimensions) @@ -218,7 +255,9 @@ def test_to_xarray_unify_multiple_units(scm_run): dimensions = ("region", "scenario", "time") res = scm_run.to_xarray(dimensions=dimensions, unify_units=True) - do_basic_to_xarray_checks(res, scm_run, dimensions, (),) + do_basic_to_xarray_checks( + res, scm_run, dimensions, (), + ) do_basic_check_of_data_points(res, scm_run, dimensions) @@ -230,21 +269,25 @@ def test_to_xarray_unify_multiple_units_incompatible_units(scm_run): first_var = scm_run.get_unique_meta("variable")[0] error_msg = re.escape( "Variable `{}` cannot be converted to a common unit. " - "Units in the provided dataset: {}." - .format(first_var, scm_run.filter(variable=first_var).get_unique_meta("unit")) + "Units in the provided dataset: {}.".format( + first_var, scm_run.filter(variable=first_var).get_unique_meta("unit") + ) ) with pytest.raises(ValueError, match=error_msg): scm_run.to_xarray(dimensions=dimensions, unify_units=True) -@pytest.mark.parametrize("dimensions,extras", ( - (("junk",), (),), - (("junk",), ("climate_model"),), - (("scenario", "junk_1"), ("junk",)), - (("scenario",), ("junk",)), - (("scenario",), ("junk", "climate_model")), - (("scenario",), ("junk", "junk_2", "climate_model")), -)) +@pytest.mark.parametrize( + "dimensions,extras", + ( + (("junk",), (),), + (("junk",), ("climate_model"),), + (("scenario", "junk_1"), ("junk",)), + (("scenario",), ("junk",)), + (("scenario",), ("junk", "climate_model")), + (("scenario",), ("junk", "junk_2", "climate_model")), + ), +) def test_dimension_and_or_extra_not_in_metadata(scm_run, dimensions, extras): with pytest.raises(KeyError): scm_run.to_xarray(dimensions=dimensions, extras=extras) @@ -256,8 +299,7 @@ def test_to_xarray_dimensions_extra_overlap(scm_run): error_msg = re.escape( "dimensions and extras cannot have any overlap. " - "Current values in both dimensions and extras: {}" - .format({"scenario"}) + "Current values in both dimensions and extras: {}".format({"scenario"}) ) with pytest.raises(ValueError, match=error_msg): scm_run.to_xarray(dimensions=dimensions, extras=extras) @@ -268,8 +310,7 @@ def test_to_xarray_non_unique_timeseries(scm_run): error_msg = re.escape( "dimensions: `{}` and extras: `[]` do not uniquely define the timeseries, " - "please add extra dimensions and/or extras" - .format(list(dimensions)) + "please add extra dimensions and/or extras".format(list(dimensions)) ) with pytest.raises(ValueError, match=error_msg): scm_run.to_xarray(dimensions=dimensions) From d98f2c568d71f2729cf4a8ec963c5df2c786e04a Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Mon, 3 May 2021 18:57:01 +1000 Subject: [PATCH 11/11] Add note about notebook to write --- tests/unit/test_ops.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index f830b21c..e42c82fe 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -1,3 +1,4 @@ +import datetime as dt import re from unittest.mock import patch @@ -989,3 +990,7 @@ def test_linear_regression_scmrun(): assert_scmdf_almost_equal( res, exp, allow_unordered=True, check_ts_names=False, rtol=1e-3 ) + + +# TODO: notebook illustrating rolling mean options +# Rationale: rolling means are really tricky (do you take e.g. an annual mean first, do you worry about happens at the window edge?) and they're pretty easy to convert back into ScmRun objects so a notebooks is probably more helpful than exact functionality for now