Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ jobs:
cache-dependency-glob: pyproject.toml
- name: create hatch environment
run: uvx hatch env create ${{ matrix.env.name }}
- name: Show installed packages
run: |
uvx hatch run ${{ matrix.env.name }}:uv pip freeze
- name: run tests using hatch
env:
MPLBACKEND: agg
Expand Down
26 changes: 26 additions & 0 deletions ehrapy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from subprocess import PIPE, Popen
from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast

import numpy as np

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
Expand Down Expand Up @@ -165,6 +167,30 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return decorator


def _apply_over_time_axis(f: Callable) -> Callable:
"""Decorator to allow functions to handle both 2D and 3D arrays.

- If the input is 2D: pass it through unchanged.
- If the input is 3D: reshape to 2D before calling the function, then reshape the result back to 3D.
"""

@wraps(f)
def wrapper(arr, *args, **kwargs):
if arr.ndim == 2:
return f(arr, *args, **kwargs)

elif arr.ndim == 3:
n_obs, n_vars, n_time = arr.shape
arr_2d = np.moveaxis(arr, 1, 2).reshape(-1, n_vars)
arr_modified_2d = f(arr_2d, *args, **kwargs)
return np.moveaxis(arr_modified_2d.reshape(n_obs, n_time, n_vars), 1, 2)

else:
raise ValueError(f"Unsupported array dimensionality: {arr.ndim}. Please reshape the array to 2D or 3D.")

return wrapper


def _cast_adata_to_match_data_type(input_data: AnnData, target_type_reference: EHRData | AnnData) -> EHRData | AnnData:
"""Cast the data object to the type used by the function."""
if isinstance(input_data, type(target_type_reference)):
Expand Down
82 changes: 46 additions & 36 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,23 @@

import numpy as np
import pandas as pd
import scipy.sparse as sp
from ehrdata._logger import logger
from sklearn.experimental import enable_iterative_imputer # noinspection PyUnresolvedReference
from sklearn.impute import SimpleImputer

from ehrapy import settings
from ehrapy._compat import DaskArray, _raise_array_type_not_implemented, function_2D_only, use_ehrdata
from ehrapy._compat import (
DaskArray,
_apply_over_time_axis,
_raise_array_type_not_implemented,
function_2D_only,
use_ehrdata,
)
from ehrapy._progress import spinner
from ehrapy.anndata import _check_feature_types
from ehrapy.anndata._feature_specifications import _infer_numerical_column_indices
from ehrapy.anndata.anndata_ext import (
_get_var_indices,
)
from ehrapy.anndata.anndata_ext import _get_var_indices

if TYPE_CHECKING:
from anndata import AnnData
Expand Down Expand Up @@ -106,7 +111,7 @@ def _replace_explicit(arr, replacement: str | int, impute_empty_strings: bool) -
_raise_array_type_not_implemented(_replace_explicit, type(arr))


@_replace_explicit.register
@_replace_explicit.register(np.ndarray)
def _(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> np.ndarray:
"""Replace one column or whole X with a value where missing values are stored."""
if not impute_empty_strings: # pragma: no cover
Expand Down Expand Up @@ -148,9 +153,36 @@ def _extract_impute_value(replacement: dict[str, str | int], column_name: str) -
return None


@singledispatch
def _simple_impute_function(arr, strategy: Literal["mean", "median", "most_frequent"]) -> None:
_raise_array_type_not_implemented(_simple_impute_function, type(arr))


@_simple_impute_function.register(sp.coo_array)
def _(arr: sp.coo_array, strategy: Literal["mean", "median", "most_frequent"]) -> sp.coo_array:
_raise_array_type_not_implemented(_simple_impute_function, type(arr))


@_simple_impute_function.register(DaskArray)
@_apply_over_time_axis
def _(arr: DaskArray, strategy: Literal["mean", "median", "most_frequent"]) -> DaskArray:
import dask_ml.impute

arr_dtype = arr.dtype
return dask_ml.impute.SimpleImputer(strategy=strategy).fit_transform(arr.astype(float)).astype(arr_dtype)


@_simple_impute_function.register(sp.csc_array)
@_simple_impute_function.register(sp.csr_array)
@_simple_impute_function.register(np.ndarray)
@_apply_over_time_axis
def _(arr: np.ndarray, strategy: Literal["mean", "median", "most_frequent"]) -> np.ndarray:
import sklearn

return sklearn.impute.SimpleImputer(strategy=strategy).fit_transform(arr)


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
@spinner("Performing simple impute")
def simple_impute(
edata: EHRData | AnnData,
var_names: Iterable[str] | None = None,
Expand Down Expand Up @@ -186,39 +218,17 @@ def simple_impute(
if copy:
edata = edata.copy()

_warn_imputation_threshold(edata, var_names, threshold=warning_threshold, layer=layer)
# TODO: warn again if qc_metrics is 3D enabled
# _warn_imputation_threshold(edata, var_names, threshold=warning_threshold, layer=layer)

if strategy in {"median", "mean"}:
try:
_simple_impute(edata, var_names, strategy, layer)
except ValueError:
raise ValueError(
f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation "
"to certain columns using var_names parameter or use a different mode."
) from None
# most_frequent imputation works with non-numerical data as well
elif strategy == "most_frequent":
_simple_impute(edata, var_names, strategy, layer)
else:
raise ValueError(
f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent."
) from None
var_indices = _get_var_indices(edata, edata.var_names if var_names is None else var_names)

return edata if copy else None


def _simple_impute(edata: EHRData | AnnData, var_names: Iterable[str] | None, strategy: str, layer: str | None) -> None:
imputer = SimpleImputer(strategy=strategy)
if layer is None:
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
edata[:, var_names].X = imputer.fit_transform(edata[:, var_names].X)
else:
edata.X = imputer.fit_transform(edata.X)
edata.X[:, var_indices] = _simple_impute_function(edata.X[:, var_indices], strategy)
else:
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
edata[:, var_names].layers[layer] = imputer.fit_transform(edata[:, var_names].layers[layer])
else:
edata.layers[layer] = imputer.fit_transform(edata.layers[layer])
edata.layers[layer][:, var_indices] = _simple_impute_function(edata.layers[layer][:, var_indices], strategy)

return edata if copy else None


@_check_feature_types
Expand Down
30 changes: 28 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pandas as pd
import pytest
import scipy.sparse as sp
from anndata import AnnData
from ehrdata.core.constants import CATEGORICAL_TAG, DEFAULT_TEM_LAYER_NAME, FEATURE_TYPE_KEY, NUMERIC_TAG
from matplotlib.testing.compare import compare_images
Expand Down Expand Up @@ -141,7 +142,11 @@ def mcar_edata(rng) -> ed.EHRData:
missing_indices = rng.choice(a=[False, True], size=data.shape, p=[1 - 0.1, 0.1])
data[missing_indices] = np.nan

return ed.EHRData(data)
data_3d = rng.random((100, 10, 3))
missing_indices = rng.choice(a=[False, True], size=data_3d.shape, p=[1 - 0.1, 0.1])
data_3d[missing_indices] = np.nan

return ed.EHRData(data, layers={DEFAULT_TEM_LAYER_NAME: data_3d})


@pytest.fixture
Expand All @@ -151,6 +156,20 @@ def edata_mini():
)


@pytest.fixture
def edata_mini_3D_missing_values():
tiny_mixed_array = np.array(
[
[[138, 139], [78, np.nan], [77, 76], [1, 2], ["A", "B"], ["Yes", np.nan]],
[[140, 141], [80, 81], [60, 90], [0, 1], ["A", "A"], ["Yes", "Yes"]],
[[148, 149], [77, 78], [110, np.nan], [0, 1], [np.nan, "B"], ["Yes", "Yes"]],
[[150, 151], [79, 80], [56, np.nan], [2, 3], ["B", "B"], ["Yes", "No"]],
],
dtype=object,
)
return ed.EHRData(layers={DEFAULT_TEM_LAYER_NAME: tiny_mixed_array})


@pytest.fixture
def edata_mini_sample():
return ed.io.read_csv(f"{TEST_DATA_PATH}/dataset1.csv", columns_obs_only=["clinic_day"])
Expand Down Expand Up @@ -415,4 +434,11 @@ def as_dense_dask_array(a, chunk_size=1000):
return da.from_array(a, chunks=chunk_size)


ARRAY_TYPES = (asarray, as_dense_dask_array)
ARRAY_TYPES_NUMERIC = (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest these collections of array types for testing, which I hope reduces some repetitiveness

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we should have this here or even in our compat code because we might also use this for the implementations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree actually, didn't occur to me before. -> asked @sueoglu to move in #967

asarray,
as_dense_dask_array,
sp.csr_array,
sp.csc_array,
) # add coo_array once supported in AnnData
ARRAY_TYPES_NUMERIC_3D_ABLE = (asarray, as_dense_dask_array) # add coo_array once supported in AnnData
ARRAY_TYPES_NONNUMERIC = (asarray, as_dense_dask_array)
120 changes: 105 additions & 15 deletions tests/preprocessing/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
miss_forest_impute,
simple_impute,
)
from tests.conftest import ARRAY_TYPES, TEST_DATA_PATH
from tests.conftest import ARRAY_TYPES_NONNUMERIC, ARRAY_TYPES_NUMERIC, ARRAY_TYPES_NUMERIC_3D_ABLE, TEST_DATA_PATH

CURRENT_DIR = Path(__file__).parent
_TEST_PATH = f"{TEST_DATA_PATH}/imputation"
Expand Down Expand Up @@ -92,7 +92,9 @@ def _is_val_missing(data: np.ndarray) -> np.ndarray[Any, np.dtype[np.bool_]]:
raise AssertionError("Values outside imputed columns were modified.")

# Ensure imputation does not alter non-NaN values in the imputed columns
imputed_non_nan_mask = (~before_nan_mask) & imputed_mask
imputed_non_nan_mask = (~before_nan_mask) & (
imputed_mask[None, :] if layer_before.ndim == 2 else imputed_mask[None, :, None]
)
if not _are_ndarrays_equal(layer_before[imputed_non_nan_mask], layer_after[imputed_non_nan_mask]):
raise AssertionError("Non-NaN values in imputed columns were modified.")

Expand Down Expand Up @@ -147,10 +149,36 @@ def test_base_check_imputation_change_detected_in_imputed_column(impute_num_edat
_base_check_imputation(impute_num_edata, edata_imputed)


@pytest.mark.parametrize(
"array_type,expected_error",
[
(np.array, None),
(da.array, None),
(sparse.csr_array, None),
(sparse.csc_array, None),
# (sparse.coo_array, None) # not yet supported by AnnData
],
)
def test_simple_impute_array_types(impute_num_edata, array_type, expected_error):
impute_num_edata.X = array_type(impute_num_edata.X)

if expected_error:
with pytest.raises(expected_error):
simple_impute(impute_num_edata, strategy="mean")


@pytest.mark.parametrize("array_type", ARRAY_TYPES_NUMERIC)
@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
def test_simple_impute_basic(impute_num_edata, strategy):
edata_imputed = simple_impute(impute_num_edata, strategy=strategy, copy=True)
_base_check_imputation(impute_num_edata, edata_imputed)
def test_simple_impute_basic(impute_num_edata, array_type, strategy):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation: This test and the following one are specific for the function, which is about imputation here

impute_num_edata.X = array_type(impute_num_edata.X)

if isinstance(impute_num_edata.X, da.Array) and strategy != "mean":
with pytest.raises(ValueError):
edata_imputed = simple_impute(impute_num_edata, strategy=strategy, copy=True)

else:
edata_imputed = simple_impute(impute_num_edata, strategy=strategy, copy=True)
_base_check_imputation(impute_num_edata, edata_imputed)


@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
Expand All @@ -161,19 +189,81 @@ def test_simple_impute_copy(impute_num_edata, strategy):
_base_check_imputation(impute_num_edata, edata_imputed)


@pytest.mark.parametrize("array_type", ARRAY_TYPES_NONNUMERIC)
@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
def test_simple_impute_subset(impute_edata, strategy):
def test_simple_impute_subset(impute_edata, array_type, strategy):
impute_edata.X = array_type(impute_edata.X)
var_names = ("intcol", "indexcol")
edata_imputed = simple_impute(impute_edata, var_names=var_names, copy=True)
if isinstance(impute_edata.X, da.Array) and strategy != "mean":
with pytest.raises(ValueError):
edata_imputed = simple_impute(impute_edata, var_names=var_names, strategy=strategy, copy=True)
else:
edata_imputed = simple_impute(impute_edata, var_names=var_names, strategy=strategy, copy=True)

_base_check_imputation(impute_edata, edata_imputed, imputed_var_names=var_names)
assert np.any([item != item for item in edata_imputed.X[::, 3:4]])
_base_check_imputation(impute_edata, edata_imputed, imputed_var_names=var_names)
assert np.any([item != item for item in edata_imputed.X[::, 3:4]])

# manually verified computation result
if strategy == "mean":
assert edata_imputed.X[0, 1] == 3.0
elif strategy == "most_frequent":
assert edata_imputed.X[0, 1] == 2.0 # if multiple equally frequent values, return minimum

def test_simple_impute_3D_edata(edata_blob_small):
simple_impute(edata_blob_small, layer="layer_2")
with pytest.raises(ValueError, match=r"only supports 2D data"):
simple_impute(edata_blob_small, layer=DEFAULT_TEM_LAYER_NAME)

@pytest.mark.parametrize("array_type", ARRAY_TYPES_NUMERIC_3D_ABLE)
@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
def test_simple_impute_3D_edata(mcar_edata, array_type, strategy):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation: this test checks that for numeric data, the method can handle 3D data.

There might be combinations of arguments that are not supported, which should be mentioned in the documentation if not obvious.

mcar_edata.layers[DEFAULT_TEM_LAYER_NAME] = array_type(mcar_edata.layers[DEFAULT_TEM_LAYER_NAME])

if isinstance(mcar_edata.layers[DEFAULT_TEM_LAYER_NAME], da.Array) and strategy != "mean":
with pytest.raises(ValueError):
edata_imputed = simple_impute(mcar_edata, layer=DEFAULT_TEM_LAYER_NAME, strategy=strategy, copy=True)

else:
edata_imputed = simple_impute(mcar_edata, layer=DEFAULT_TEM_LAYER_NAME, strategy=strategy, copy=True)
_base_check_imputation(
mcar_edata,
edata_imputed,
before_imputation_layer=DEFAULT_TEM_LAYER_NAME,
after_imputation_layer=DEFAULT_TEM_LAYER_NAME,
)

# manually verify computation result for 1 value
if strategy in {"mean", "median"}:
element = edata_imputed[9, 0, 0].layers[DEFAULT_TEM_LAYER_NAME]

if strategy == "mean":
reference_value = np.nanmean(mcar_edata[:, 0, :].layers[DEFAULT_TEM_LAYER_NAME])
elif strategy == "median":
reference_value = np.nanmedian(mcar_edata[:, 0, :].layers[DEFAULT_TEM_LAYER_NAME])

assert np.isclose(element, reference_value)


@pytest.mark.parametrize("array_type", ARRAY_TYPES_NONNUMERIC)
@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
def test_simple_impute_3D_edata_nonnumeric(edata_mini_3D_missing_values, array_type, strategy):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation: this checks whether the function can also handle non-numeric 3D data.

There might be combinations of arguments that are not supported, which should be mentioned in the documentation if not obvious.

edata_mini_3D_missing_values.layers[DEFAULT_TEM_LAYER_NAME] = array_type(
edata_mini_3D_missing_values.layers[DEFAULT_TEM_LAYER_NAME]
)

if strategy == "most_frequent" and not isinstance(
edata_mini_3D_missing_values.layers[DEFAULT_TEM_LAYER_NAME], da.Array
):
edata_imputed = simple_impute(
edata_mini_3D_missing_values, layer=DEFAULT_TEM_LAYER_NAME, strategy=strategy, copy=True
)
_base_check_imputation(
edata_mini_3D_missing_values,
edata_imputed,
before_imputation_layer=DEFAULT_TEM_LAYER_NAME,
after_imputation_layer=DEFAULT_TEM_LAYER_NAME,
)
else:
with pytest.raises(ValueError):
edata_imputed = simple_impute(
edata_mini_3D_missing_values, layer=DEFAULT_TEM_LAYER_NAME, strategy=strategy, copy=True
)


@pytest.mark.parametrize("strategy", ["mean", "median"])
Expand Down Expand Up @@ -330,7 +420,7 @@ def test_explicit_impute_3D_edata(edata_blob_small):
explicit_impute(edata_blob_small, replacement=1011, layer=DEFAULT_TEM_LAYER_NAME)


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("array_type", ARRAY_TYPES_NONNUMERIC)
def test_explicit_impute_all(array_type, impute_num_edata):
impute_num_edata.X = array_type(impute_num_edata.X)
warnings.filterwarnings("ignore", category=FutureWarning)
Expand All @@ -340,7 +430,7 @@ def test_explicit_impute_all(array_type, impute_num_edata):
assert np.sum([edata_imputed.X == 1011]) == 3


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("array_type", ARRAY_TYPES_NONNUMERIC)
def test_explicit_impute_subset(impute_edata, array_type):
impute_edata.X = array_type(impute_edata.X)
edata_imputed = explicit_impute(impute_edata, replacement={"strcol": "REPLACED", "intcol": 1011}, copy=True)
Expand Down
Loading
Loading