Skip to content
25 changes: 25 additions & 0 deletions ehrapy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,31 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return decorator


def _support_3d(f: Callable) -> Callable:
"""Decorator to allow functions to handle both 2D and 3D arrays.

- If the input is 2D: pass it through unchanged.
- If the input is 3D: reshape to 2D before calling the function,
then reshape the result back to 3D.
"""

@wraps(f)
def wrapper(arr, *args, **kwargs):
if arr.ndim == 2:
return f(arr, *args, **kwargs)

elif arr.ndim == 3:
n_obs, n_vars, n_time = arr.shape
arr_2d = arr.transpose(0, 2, 1).reshape(-1, n_vars)
arr_modified_2d = f(arr_2d, *args, **kwargs)
return arr_modified_2d.reshape(n_obs, n_time, n_vars).transpose(0, 2, 1)

else:
raise ValueError(f"Unsupported array dimensionality: {arr.ndim}")

return wrapper


def _cast_adata_to_match_data_type(input_data: AnnData, target_type_reference: EHRData | AnnData) -> EHRData | AnnData:
"""Cast the data object to the type used by the function."""
if isinstance(input_data, type(target_type_reference)):
Expand Down
65 changes: 31 additions & 34 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
from sklearn.impute import SimpleImputer

from ehrapy import settings
from ehrapy._compat import DaskArray, _raise_array_type_not_implemented, function_2D_only, use_ehrdata
from ehrapy._compat import DaskArray, _raise_array_type_not_implemented, _support_3d, function_2D_only, use_ehrdata
from ehrapy._progress import spinner
from ehrapy.anndata import _check_feature_types
from ehrapy.anndata._feature_specifications import _infer_numerical_column_indices
from ehrapy.anndata.anndata_ext import (
_get_var_indices,
)
from ehrapy.anndata.anndata_ext import _get_var_indices

if TYPE_CHECKING:
from anndata import AnnData
Expand Down Expand Up @@ -148,8 +146,29 @@ def _extract_impute_value(replacement: dict[str, str | int], column_name: str) -
return None


@singledispatch
def _simple_impute_function(arr, strategy: Literal["mean", "median", "most_frequent"]) -> None:
_raise_array_type_not_implemented(_simple_impute_function, type(arr))


@_simple_impute_function.register
@_support_3d
def _(arr: DaskArray, strategy: Literal["mean", "median", "most_frequent"]) -> DaskArray:
import dask_ml.impute

arr_dtype = arr.dtype
return dask_ml.impute.SimpleImputer(strategy=strategy).fit_transform(arr.astype(float)).astype(arr_dtype)


@_simple_impute_function.register
@_support_3d
def _(arr: np.ndarray, strategy: Literal["mean", "median", "most_frequent"]) -> np.ndarray:
import sklearn

return sklearn.impute.SimpleImputer(strategy=strategy).fit_transform(arr)


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
@spinner("Performing simple impute")
def simple_impute(
edata: EHRData | AnnData,
Expand Down Expand Up @@ -186,39 +205,17 @@ def simple_impute(
if copy:
edata = edata.copy()

_warn_imputation_threshold(edata, var_names, threshold=warning_threshold, layer=layer)

if strategy in {"median", "mean"}:
try:
_simple_impute(edata, var_names, strategy, layer)
except ValueError:
raise ValueError(
f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation "
"to certain columns using var_names parameter or use a different mode."
) from None
# most_frequent imputation works with non-numerical data as well
elif strategy == "most_frequent":
_simple_impute(edata, var_names, strategy, layer)
else:
raise ValueError(
f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent."
) from None
# TODO: warn again if qc_metrics is 3D enabled
# _warn_imputation_threshold(edata, var_names, threshold=warning_threshold, layer=layer)

return edata if copy else None
var_indices = _get_var_indices(edata, edata.var_names if var_names is None else var_names)


def _simple_impute(edata: EHRData | AnnData, var_names: Iterable[str] | None, strategy: str, layer: str | None) -> None:
imputer = SimpleImputer(strategy=strategy)
if layer is None:
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
edata[:, var_names].X = imputer.fit_transform(edata[:, var_names].X)
else:
edata.X = imputer.fit_transform(edata.X)
edata.X[:, var_indices] = _simple_impute_function(edata.X[:, var_indices], strategy)
else:
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
edata[:, var_names].layers[layer] = imputer.fit_transform(edata[:, var_names].layers[layer])
else:
edata.layers[layer] = imputer.fit_transform(edata.layers[layer])
edata.layers[layer][:, var_indices] = _simple_impute_function(edata.layers[layer][:, var_indices], strategy)

return edata if copy else None


@_check_feature_types
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ def mcar_edata(rng) -> ed.EHRData:
missing_indices = rng.choice(a=[False, True], size=data.shape, p=[1 - 0.1, 0.1])
data[missing_indices] = np.nan

return ed.EHRData(data)
data_3d = rng.random((100, 10, 3))
missing_indices = rng.choice(a=[False, True], size=data_3d.shape, p=[1 - 0.1, 0.1])
data_3d[missing_indices] = np.nan

return ed.EHRData(data, layers={DEFAULT_TEM_LAYER_NAME: data_3d})


@pytest.fixture
Expand Down
86 changes: 74 additions & 12 deletions tests/preprocessing/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def _is_val_missing(data: np.ndarray) -> np.ndarray[Any, np.dtype[np.bool_]]:
raise AssertionError("Values outside imputed columns were modified.")

# Ensure imputation does not alter non-NaN values in the imputed columns
imputed_non_nan_mask = (~before_nan_mask) & imputed_mask
imputed_non_nan_mask = (~before_nan_mask) & (
imputed_mask[None, :] if layer_before.ndim == 2 else imputed_mask[None, :, None]
)
if not _are_ndarrays_equal(layer_before[imputed_non_nan_mask], layer_after[imputed_non_nan_mask]):
raise AssertionError("Non-NaN values in imputed columns were modified.")

Expand Down Expand Up @@ -147,10 +149,34 @@ def test_base_check_imputation_change_detected_in_imputed_column(impute_num_edat
_base_check_imputation(impute_num_edata, edata_imputed)


@pytest.mark.parametrize(
"array_type,expected_error",
[
(np.array, None),
(da.array, None),
(sparse.csr_matrix, NotImplementedError),
],
)
def test_simple_impute_array_types(impute_num_edata, array_type, expected_error):
impute_num_edata.X = array_type(impute_num_edata.X)

if expected_error:
with pytest.raises(expected_error):
simple_impute(impute_num_edata, strategy="mean")


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
def test_simple_impute_basic(impute_num_edata, strategy):
edata_imputed = simple_impute(impute_num_edata, strategy=strategy, copy=True)
_base_check_imputation(impute_num_edata, edata_imputed)
def test_simple_impute_basic(impute_num_edata, array_type, strategy):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation: This test and the following one are specific for the function, which is about imputation here

impute_num_edata.X = array_type(impute_num_edata.X)

if isinstance(impute_num_edata.X, da.Array) and strategy != "mean":
with pytest.raises(ValueError):
edata_imputed = simple_impute(impute_num_edata, strategy=strategy, copy=True)

else:
edata_imputed = simple_impute(impute_num_edata, strategy=strategy, copy=True)
_base_check_imputation(impute_num_edata, edata_imputed)


@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
Expand All @@ -161,19 +187,55 @@ def test_simple_impute_copy(impute_num_edata, strategy):
_base_check_imputation(impute_num_edata, edata_imputed)


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
def test_simple_impute_subset(impute_edata, strategy):
def test_simple_impute_subset(impute_edata, array_type, strategy):
impute_edata.X = array_type(impute_edata.X)
var_names = ("intcol", "indexcol")
edata_imputed = simple_impute(impute_edata, var_names=var_names, copy=True)
if isinstance(impute_edata.X, da.Array) and strategy != "mean":
with pytest.raises(ValueError):
edata_imputed = simple_impute(impute_edata, var_names=var_names, strategy=strategy, copy=True)
else:
edata_imputed = simple_impute(impute_edata, var_names=var_names, strategy=strategy, copy=True)

_base_check_imputation(impute_edata, edata_imputed, imputed_var_names=var_names)
assert np.any([item != item for item in edata_imputed.X[::, 3:4]])
_base_check_imputation(impute_edata, edata_imputed, imputed_var_names=var_names)
assert np.any([item != item for item in edata_imputed.X[::, 3:4]])

# manually verified computation result
if strategy == "mean":
assert edata_imputed.X[0, 1] == 3.0
elif strategy == "most_frequent":
assert edata_imputed.X[0, 1] == 2.0 # if multiple equally frequent values, return minimum

def test_simple_impute_3D_edata(edata_blob_small):
simple_impute(edata_blob_small, layer="layer_2")
with pytest.raises(ValueError, match=r"only supports 2D data"):
simple_impute(edata_blob_small, layer=DEFAULT_TEM_LAYER_NAME)

@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"])
def test_simple_impute_3D_edata(mcar_edata, array_type, strategy):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation: this test checks that for numeric data, the method can handle 3D data.

There might be combinations of arguments that are not supported, which should be mentioned in the documentation if not obvious.

mcar_edata.layers[DEFAULT_TEM_LAYER_NAME] = array_type(mcar_edata.layers[DEFAULT_TEM_LAYER_NAME])

if isinstance(mcar_edata.layers[DEFAULT_TEM_LAYER_NAME], da.Array) and strategy != "mean":
with pytest.raises(ValueError):
edata_imputed = simple_impute(mcar_edata, layer=DEFAULT_TEM_LAYER_NAME, strategy=strategy, copy=True)

else:
edata_imputed = simple_impute(mcar_edata, layer=DEFAULT_TEM_LAYER_NAME, strategy=strategy, copy=True)
_base_check_imputation(
mcar_edata,
edata_imputed,
before_imputation_layer=DEFAULT_TEM_LAYER_NAME,
after_imputation_layer=DEFAULT_TEM_LAYER_NAME,
)

# manually verify computation result for 1 value
if strategy in {"mean", "median"}:
element = edata_imputed[9, 0, 0].layers[DEFAULT_TEM_LAYER_NAME]

if strategy == "mean":
reference_value = np.nanmean(mcar_edata[:, 0, :].layers[DEFAULT_TEM_LAYER_NAME])
elif strategy == "median":
reference_value = np.nanmedian(mcar_edata[:, 0, :].layers[DEFAULT_TEM_LAYER_NAME])

assert np.isclose(element, reference_value)


@pytest.mark.parametrize("strategy", ["mean", "median"])
Expand Down
Loading