Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/unit_tests/data_validation/test_IQROutliersBarPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,20 @@ def test_binary_exclusion(self):
# Check that binary column is not included in figures
figure_titles = [fig.layout.title.text for fig in results[:-1]]
self.assertNotIn("binary", figure_titles)

def test_boolean_dtype_excluded_from_raw_data(self):
n_samples = 100
df = pd.DataFrame(
{
"numeric": np.random.randn(n_samples),
"flag": np.random.choice([True, False], n_samples),
}
)
vm_dataset = vm.init_dataset(
input_id="test_boolean_dataset", dataset=df, __log=False
)

results = IQROutliersBarPlot(vm_dataset)
raw_data = results[-1]

self.assertNotIn("flag", raw_data.outlier_counts_by_feature.index)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest

from validmind.tests.model_validation.sklearn.WeakspotsDiagnosis import (
_prepare_metrics_and_thresholds,
)


class TestWeakspotsDiagnosisThresholds(unittest.TestCase):
def test_partial_thresholds_use_defaults_for_plotting(self):
_, plot_thresholds, pass_thresholds = _prepare_metrics_and_thresholds(
metrics=None,
thresholds={"accuracy": 0.65},
)

self.assertEqual(pass_thresholds, {"Accuracy": 0.65})
self.assertEqual(plot_thresholds["Accuracy"], 0.65)
self.assertEqual(plot_thresholds["Precision"], 0.5)
self.assertEqual(plot_thresholds["Recall"], 0.5)
self.assertEqual(plot_thresholds["F1"], 0.7)

def test_partial_thresholds_subset_for_pass_fail(self):
_, _, pass_thresholds = _prepare_metrics_and_thresholds(
metrics=None,
thresholds={"accuracy": 0.75, "f1": 0.55},
)

self.assertEqual(set(pass_thresholds.keys()), {"Accuracy", "F1"})


if __name__ == "__main__":
unittest.main()
16 changes: 10 additions & 6 deletions validmind/tests/data_validation/IQROutliersBarPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def IQROutliersBarPlot(
"""
df = dataset.df

# Exclude binary/boolean features (IQR is not meaningful and quantile fails on bool)
eligible_columns = [
col for col in dataset.feature_columns_numeric if len(df[col].unique()) > 2
]

figures = []

for col in dataset.feature_columns_numeric:
# Skip binary features
if len(df[col].unique()) <= 2:
continue
for col in eligible_columns:

outliers = compute_outliers(df[col], threshold)
if outliers.empty:
Expand Down Expand Up @@ -121,8 +123,10 @@ def IQROutliersBarPlot(
)
figures.append(fig)

outliers_by_feature = df[dataset.feature_columns_numeric].apply(
lambda col: compute_outliers(col, threshold)
outliers_by_feature = (
df[eligible_columns].apply(lambda col: compute_outliers(col, threshold))
if eligible_columns
else df.iloc[:, 0:0]
)

return (
Expand Down
46 changes: 39 additions & 7 deletions validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,36 @@
}


def _normalize_dict_keys(d: Dict) -> Dict:
"""Normalize metric/threshold keys to title case (e.g. 'f1' -> 'F1')."""
return {k.title(): v for k, v in d.items()}


def _prepare_metrics_and_thresholds(
metrics: Optional[Dict[str, Callable]],
thresholds: Optional[Dict[str, float]],
) -> Tuple[Dict[str, Callable], Dict[str, float], Dict[str, float]]:
"""
Prepare metrics and threshold dicts for plotting and pass/fail checks.

Custom thresholds may specify only a subset of metrics (e.g. accuracy only).
Plotting uses default thresholds for any metric without an explicit value so
charts always show a reference line; pass/fail uses only the user-provided
thresholds when a custom dict is supplied.
"""
normalized_metrics = _normalize_dict_keys(metrics or DEFAULT_METRICS)
default_thresholds = _normalize_dict_keys(DEFAULT_THRESHOLDS)

if thresholds is not None:
pass_thresholds = _normalize_dict_keys(thresholds)
plot_thresholds = {**default_thresholds, **pass_thresholds}
else:
pass_thresholds = default_thresholds
plot_thresholds = default_thresholds

return normalized_metrics, plot_thresholds, pass_thresholds


def _compute_metrics(
results: dict,
metrics: Dict[str, Callable],
Expand Down Expand Up @@ -230,11 +260,9 @@ def WeakspotsDiagnosis(
"Column(s) provided in features_columns do not exist in the dataset"
)

metrics = metrics or DEFAULT_METRICS
metrics = {k.title(): v for k, v in metrics.items()}

thresholds = thresholds or DEFAULT_THRESHOLDS
thresholds = {k.title(): v for k, v in thresholds.items()}
metrics, plot_thresholds, pass_thresholds = _prepare_metrics_and_thresholds(
metrics, thresholds
)

results_headers = ["Slice", "Number of Records", "Feature"]
results_headers.extend(metrics.keys())
Expand Down Expand Up @@ -290,14 +318,18 @@ def WeakspotsDiagnosis(
results_2=r2,
feature_column=feature,
metric=metric,
threshold=thresholds[metric],
threshold=plot_thresholds[metric],
)

figures.append(fig)

# For simplicity, test has failed if any of the metrics is below the threshold. We will
# rely on visual assessment for this test for now.
if not df[df[list(thresholds.keys())].lt(thresholds).any(axis=1)].empty:
pass_columns = [c for c in pass_thresholds if c in metrics]
if (
pass_columns
and not df[df[pass_columns].lt(pass_thresholds).any(axis=1)].empty
):
passed = False
results_1 = pd.concat([results_1, pd.DataFrame(r1)])
results_2 = pd.concat([results_2, pd.DataFrame(r2)])
Expand Down
Loading