Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ jobs:
cache-dependency-glob: pyproject.toml
- name: create hatch environment
run: uvx hatch env create ${{ matrix.env.name }}
- name: Show installed packages
run: |
uvx hatch run ${{ matrix.env.name }}:uv pip freeze
- name: run tests using hatch
env:
MPLBACKEND: agg
Expand Down
26 changes: 26 additions & 0 deletions ehrapy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from subprocess import PIPE, Popen
from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast

import numpy as np

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
Expand Down Expand Up @@ -165,6 +167,30 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return decorator


def _apply_over_time_axis(f: Callable) -> Callable:
"""Decorator to allow functions to handle both 2D and 3D arrays.

- If the input is 2D: pass it through unchanged.
- If the input is 3D: reshape to 2D before calling the function, then reshape the result back to 3D.
"""

@wraps(f)
def wrapper(arr, *args, **kwargs):
if arr.ndim == 2:
return f(arr, *args, **kwargs)

elif arr.ndim == 3:
n_obs, n_vars, n_time = arr.shape
arr_2d = np.moveaxis(arr, 1, 2).reshape(-1, n_vars)
arr_modified_2d = f(arr_2d, *args, **kwargs)
return np.moveaxis(arr_modified_2d.reshape(n_obs, n_time, n_vars), 1, 2)

else:
raise ValueError(f"Unsupported array dimensionality: {arr.ndim}. Please reshape the array to 2D or 3D.")

return wrapper


def _cast_adata_to_match_data_type(input_data: AnnData, target_type_reference: EHRData | AnnData) -> EHRData | AnnData:
"""Cast the data object to the type used by the function."""
if isinstance(input_data, type(target_type_reference)):
Expand Down
84 changes: 47 additions & 37 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,23 @@

import numpy as np
import pandas as pd
import scipy.sparse as sp
from ehrdata._logger import logger
from sklearn.experimental import enable_iterative_imputer # noinspection PyUnresolvedReference
from sklearn.impute import SimpleImputer

from ehrapy import settings
from ehrapy._compat import DaskArray, _raise_array_type_not_implemented, function_2D_only, use_ehrdata
from ehrapy._compat import (
DaskArray,
_apply_over_time_axis,
_raise_array_type_not_implemented,
function_2D_only,
use_ehrdata,
)
from ehrapy._progress import spinner
from ehrapy.anndata import _check_feature_types
from ehrapy.anndata._feature_specifications import _infer_numerical_column_indices
from ehrapy.anndata.anndata_ext import (
_get_var_indices,
)
from ehrapy.anndata.anndata_ext import _get_var_indices

if TYPE_CHECKING:
from anndata import AnnData
Expand Down Expand Up @@ -106,7 +111,7 @@ def _replace_explicit(arr, replacement: str | int, impute_empty_strings: bool) -
_raise_array_type_not_implemented(_replace_explicit, type(arr))


@_replace_explicit.register
@_replace_explicit.register(np.ndarray)
def _(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> np.ndarray:
"""Replace one column or whole X with a value where missing values are stored."""
if not impute_empty_strings: # pragma: no cover
Expand Down Expand Up @@ -148,9 +153,36 @@ def _extract_impute_value(replacement: dict[str, str | int], column_name: str) -
return None


@singledispatch
def _simple_impute_function(arr, strategy: Literal["mean", "median", "most_frequent"]) -> None:
_raise_array_type_not_implemented(_simple_impute_function, type(arr))


@_simple_impute_function.register(sp.coo_array)
def _(arr: sp.coo_array, strategy: Literal["mean", "median", "most_frequent"]) -> sp.coo_array:
_raise_array_type_not_implemented(_simple_impute_function, type(arr))


@_simple_impute_function.register(DaskArray)
@_apply_over_time_axis
def _(arr: DaskArray, strategy: Literal["mean", "median", "most_frequent"]) -> DaskArray:
import dask_ml.impute

arr_dtype = arr.dtype
return dask_ml.impute.SimpleImputer(strategy=strategy).fit_transform(arr.astype(float)).astype(arr_dtype)


@_simple_impute_function.register(sp.csc_array)
@_simple_impute_function.register(sp.csr_array)
@_simple_impute_function.register(np.ndarray)
@_apply_over_time_axis
def _(arr: np.ndarray, strategy: Literal["mean", "median", "most_frequent"]) -> np.ndarray:
import sklearn

return sklearn.impute.SimpleImputer(strategy=strategy).fit_transform(arr)


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
@spinner("Performing simple impute")
def simple_impute(
edata: EHRData | AnnData,
var_names: Iterable[str] | None = None,
Expand All @@ -168,7 +200,7 @@ def simple_impute(
Args:
edata: Central data object.
var_names: A list of column names to apply imputation on (if None, impute all columns).
strategy: Imputation strategy to use. One of {'mean', 'median', 'most_frequent'}.
strategy: Imputation strategy to use. One of {'mean', 'median', 'most_frequent'}. If data is a `dask.array.Array`, only 'mean' is supported.
warning_threshold: Display a warning message if percentage of missing values exceeds this threshold.
layer: The layer to impute.
copy: Whether to return a copy of `edata` or modify it inplace.
Expand All @@ -186,39 +218,17 @@ def simple_impute(
if copy:
edata = edata.copy()

_warn_imputation_threshold(edata, var_names, threshold=warning_threshold, layer=layer)
# TODO: warn again if qc_metrics is 3D enabled
# _warn_imputation_threshold(edata, var_names, threshold=warning_threshold, layer=layer)

if strategy in {"median", "mean"}:
try:
_simple_impute(edata, var_names, strategy, layer)
except ValueError:
raise ValueError(
f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation "
"to certain columns using var_names parameter or use a different mode."
) from None
# most_frequent imputation works with non-numerical data as well
elif strategy == "most_frequent":
_simple_impute(edata, var_names, strategy, layer)
else:
raise ValueError(
f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent."
) from None
var_indices = _get_var_indices(edata, edata.var_names if var_names is None else var_names)

return edata if copy else None


def _simple_impute(edata: EHRData | AnnData, var_names: Iterable[str] | None, strategy: str, layer: str | None) -> None:
imputer = SimpleImputer(strategy=strategy)
if layer is None:
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
edata[:, var_names].X = imputer.fit_transform(edata[:, var_names].X)
else:
edata.X = imputer.fit_transform(edata.X)
edata.X[:, var_indices] = _simple_impute_function(edata.X[:, var_indices], strategy)
else:
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
edata[:, var_names].layers[layer] = imputer.fit_transform(edata[:, var_names].layers[layer])
else:
edata.layers[layer] = imputer.fit_transform(edata.layers[layer])
edata.layers[layer][:, var_indices] = _simple_impute_function(edata.layers[layer][:, var_indices], strategy)

return edata if copy else None


@_check_feature_types
Expand Down
30 changes: 28 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pandas as pd
import pytest
import scipy.sparse as sp
from anndata import AnnData
from ehrdata.core.constants import CATEGORICAL_TAG, DEFAULT_TEM_LAYER_NAME, FEATURE_TYPE_KEY, NUMERIC_TAG
from matplotlib.testing.compare import compare_images
Expand Down Expand Up @@ -141,7 +142,11 @@ def mcar_edata(rng) -> ed.EHRData:
missing_indices = rng.choice(a=[False, True], size=data.shape, p=[1 - 0.1, 0.1])
data[missing_indices] = np.nan

return ed.EHRData(data)
data_3d = rng.random((100, 10, 3))
missing_indices = rng.choice(a=[False, True], size=data_3d.shape, p=[1 - 0.1, 0.1])
data_3d[missing_indices] = np.nan

return ed.EHRData(data, layers={DEFAULT_TEM_LAYER_NAME: data_3d})


@pytest.fixture
Expand All @@ -151,6 +156,20 @@ def edata_mini():
)


@pytest.fixture
def edata_mini_3D_missing_values():
tiny_mixed_array = np.array(
[
[[138, 139], [78, np.nan], [77, 76], [1, 2], ["A", "B"], ["Yes", np.nan]],
[[140, 141], [80, 81], [60, 90], [0, 1], ["A", "A"], ["Yes", "Yes"]],
[[148, 149], [77, 78], [110, np.nan], [0, 1], [np.nan, "B"], ["Yes", "Yes"]],
[[150, 151], [79, 80], [56, np.nan], [2, 3], ["B", "B"], ["Yes", "No"]],
],
dtype=object,
)
return ed.EHRData(layers={DEFAULT_TEM_LAYER_NAME: tiny_mixed_array})


@pytest.fixture
def edata_mini_sample():
return ed.io.read_csv(f"{TEST_DATA_PATH}/dataset1.csv", columns_obs_only=["clinic_day"])
Expand Down Expand Up @@ -415,4 +434,11 @@ def as_dense_dask_array(a, chunk_size=1000):
return da.from_array(a, chunks=chunk_size)


ARRAY_TYPES = (asarray, as_dense_dask_array)
ARRAY_TYPES_NUMERIC = (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest these collections of array types for testing, which I hope reduces some repetitiveness

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we should have this here or even in our compat code because we might also use this for the implementations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree actually, didn't occur to me before. -> asked @sueoglu to move in #967

asarray,
as_dense_dask_array,
sp.csr_array,
sp.csc_array,
) # add coo_array once supported in AnnData
ARRAY_TYPES_NUMERIC_3D_ABLE = (asarray, as_dense_dask_array) # add coo_array once supported in AnnData
ARRAY_TYPES_NONNUMERIC = (asarray, as_dense_dask_array)
Loading
Loading