diff --git a/src/scembed/aggregation.py b/src/scembed/aggregation.py index c9356fb..e34847a 100644 --- a/src/scembed/aggregation.py +++ b/src/scembed/aggregation.py @@ -4,6 +4,7 @@ from dataclasses import fields from pathlib import Path from tempfile import TemporaryDirectory +from typing import Union import numpy as np import pandas as pd @@ -16,6 +17,10 @@ from scembed.logging import logger from scembed.utils import _download_artifact_by_run_id +# Allowed scalar types for filtering +AllowedScalar = Union[int, float, str, bool] +AllowedFilterValue = Union[AllowedScalar, list[AllowedScalar]] + class scIBAggregator: """Aggregator for WandB sweep results with scIB metrics visualization. @@ -155,7 +160,60 @@ def _unwrap_wandb_config(config_dict: dict) -> dict: # Config is already at the top level return config_dict - def fetch_runs(self) -> None: + def _filter_params(self, target_key, target_value) -> None: + """Filter runs to match certain parameter criteria.""" + keep_runs = [] + discard_runs = [] + + for idx, row in self.raw_df.iterrows(): + # Get the config. Missing config runs will be filtered out later in _process_runs() + config = row.get("config") + if config is None or not isinstance(config, dict): + continue + + # Get the value of the target key in the config. Can be a dict with 'value' key or direct value + # Keeping this here for safety + param_value = config.get(target_key, None) + if isinstance(param_value, dict): + param_value = param_value["value"] + + if param_value is None: + if self.filter_allow_none: + keep_runs.append(idx) + else: + discard_runs.append(idx) + continue + + # If the desired target_value is not a list, just check for a match + if not isinstance(target_value, list): + if param_value == target_value: + keep_runs.append(idx) + else: + discard_runs.append(idx) + # If the desired target_value is a list, check if param_value is in that list + else: + if any([param_value == x for x in target_value]): + keep_runs.append(idx) + else: + discard_runs.append(idx) + + if discard_runs: + logger.info( + "Masked %d runs not matching %s=%s. %d runs remain.", + len(discard_runs), + target_key, + target_value, + len(keep_runs), + ) + + self.raw_df = self.raw_df.loc[keep_runs].copy() + + if self.raw_df.empty: + logger.warning("All runs were masked, please relax the filtering constraints.") + + def fetch_runs( + self, filter_params: dict[str, AllowedFilterValue] | None = None, filter_allow_none: bool = True + ) -> None: """Fetch runs from WandB and process into internal storage.""" logger.info("Fetching runs from %s/%s...", self.entity, self.project) @@ -168,6 +226,10 @@ def fetch_runs(self) -> None: self.missing_metrics_runs = [] self.available_scib_metrics = set() + # Set filtering options + self.filter_params = filter_params + self.filter_allow_none = filter_allow_none + # Initialize WandB API api = wandb.Api() @@ -187,6 +249,11 @@ def fetch_runs(self) -> None: logger.info("Fetched %d runs", self.n_runs_fetched) + # Filter the data + if self.filter_params is not None: + for target_key, target_value in self.filter_params.items(): + self._filter_params(target_key=target_key, target_value=target_value) + # Process the data self._process_runs()