diff --git a/src/scmdata/_xarray.py b/src/scmdata/_xarray.py new file mode 100644 index 00000000..ffe372e0 --- /dev/null +++ b/src/scmdata/_xarray.py @@ -0,0 +1,308 @@ +""" +Interface with `xarray `_ +""" +import numpy as np +import pint.errors +import xarray as xr + +from .errors import NonUniqueMetadataError + + +def to_xarray(self, dimensions=("region",), extras=(), unify_units=True): + """ + Convert to a :class:`xr.Dataset` + + Parameters + ---------- + dimensions : iterable of str + Dimensions for each variable in the returned dataset. If an "_id" co-ordinate is required (see ``extras`` documentation for when "_id" is required) and is not included in ``dimensions`` then it will be the last dimension (or second last dimension if "time" is also not included in ``dimensions``). If "time" is not included in ``dimensions`` it will be the last dimension. + + extras : iterable of str + Columns in ``self.meta`` from which to create "non-dimension co-ordinates" (see `xarray terminology `_ for more details). These non-dimension co-ordinates store extra information and can be mapped to each timeseries found in the data variables of the output :obj:`xr.Dataset`. Where possible, these non-dimension co-ordinates will use dimension co-ordinates as their own co-ordinates. However, if the metadata in ``extras`` is not defined by a single dimension in ``dimensions``, then the ``extras`` co-ordinates will have dimensions of "_id". This "_id" co-ordinate maps the values in the ``extras`` co-ordinates to each timeseries in the serialised dataset. Where "_id" is required, an extra "_id" dimension will also be added to ``dimensions``. + + unify_units : bool + If a given variable has multiple units, should we attempt to unify them? + + Returns + ------- + :obj:`xr.Dataset` + Data in self, re-formatted as an :obj:`xr.Dataset` + + Raises + ------ + ValueError + If a variable has multiple units and ``unify_units`` is ``False``. + + ValueError + If a variable has multiple units which are not able to be converted to a common unit because they have different base units. + """ + dimensions = list(dimensions) + extras = list(extras) + + dimensions_extras_overlap = set(dimensions).intersection(set(extras)) + if dimensions_extras_overlap: + raise ValueError( + "dimensions and extras cannot have any overlap. " + "Current values in both dimensions and extras: {}".format( + dimensions_extras_overlap + ) + ) + + timeseries_dims = list(set(dimensions) - {"time"} - {"_id"}) + + self_unified_units = _unify_scmrun_units(self, unify_units) + timeseries = _get_timeseries_for_xr_dataset( + self_unified_units, timeseries_dims, extras + ) + non_dimension_extra_metadata = _get_other_metdata_for_xr_dataset( + self_unified_units, dimensions, extras + ) + + if extras: + ids, ids_dimensions = _get_ids_for_xr_dataset( + self_unified_units, extras, timeseries_dims + ) + else: + ids = None + ids_dimensions = None + + for_xarray = _get_dataframe_for_xr_dataset( + timeseries, timeseries_dims, extras, ids, ids_dimensions + ) + xr_ds = xr.Dataset.from_dataframe(for_xarray) + + if extras: + xr_ds = _add_extras(xr_ds, ids, ids_dimensions, self_unified_units) + + unit_map = ( + self_unified_units.meta[["variable", "unit"]] + .drop_duplicates() + .set_index("variable")["unit"] + ) + xr_ds = _add_units(xr_ds, unit_map) + xr_ds = _add_scmdata_metadata(xr_ds, non_dimension_extra_metadata) + xr_ds = _set_dimensions(xr_ds, dimensions) + + return xr_ds + + +def _unify_scmrun_units(run, unify_units): + variable_unit_table = run.meta[["variable", "unit"]].drop_duplicates() + variable_units = variable_unit_table.set_index("variable")["unit"] + + variable_counts = variable_unit_table["variable"].value_counts() + more_than_one_unit_variables = variable_counts[variable_counts > 1] + if not more_than_one_unit_variables.empty: + if not unify_units: + error_msg = ( + "The following variables are reported in more than one unit. " + "Found variable-unit combinations are:\n{}".format( + variable_units[more_than_one_unit_variables.index.values] + ) + ) + + raise ValueError(error_msg) + + for variable in more_than_one_unit_variables.index: + out_unit = variable_units[variable].iloc[0] + try: + run = run.convert_unit(out_unit, variable=variable) + except pint.errors.DimensionalityError as exc: + error_msg = ( + "Variable `{}` cannot be converted to a common unit. " + "Units in the provided dataset: {}.".format( + variable, variable_units[variable].values.tolist() + ) + ) + raise ValueError(error_msg) from exc + + return run + + +def _get_timeseries_for_xr_dataset(run, dimensions, extras): + + for d in dimensions: + vals = sorted(run.meta[d].unique()) + if not all([isinstance(v, str) for v in vals]) and np.isnan(vals).any(): + raise AssertionError("nan in dimension: `{}`".format(d)) + + try: + timeseries = run.timeseries(dimensions + extras + ["variable"]) + except NonUniqueMetadataError as exc: + error_msg = ( + "dimensions: `{}` and extras: `{}` do not uniquely define the " + "timeseries, please add extra dimensions and/or extras".format( + dimensions, extras + ) + ) + raise ValueError(error_msg) from exc + + timeseries.columns = run.time_points.as_cftime() + + return timeseries + + +def _get_other_metdata_for_xr_dataset(run, dimensions, extras): + other_dimensions = list( + set(run.meta.columns) - set(dimensions) - set(extras) - {"variable", "unit"} + ) + other_metdata = run.meta[other_dimensions].drop_duplicates() + if other_metdata.shape[0] > 1 and not other_metdata.empty: + error_msg = ( + "Other metadata is not unique for dimensions: `{}` and extras: `{}`. " + "Please add meta columns with more than one value to dimensions or " + "extras.\nNumber of unique values in each column:\n{}.\n" + "Existing values in the other metadata:\n{}.".format( + dimensions, + extras, + other_metdata.nunique(), + other_metdata.drop_duplicates(), + ) + ) + raise ValueError(error_msg) + + return other_metdata + + +def _get_ids_for_xr_dataset(run, extras, dimensions): + # these loops could be very slow with lots of extras and dimensions... + ids_dimensions = {} + for extra in extras: + for col in dimensions: + if _many_to_one(run.meta, extra, col): + dim_col = col + break + else: + dim_col = "_id" + + ids_dimensions[extra] = dim_col + + ids = run.meta[extras].drop_duplicates() + ids["_id"] = range(ids.shape[0]) + ids = ids.set_index(extras) + + return ids, ids_dimensions + + +def _many_to_one(df, col1, col2): + """ + Check if there is a many to one mapping between col2 and col1 + """ + # thanks https://stackoverflow.com/a/59091549 + checker = df[[col1, col2]].drop_duplicates() + + max_count = checker.groupby(col2).count().max()[0] + if max_count < 1: # pragma: no cover # emergency valve + raise AssertionError + + return max_count == 1 + + +def _get_dataframe_for_xr_dataset(timeseries, dimensions, extras, ids, ids_dimensions): + timeseries = timeseries.reset_index() + + add_id_dimension = extras and "_id" in set(ids_dimensions.values()) + if add_id_dimension: + timeseries = ( + timeseries.set_index(ids.index.names) + .join(ids) + .reset_index(drop=True) + .set_index(dimensions + ["variable", "_id"]) + ) + else: + timeseries = timeseries.set_index(dimensions + ["variable"]) + if extras: + timeseries = timeseries.drop(extras, axis="columns") + + timeseries.columns.names = ["time"] + + if ( + len(timeseries.index.unique()) != timeseries.shape[0] + ): # pragma: no cover # emergency valve + # shouldn't be able to get here because any issues should be caught + # by initial creation of timeseries but just in case + raise AssertionError("something not unique") + + for_xarray = ( + timeseries.T.stack(dimensions + ["_id"]) + if add_id_dimension + else timeseries.T.stack(dimensions) + ) + + return for_xarray + + +def _add_extras(xr_ds, ids, ids_dimensions, run): + # this loop could also be slow... + extra_coords = {} + for extra, id_dimension in ids_dimensions.items(): + if id_dimension in ids: + ids_extra = ids.reset_index().set_index(id_dimension) + else: + ids_extra = ( + run.meta[[extra, id_dimension]] + .drop_duplicates() + .set_index(id_dimension) + ) + + extra_coords[extra] = ( + id_dimension, + ids_extra[extra].loc[xr_ds[id_dimension].values], + ) + + xr_ds = xr_ds.assign_coords(extra_coords) + + return xr_ds + + +def _add_units(xr_ds, unit_map): + for data_var in xr_ds.data_vars: + unit = unit_map[data_var] + if not isinstance(unit, str) and len(unit) > 1: + raise AssertionError( + "Found multiple units ({}) for {}".format(unit, data_var) + ) + + xr_ds[data_var].attrs["units"] = unit + + return xr_ds + + +def _add_scmdata_metadata(xr_ds, others): + for col in others: + vals = others[col].unique() + if len(vals) > 1: # pragma: no cover # emergency valve + # should have already been caught... + raise AssertionError("More than one value for meta: {}".format(col)) + + xr_ds.attrs["scmdata_metadata_{}".format(col)] = vals[0] + + return xr_ds + + +def _set_dimensions(xr_ds, dimensions): + out_dimensions = dimensions + if "time" not in dimensions: + out_dimensions += ["time"] + + if "_id" in xr_ds.dims and "_id" not in dimensions: + out_dimensions += ["_id"] + + return xr_ds.transpose(*out_dimensions) + + +def inject_xarray_methods(cls): + """ + Inject the xarray methods + + Parameters + ---------- + cls + Target class + """ + methods = [ + ("to_xarray", to_xarray), + ] + + for name, f in methods: + setattr(cls, name, f) diff --git a/src/scmdata/netcdf.py b/src/scmdata/netcdf.py index 918b903a..d29ef435 100644 --- a/src/scmdata/netcdf.py +++ b/src/scmdata/netcdf.py @@ -14,214 +14,17 @@ from datetime import datetime from logging import getLogger -import numpy as np import xarray as xr from . import __version__ -from .errors import NonUniqueMetadataError logger = getLogger(__name__) -""" -Default to writing float data as 8 byte floats -""" -DEFAULT_FLOAT = "f8" - - def _var_to_nc(var): return var.replace("|", "__").replace(" ", "_") -def _write_nc(fname, run, dimensions, extras, **kwargs): - """ - Low level function to write the dimensions, variables and metadata to disk - """ - xr_ds = _get_xr_dataset(run, dimensions, extras) - - xr_ds.attrs["created_at"] = datetime.utcnow().isoformat() - xr_ds.attrs["_scmdata_version"] = __version__ - - if run.metadata: - xr_ds.attrs.update(run.metadata) - - write_kwargs = _update_kwargs_to_match_serialised_variable_names(xr_ds, kwargs) - xr_ds.to_netcdf(fname, **write_kwargs) - - -def _get_xr_dataset(run, dimensions, extras): - timeseries = _get_timeseries_for_xr_dataset(run, dimensions, extras) - non_dimension_extra_metadata = _get_other_metdata_for_xr_dataset( - run, dimensions, extras - ) - - if extras: - ids, ids_dimensions = _get_ids_for_xr_dataset(run, extras, dimensions) - else: - ids = None - ids_dimensions = None - - for_xarray = _get_dataframe_for_xr_dataset( - timeseries, dimensions, extras, ids, ids_dimensions - ) - xr_ds = xr.Dataset.from_dataframe(for_xarray) - - if extras: - xr_ds = _add_extras(xr_ds, ids, ids_dimensions, run) - - unit_map = ( - run.meta[["variable", "unit"]].drop_duplicates().set_index("variable")["unit"] - ) - xr_ds = _add_units(xr_ds, unit_map) - xr_ds = _rename_variables(xr_ds) - xr_ds = _add_scmdata_metadata(xr_ds, non_dimension_extra_metadata) - - return xr_ds - - -def _get_timeseries_for_xr_dataset(run, dimensions, extras): - for d in dimensions: - vals = sorted(run.meta[d].unique()) - if not all([isinstance(v, str) for v in vals]) and np.isnan(vals).any(): - raise AssertionError("nan in dimension: `{}`".format(d)) - - try: - timeseries = run.timeseries(dimensions + extras + ["variable"]) - except NonUniqueMetadataError as exc: - error_msg = ( - "dimensions: `{}` and extras: `{}` do not uniquely define the " - "timeseries, please add extra dimensions and/or extras".format( - dimensions, extras - ) - ) - raise ValueError(error_msg) from exc - - timeseries.columns = run.time_points.as_cftime() - - return timeseries - - -def _get_other_metdata_for_xr_dataset(run, dimensions, extras): - other_dimensions = list( - set(run.meta.columns) - set(dimensions) - set(extras) - {"variable", "unit"} - ) - other_metdata = run.meta[other_dimensions].drop_duplicates() - if other_metdata.shape[0] > 1 and not other_metdata.empty: - error_msg = ( - "Other metadata is not unique for dimensions: `{}` and extras: `{}`. " - "Please add meta columns with more than one value to dimensions or " - "extras.\nNumber of unique values in each column:\n{}.\n" - "Existing values in the other metadata:\n{}.".format( - dimensions, - extras, - other_metdata.nunique(), - other_metdata.drop_duplicates(), - ) - ) - raise ValueError(error_msg) - - return other_metdata - - -def _get_ids_for_xr_dataset(run, extras, dimensions): - # these loops could be very slow with lots of extras and dimensions... - ids_dimensions = {} - for extra in extras: - for col in dimensions: - if _many_to_one(run.meta, extra, col): - dim_col = col - break - else: - dim_col = "_id" - - ids_dimensions[extra] = dim_col - - ids = run.meta[extras].drop_duplicates() - ids["_id"] = range(ids.shape[0]) - ids = ids.set_index(extras) - - return ids, ids_dimensions - - -def _many_to_one(df, col1, col2): - """ - Check if there is a many to one mapping between col2 and col1 - """ - # thanks https://stackoverflow.com/a/59091549 - checker = df[[col1, col2]].drop_duplicates() - - max_count = checker.groupby(col2).count().max()[0] - if max_count < 1: # pragma: no cover # emergency valve - raise AssertionError - - return max_count == 1 - - -def _get_dataframe_for_xr_dataset(timeseries, dimensions, extras, ids, ids_dimensions): - timeseries = timeseries.reset_index() - - add_id_dimension = extras and "_id" in set(ids_dimensions.values()) - if add_id_dimension: - timeseries = ( - timeseries.set_index(ids.index.names) - .join(ids) - .reset_index(drop=True) - .set_index(dimensions + ["variable", "_id"]) - ) - else: - timeseries = timeseries.set_index(dimensions + ["variable"]) - if extras: - timeseries = timeseries.drop(extras, axis="columns") - - timeseries.columns.names = ["time"] - - if ( - len(timeseries.index.unique()) != timeseries.shape[0] - ): # pragma: no cover # emergency valve - # shouldn't be able to get here because any issues should be caught - # by initial creation of timeseries but just in case - raise AssertionError("something not unique") - - for_xarray = ( - timeseries.T.stack(dimensions + ["_id"]) - if add_id_dimension - else timeseries.T.stack(dimensions) - ) - - return for_xarray - - -def _add_extras(xr_ds, ids, ids_dimensions, run): - # this loop could also be slow... - extra_coords = {} - for extra, id_dimension in ids_dimensions.items(): - if id_dimension in ids: - ids_extra = ids.reset_index().set_index(id_dimension) - else: - ids_extra = ( - run.meta[[extra, id_dimension]] - .drop_duplicates() - .set_index(id_dimension) - ) - - extra_coords[extra] = ( - id_dimension, - ids_extra[extra].loc[xr_ds[id_dimension].values], - ) - - xr_ds = xr_ds.assign_coords(extra_coords) - - return xr_ds - - -def _add_units(xr_ds, unit_map): - for data_var in xr_ds.data_vars: - unit = unit_map[data_var] - xr_ds[data_var].attrs["units"] = unit - - return xr_ds - - def _rename_variables(xr_ds): name_mapping = {} for data_var in xr_ds.data_vars: @@ -234,18 +37,29 @@ def _rename_variables(xr_ds): return xr_ds -def _add_scmdata_metadata(xr_ds, others): - for col in others: - vals = others[col].unique() - if len(vals) > 1: # pragma: no cover # emergency valve - # should have already been caught... - raise AssertionError("More than one value for meta: {}".format(col)) - - xr_ds.attrs["_scmdata_metadata_{}".format(col)] = vals[0] +def _get_xr_dataset_to_write(run, dimensions, extras): + xr_ds = run.to_xarray(dimensions, extras) + xr_ds = _rename_variables(xr_ds) return xr_ds +def _write_nc(fname, run, dimensions, extras, **kwargs): + """ + Low level function to write the dimensions, variables and metadata to disk + """ + xr_ds = _get_xr_dataset_to_write(run, dimensions, extras) + + xr_ds.attrs["created_at"] = datetime.utcnow().isoformat() + xr_ds.attrs["_scmdata_version"] = __version__ + + if run.metadata: + xr_ds.attrs.update(run.metadata) + + write_kwargs = _update_kwargs_to_match_serialised_variable_names(xr_ds, kwargs) + xr_ds.to_netcdf(fname, **write_kwargs) + + def _read_nc(cls, fname): loaded = xr.load_dataset(fname) dataframe = loaded.to_dataframe() @@ -278,8 +92,8 @@ def _reshape_to_scmrun_dataframe(dataframe, loaded): def _convert_to_cls_and_add_metadata(dataframe, loaded, cls): for k in list(loaded.attrs.keys()): - if k.startswith("_scmdata_metadata_"): - dataframe[k.replace("_scmdata_metadata_", "")] = loaded.attrs.pop(k) + if k.startswith("scmdata_metadata_"): + dataframe[k.replace("scmdata_metadata_", "")] = loaded.attrs.pop(k) run = cls(dataframe) run.metadata.update(loaded.attrs) @@ -325,11 +139,14 @@ def run_to_nc(run, fname, dimensions=("region",), extras=(), **kwargs): Path to write the file into dimensions : iterable of str - Dimensions to include in the netCDF file. The time dimension is always included, even if not provided. An additional dimension (specifically a co-ordinate in xarray terms), "_id", will be included if ``extras`` is provided and any of the metadata in ``extras`` is not uniquely defined by ``dimensions``. "_id" maps the timeseries in each variable to their relevant metadata. + Dimensions to include in the netCDF file. The time dimension is always included (if not provided it will be the last dimension). An additional dimension (specifically a co-ordinate in xarray terms), "_id", will be included if ``extras`` is provided and any of the metadata in ``extras`` is not uniquely defined by ``dimensions``. "_id" maps the timeseries in each variable to their relevant metadata. extras : iterable of str Metadata columns to write as variables (specifically co-ordinates in xarray terms) in the netCDF file. Where possible, the metadata in ``dimensions`` will be used as the dimensions of these variables. However, if the metadata in ``extras`` is not defined by a single dimension in ``dimensions``, then the ``extras`` variables will have dimensions of "_id", which maps the metadata to each timeseries in the serialised dataset. + extras : iterable of str + Columns in ``self.meta`` from which to create "non-dimension co-ordinates" (see `xarray terminology `_ for more details). These non-dimension co-ordinates store extra information and can be mapped to each timeseries found in the data variables of the output :obj:`xr.Dataset`. Where possible, these non-dimension co-ordinates will use dimension co-ordinates as their own co-ordinates. However, if the metadata in ``extras`` is not defined by a single dimension in ``dimensions``, then the ``extras`` co-ordinates will have dimensions of "_id". This "_id" co-ordinate maps the values in the ``extras`` co-ordinates to each timeseries in the serialised dataset. Where "_id" is required, an extra "_id" dimension will also be added to ``dimensions``. + kwargs Passed through to :meth:`xarray.Dataset.to_netcdf` diff --git a/src/scmdata/run.py b/src/scmdata/run.py index d646a08a..f7308da0 100644 --- a/src/scmdata/run.py +++ b/src/scmdata/run.py @@ -22,6 +22,7 @@ from openscm_units import unit_registry as ur from xarray.core.ops import inject_binary_ops +from ._xarray import inject_xarray_methods from .errors import MissingRequiredColumnError, NonUniqueMetadataError from .filters import ( HIERARCHY_SEPARATOR, @@ -2185,6 +2186,7 @@ def _handle_potential_duplicates_in_append(data, duplicate_msg): inject_nc_methods(BaseScmRun) inject_plotting_methods(BaseScmRun) inject_ops_methods(BaseScmRun) +inject_xarray_methods(BaseScmRun) class ScmRun(BaseScmRun): diff --git a/tests/unit/test_netcdf.py b/tests/unit/test_netcdf.py index ced183b0..6964c12a 100644 --- a/tests/unit/test_netcdf.py +++ b/tests/unit/test_netcdf.py @@ -13,7 +13,7 @@ import xarray as xr from scmdata import ScmRun -from scmdata.netcdf import _get_xr_dataset, nc_to_run, run_to_nc +from scmdata.netcdf import _get_xr_dataset_to_write, nc_to_run, run_to_nc from scmdata.testing import assert_scmdf_almost_equal @@ -33,15 +33,15 @@ def test_run_to_nc(scm_run): assert ds.variables["scenario"][1] == "a_scenario2" npt.assert_allclose( - ds.variables["Primary_Energy"][:, 0], + ds.variables["Primary_Energy"][0, :], scm_run.filter(variable="Primary Energy", scenario="a_scenario").values[0], ) npt.assert_allclose( - ds.variables["Primary_Energy"][:, 1], + ds.variables["Primary_Energy"][1, :], scm_run.filter(variable="Primary Energy", scenario="a_scenario2").values[0], ) npt.assert_allclose( - ds.variables["Primary_Energy__Coal"][:, 0], + ds.variables["Primary_Energy__Coal"][0, :], scm_run.filter( variable="Primary Energy|Coal", scenario="a_scenario" ).values[0], @@ -411,15 +411,15 @@ def test_run_to_nc_with_extras(scm_run, dtype): assert run_id == exp_val npt.assert_allclose( - ds.variables["Primary_Energy"][:, 0], + ds.variables["Primary_Energy"][0, :], scm_run.filter(variable="Primary Energy", scenario="a_scenario").values[0], ) npt.assert_allclose( - ds.variables["Primary_Energy"][:, 1], + ds.variables["Primary_Energy"][1, :], scm_run.filter(variable="Primary Energy", scenario="a_scenario2").values[0], ) npt.assert_allclose( - ds.variables["Primary_Energy__Coal"][:, 0], + ds.variables["Primary_Energy__Coal"][0, :], scm_run.filter( variable="Primary Energy|Coal", scenario="a_scenario" ).values[0], @@ -667,12 +667,12 @@ def test_run_to_nc_loop_tricky_variable_name(scm_run, start_variable): assert_scmdf_almost_equal(scm_run, loaded, check_ts_names=False) -@patch("scmdata.netcdf._get_xr_dataset") +@patch("scmdata.netcdf._get_xr_dataset_to_write") def test_run_to_nc_xarray_kwarg_passing(mock_get_xr_dataset, scm_run, tmpdir): dimensions = ["scenario"] extras = [] mock_ds = MagicMock() - mock_ds.data_vars = _get_xr_dataset(scm_run, dimensions, extras).data_vars + mock_ds.data_vars = _get_xr_dataset_to_write(scm_run, dimensions, extras).data_vars mock_get_xr_dataset.return_value = mock_ds out_fname = join(tmpdir, "out.nc") @@ -681,7 +681,7 @@ def test_run_to_nc_xarray_kwarg_passing(mock_get_xr_dataset, scm_run, tmpdir): mock_ds.to_netcdf.assert_called_with(out_fname, engine="engine") -@patch("scmdata.netcdf._get_xr_dataset") +@patch("scmdata.netcdf._get_xr_dataset_to_write") @pytest.mark.parametrize( "in_kwargs,call_kwargs", ( @@ -714,7 +714,7 @@ def test_run_to_nc_xarray_kwarg_passing_variable_renaming( extras = [] mock_ds = MagicMock() - mock_ds.data_vars = _get_xr_dataset(scm_run, dimensions, extras).data_vars + mock_ds.data_vars = _get_xr_dataset_to_write(scm_run, dimensions, extras).data_vars mock_get_xr_dataset.return_value = mock_ds out_fname = join(tmpdir, "out.nc") diff --git a/tests/unit/test_ops.py b/tests/unit/test_ops.py index f830b21c..e42c82fe 100644 --- a/tests/unit/test_ops.py +++ b/tests/unit/test_ops.py @@ -1,3 +1,4 @@ +import datetime as dt import re from unittest.mock import patch @@ -989,3 +990,7 @@ def test_linear_regression_scmrun(): assert_scmdf_almost_equal( res, exp, allow_unordered=True, check_ts_names=False, rtol=1e-3 ) + + +# TODO: notebook illustrating rolling mean options +# Rationale: rolling means are really tricky (do you take e.g. an annual mean first, do you worry about happens at the window edge?) and they're pretty easy to convert back into ScmRun objects so a notebooks is probably more helpful than exact functionality for now diff --git a/tests/unit/test_xarray.py b/tests/unit/test_xarray.py new file mode 100644 index 00000000..49bfe4a4 --- /dev/null +++ b/tests/unit/test_xarray.py @@ -0,0 +1,316 @@ +import re + +import numpy as np +import numpy.testing as npt +import pandas as pd +import pytest +import xarray as xr + +import scmdata + + +def do_basic_to_xarray_checks(res, start_run, dimensions, extras): + assert isinstance(res, xr.Dataset) + assert set(res.data_vars) == set(start_run.get_unique_meta("variable")) + + for variable_name, data_var in res.data_vars.items(): + assert data_var.dims == dimensions + + unit = start_run.filter(variable=variable_name).get_unique_meta("unit") + assert data_var.units in unit + + # all other metadata should be in attrs + for meta_col in ( + set(start_run.meta.columns) + - set(dimensions) + - set(extras) + - {"variable", "unit"} + ): + meta_val = start_run.get_unique_meta(meta_col, True) + assert res.attrs["scmdata_metadata_{}".format(meta_col)] == meta_val + + +def do_basic_check_of_data_points(res, start_run, dimensions): + for variable_name, data_var in res.data_vars.items(): + # check a couple of data points to make sure the translation is correct + for idx in [0, -1]: + xarray_spot = data_var.isel({v: idx for v in dimensions}) + fkwargs = {k: [v.values.tolist()] for k, v in xarray_spot.coords.items()} + fkwargs["variable"] = variable_name + xarray_unit = data_var.units + + start_run_spot = start_run.filter(**fkwargs).convert_unit(xarray_unit) + if np.isnan(xarray_spot): + assert start_run_spot.empty + else: + start_run_vals = float(start_run_spot.values.squeeze()) + npt.assert_array_equal(xarray_spot.values, start_run_vals) + + +@pytest.mark.parametrize( + "dimensions,expected_dimensions", + ( + (("region", "scenario", "time"), ("region", "scenario", "time")), + (("time", "region", "scenario"), ("time", "region", "scenario")), + (("region", "time", "scenario"), ("region", "time", "scenario")), + (("region", "scenario"), ("region", "scenario", "time")), + (("scenario", "region"), ("scenario", "region", "time")), + (("scenario",), ("scenario", "time")), + ), +) +def test_to_xarray(scm_run, dimensions, expected_dimensions): + res = scm_run.to_xarray(dimensions=dimensions) + + do_basic_to_xarray_checks( + res, scm_run, expected_dimensions, (), + ) + do_basic_check_of_data_points(res, scm_run, expected_dimensions) + + # no extras + assert not set(res.coords) - set(res.dims) + + +@pytest.mark.parametrize( + "extras", (("model",), ("climate_model",), ("climate_model", "model"),) +) +def test_to_xarray_extras_no_id_coord(scm_run, extras): + dimensions = ("scenario", "region", "time") + res = scm_run.to_xarray(dimensions=dimensions, extras=extras) + + do_basic_to_xarray_checks(res, scm_run, dimensions, extras) + do_basic_check_of_data_points(res, scm_run, dimensions) + + assert set(extras) == set(res.coords) - set(res.dims) + + scm_run_meta = scm_run.meta + for extra_col in extras: + xarray_vals = res[extra_col].values + extra_dims = res[extra_col].dims + assert len(extra_dims) == 1 + extra_dims = extra_dims[0] + xarray_coords = res[extra_col][extra_dims].values + + for xarray_extra_val, extra_xarray_coord in zip(xarray_vals, xarray_coords): + scm_run_extra_val = ( + scm_run_meta[scm_run_meta[extra_dims] == extra_xarray_coord][extra_col] + .unique() + .tolist() + ) + assert len(scm_run_extra_val) == 1 + scm_run_extra_val = scm_run_extra_val[0] + + assert scm_run_extra_val == xarray_extra_val + + +@pytest.mark.parametrize("extras", (("scenario", "model", "random_key"),)) +@pytest.mark.parametrize( + "dimensions,expected_dimensions", + ( + (("climate_model", "run_id"), ("climate_model", "run_id", "time", "_id")), + (("run_id", "climate_model"), ("run_id", "climate_model", "time", "_id")), + ( + ("run_id", "climate_model", "time"), + ("run_id", "climate_model", "time", "_id"), + ), + ( + ("run_id", "time", "climate_model"), + ("run_id", "time", "climate_model", "_id"), + ), + ( + ("run_id", "climate_model", "time", "_id"), + ("run_id", "climate_model", "time", "_id"), + ), + ( + ("_id", "run_id", "time", "climate_model"), + ("_id", "run_id", "time", "climate_model"), + ), + ( + ("run_id", "_id", "climate_model"), + ("run_id", "_id", "climate_model", "time"), + ), + ), +) +def test_to_xarray_extras_with_id_coord( + scm_run, extras, dimensions, expected_dimensions +): + df = scm_run.timeseries() + val_cols = df.columns.tolist() + df = df.reset_index() + + df["climate_model"] = "base_m" + df["run_id"] = 1 + df.loc[:, val_cols] = np.random.rand(df.shape[0], len(val_cols)) + + big_df = [df] + for climate_model in ["abc_m", "def_m", "ghi_m"]: + for run_id in range(10): + new_df = df.copy() + new_df["run_id"] = run_id + new_df["climate_model"] = climate_model + new_df.loc[:, val_cols] = np.random.rand(df.shape[0], len(val_cols)) + + big_df.append(new_df) + + big_df = pd.concat(big_df).reset_index(drop=True) + big_df["random_key"] = (100 * np.random.random(big_df.shape[0])).astype(int) + scm_run = scm_run.__class__(big_df) + + res = scm_run.to_xarray(dimensions=dimensions, extras=extras) + + do_basic_to_xarray_checks(res, scm_run, expected_dimensions, extras) + + assert set(extras) == set(res.coords) - set(res.dims) + + # check a couple of data points to make sure the translation is correct + # and well-defined + scm_run_meta = scm_run.meta + for id_val in res["_id"].values[::10]: + xarray_timeseries = res.sel(_id=id_val) + fkwargs = {} + for extra_col in extras: + val = xarray_timeseries[extra_col].values.tolist() + if isinstance(val, list): + assert len(set(val)) == 1 + fkwargs[extra_col] = val[0] + else: + fkwargs[extra_col] = val + + for i, (key, value) in enumerate(fkwargs.items()): + if i < 1: + keep_meta_rows = scm_run_meta[key] == value + else: + keep_meta_rows &= scm_run_meta[key] == value + + meta_timeseries = scm_run_meta[keep_meta_rows] + for _, row in meta_timeseries.iterrows(): + scm_run_filter = row.to_dict() + scm_run_spot = scm_run.filter(**scm_run_filter) + + xarray_sel = { + k: v for k, v in scm_run_filter.items() if k in xarray_timeseries.dims + } + xarray_spot = xarray_timeseries.sel(**xarray_sel)[ + scm_run_filter["variable"] + ] + + npt.assert_array_equal( + scm_run_spot.values.squeeze(), xarray_spot.values.squeeze() + ) + + +@pytest.mark.parametrize("ch", "!@#$%^&*()~`+={}]<>,;:'\".") +@pytest.mark.parametrize("weird_idx", (0, -1, 5)) +def test_to_xarray_weird_names(scm_run, ch, weird_idx): + new_vars = [] + for i, variable_name in enumerate(scm_run.get_unique_meta("variable")): + if i < 1: + new_name = list(variable_name) + new_name.insert(weird_idx, ch) + new_name = "".join(new_name) + new_vars.append(new_name) + else: + new_vars.append(variable_name) + + dimensions = ("region", "scenario", "time") + res = scm_run.to_xarray(dimensions=dimensions) + + do_basic_to_xarray_checks( + res, scm_run, dimensions, (), + ) + do_basic_check_of_data_points(res, scm_run, dimensions) + + +def get_multiple_units_scm_run(scm_run, new_unit, new_unit_alternate): + first_var = scm_run.get_unique_meta("variable")[0] + scm_run_first_var = scm_run.filter(variable=first_var) + scm_run_first_var["unit"] = [ + v if i >= 1 else new_unit if v != new_unit else new_unit_alternate + for i, v in enumerate(scm_run_first_var["unit"].tolist()) + ] + scm_run_other_vars = scm_run.filter(variable=first_var, keep=False) + + return scmdata.run_append([scm_run_first_var, scm_run_other_vars]) + + +def test_to_xarray_multiple_units_error(scm_run): + scm_run = get_multiple_units_scm_run(scm_run, "J/yr", "MJ/yr") + + variable_unit_table = scm_run.meta[["variable", "unit"]].drop_duplicates() + variable_units = variable_unit_table.set_index("variable")["unit"] + variable_counts = variable_unit_table["variable"].value_counts() + more_than_one_unit_variables = variable_counts[variable_counts > 1] + error_msg = re.escape( + "The following variables are reported in more than one unit. " + "Found variable-unit combinations are:\n{}".format( + variable_units[more_than_one_unit_variables.index.values] + ) + ) + + with pytest.raises(ValueError, match=error_msg): + scm_run.to_xarray(dimensions=("region", "scenario", "time"), unify_units=False) + + +def test_to_xarray_unify_multiple_units(scm_run): + scm_run = get_multiple_units_scm_run(scm_run, "J/yr", "MJ/yr") + + dimensions = ("region", "scenario", "time") + res = scm_run.to_xarray(dimensions=dimensions, unify_units=True) + do_basic_to_xarray_checks( + res, scm_run, dimensions, (), + ) + do_basic_check_of_data_points(res, scm_run, dimensions) + + +def test_to_xarray_unify_multiple_units_incompatible_units(scm_run): + scm_run = get_multiple_units_scm_run(scm_run, "kg", "g") + + dimensions = ("region", "scenario", "time") + + first_var = scm_run.get_unique_meta("variable")[0] + error_msg = re.escape( + "Variable `{}` cannot be converted to a common unit. " + "Units in the provided dataset: {}.".format( + first_var, scm_run.filter(variable=first_var).get_unique_meta("unit") + ) + ) + with pytest.raises(ValueError, match=error_msg): + scm_run.to_xarray(dimensions=dimensions, unify_units=True) + + +@pytest.mark.parametrize( + "dimensions,extras", + ( + (("junk",), (),), + (("junk",), ("climate_model"),), + (("scenario", "junk_1"), ("junk",)), + (("scenario",), ("junk",)), + (("scenario",), ("junk", "climate_model")), + (("scenario",), ("junk", "junk_2", "climate_model")), + ), +) +def test_dimension_and_or_extra_not_in_metadata(scm_run, dimensions, extras): + with pytest.raises(KeyError): + scm_run.to_xarray(dimensions=dimensions, extras=extras) + + +def test_to_xarray_dimensions_extra_overlap(scm_run): + dimensions = ("scenario", "region") + extras = ("scenario",) + + error_msg = re.escape( + "dimensions and extras cannot have any overlap. " + "Current values in both dimensions and extras: {}".format({"scenario"}) + ) + with pytest.raises(ValueError, match=error_msg): + scm_run.to_xarray(dimensions=dimensions, extras=extras) + + +def test_to_xarray_non_unique_timeseries(scm_run): + dimensions = ("region",) + + error_msg = re.escape( + "dimensions: `{}` and extras: `[]` do not uniquely define the timeseries, " + "please add extra dimensions and/or extras".format(list(dimensions)) + ) + with pytest.raises(ValueError, match=error_msg): + scm_run.to_xarray(dimensions=dimensions)