Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion src/scembed/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import fields
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Union

import numpy as np
import pandas as pd
Expand All @@ -16,6 +17,10 @@
from scembed.logging import logger
from scembed.utils import _download_artifact_by_run_id

# Allowed scalar types for filtering
AllowedScalar = Union[int, float, str, bool]
AllowedFilterValue = Union[AllowedScalar, list[AllowedScalar]]


class scIBAggregator:
"""Aggregator for WandB sweep results with scIB metrics visualization.
Expand Down Expand Up @@ -155,7 +160,60 @@ def _unwrap_wandb_config(config_dict: dict) -> dict:
# Config is already at the top level
return config_dict

def fetch_runs(self) -> None:
def _filter_params(self, target_key, target_value) -> None:
"""Filter runs to match certain parameter criteria."""
keep_runs = []
discard_runs = []

for idx, row in self.raw_df.iterrows():
# Get the config. Missing config runs will be filtered out later in _process_runs()
config = row.get("config")
if config is None or not isinstance(config, dict):
continue

# Get the value of the target key in the config. Can be a dict with 'value' key or direct value
# Keeping this here for safety
param_value = config.get(target_key, None)
if isinstance(param_value, dict):
param_value = param_value["value"]

if param_value is None:
if self.filter_allow_none:
keep_runs.append(idx)
else:
discard_runs.append(idx)
continue

# If the desired target_value is not a list, just check for a match
if not isinstance(target_value, list):
if param_value == target_value:
keep_runs.append(idx)
else:
discard_runs.append(idx)
# If the desired target_value is a list, check if param_value is in that list
else:
if any([param_value == x for x in target_value]):
keep_runs.append(idx)
else:
discard_runs.append(idx)

if discard_runs:
logger.info(
"Masked %d runs not matching %s=%s. %d runs remain.",
len(discard_runs),
target_key,
target_value,
len(keep_runs),
)

self.raw_df = self.raw_df.loc[keep_runs].copy()

if self.raw_df.empty:
logger.warning("All runs were masked, please relax the filtering constraints.")

def fetch_runs(
self, filter_params: dict[str, AllowedFilterValue] | None = None, filter_allow_none: bool = True
) -> None:
"""Fetch runs from WandB and process into internal storage."""
logger.info("Fetching runs from %s/%s...", self.entity, self.project)

Expand All @@ -168,6 +226,10 @@ def fetch_runs(self) -> None:
self.missing_metrics_runs = []
self.available_scib_metrics = set()

# Set filtering options
self.filter_params = filter_params
self.filter_allow_none = filter_allow_none

# Initialize WandB API
api = wandb.Api()

Expand All @@ -187,6 +249,11 @@ def fetch_runs(self) -> None:

logger.info("Fetched %d runs", self.n_runs_fetched)

# Filter the data
if self.filter_params is not None:
for target_key, target_value in self.filter_params.items():
self._filter_params(target_key=target_key, target_value=target_value)

# Process the data
self._process_runs()

Expand Down
Loading