Skip to content

sae: shared probing primitives (eval metrics + ActivationBuffer)#1629

Open
polinabinder1 wants to merge 5 commits into
mainfrom
pbinder/sae-interp-primitives
Open

sae: shared probing primitives (eval metrics + ActivationBuffer)#1629
polinabinder1 wants to merge 5 commits into
mainfrom
pbinder/sae-interp-primitives

Conversation

@polinabinder1

@polinabinder1 polinabinder1 commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Summary

Shared, model-agnostic SAE probing primitives in the sae package (sibling of loss_recovered/sparsity/dead_latents): scoring metrics + per-feature annotation, all pure functions of codes + labels.

Contents — sae.eval.probing

  • ActivationBuffer (codes + optional dense twin + per-token labels + instance ids)
  • AUROC: auroc_all, auroc_vec, best_single_train_test
  • decoders: fit_logreg / fit_softmax / macro_auroc / decode_eval
  • domain_f1 (precision-per-nt, recall-per-instance)
  • annotate_features (per-feature best concept by AUROC → the annotation table)

How to use

from sae.eval.probing import auroc_all, annotate_features
au  = auroc_all(codes, labels)                                   # [F, L]
ann = annotate_features(codes, labels, names, min_auroc=0.85)    # [{feature_id, label, auroc}]

pytest sae/tests/test_probing.py — CPU, no model.

Why hand-rolled (not sklearn)

GPU-vectorized over the whole ~32k-feature dictionary in one pass; sklearn is CPU/per-(scores,label) — CodonFM used it and had to subsample to ≤5k features. Package stays torch+numpy. Each metric is validated against an independent reference in the tests (pairwise-AUROC oracle, hand-computed domain_f1, etc.).

Base of: #1630 (eval labels) → #1636 (harness).

On external libraries (checked — not a win)

We evaluated replacing the hand-rolled metrics with sklearn/torchmetrics, function by function:

  • auroc_all — no library computes a vectorized [features × labels] AUROC matrix on GPU; sklearn.roc_auc_score / torchmetrics are CPU and per-(scores, label), so a ~32k-feature dictionary becomes a 32k-iteration CPU loop. Kept.
  • domain_f1, best_single_train_test, annotate_features — bespoke (instance-F1, winner's-curse, per-feature assignment); no library equivalent.
  • fit_logreg / fit_softmax / decode_eval (~62 lines) — the only sklearn-replaceable code, but they fit on the [N≈50k, F≈32k] SAE-code matrix, which is exactly where CodonFM tried sklearn.LogisticRegression and had to subsample to ≤5k features. Swapping reintroduces that coverage loss and a runtime dep. Net regression.
  • ActivationBuffer / split_indices / standardizenp.savez + 7-line helpers; nothing to gain.

Conclusion: kept the package torch+numpy-only. The metrics are standard formulas (Mann–Whitney rank-AUROC, Adam BCE/softmax, InterPLM instance-F1) vectorized for full-dictionary GPU scale, and each is validated against an independent reference in the tests.

Expected output

  • Library (no script). pytest sae/tests/test_probing.py6 passed: AUROC vs the pairwise-definition oracle, domain_f1 vs a hand-computed reference, best_single winner's-curse flip, decode_eval separability, annotate_features best-concept, buffer roundtrip.

Summary by CodeRabbit

  • New Features

    • Added comprehensive evaluation and probing utilities for sparse autoencoders, including AUROC metrics, feature annotation, classifier-based probing, and domain F1 scoring.
    • Introduced a data buffer utility for storing and persisting activation analysis data.
  • Tests

    • Added comprehensive test coverage for new evaluation metrics and utilities.

…ook)

Model-agnostic SAE eval/intervention primitives, factored out of the evo2 eval
(#1624) and steering (#1626) PRs so any recipe (evo2/codonfm/esm2) can reuse them
and the evo2-specific harnesses stack on top:

- sae.eval.probing: ActivationBuffer + scoring lenses (AUROC, linear/softmax
  decode, instance-level domain_f1) — pure functions of codes + labels.
- sae.steering: delta-clamp forward hook + steer() context manager.

CPU-testable, no torch-CUDA / no model. Tests: test_probing, test_steering (7).

Signed-off-by: Polina Binder <pbinder@nvidia.com>
@coderabbitai

coderabbitai Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: eaa20ce7-8293-44d3-a581-c0c24ee89738

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds a comprehensive SAE feature-probing evaluation module (probing.py) to enable model-agnostic interpretation of learned features through metrics, classifiers, and annotation tools, along with an ActivationBuffer artifact for persistence and a full test suite validating correctness across all components.

Changes

SAE Probing Evaluation Suite

Layer / File(s) Summary
ActivationBuffer data structure and persistence
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 1–65), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 123–142)
Dataclass storing SAE feature codes, per-token boolean labels and names, optional dense residuals, and concept-to-instance id mappings; .save() serializes to typed .npz with per-concept instance arrays; .load() reconstructs the dataclass; .name_idx property maps label names to column indices.
Dataset utilities and standardization
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 73–84)
split_indices performs deterministic train/test splitting via seeded torch.randperm; standardize computes mean and std on training rows with epsilon-clamped std normalization.
AUROC computation and best-feature selection
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 86–145), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 37–71)
auroc_all computes full [feature, label] AUROC matrix via chunked rank-statistics; auroc_vec handles single-vector AUROC with degenerate-case handling; best_single_train_test selects best feature on training set and reports test AUROC without winner's-curse bias; test oracle _auroc_ref validates against brute-force reference.
Feature concept annotation via AUROC thresholding
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 147–174), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 110–121)
annotate_features derives per-feature best-label annotations by selecting max AUROC across labels and filtering by configurable AUROC threshold; excludes low-information features.
Linear classifier training and macro-AUROC evaluation
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 176–226), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 89–108)
fit_logreg trains binary logistic regression; fit_softmax trains multinomial softmax; both use Adam with BCE-with-logits and cross-entropy respectively; macro_auroc computes macro one-vs-rest AUROC; decode_eval orchestrates training and dual metric reporting for test accuracy and macro AUROC.
Domain-adjusted F1 with instance-aware thresholding
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (lines 228–270), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 73–87)
domain_f1 computes threshold-swept per-feature F1 by normalizing activations per-feature, remapping instance ids, aggregating per-instance firing via index_reduce_, combining precision from concept masks with recall from instance aggregation, and selecting best F1 threshold per feature in chunked passes.
Module public API and test setup
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py (lines 25–71), bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py (lines 1–35)
Imports and re-exports all probing.py utilities in __all__ for public access; test module imports and validates all components.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 A warren of metrics, now bundled with care,
AUROC and F1 floating through air,
Buffers that save what the features unfold,
Linear probes seeking wisdom untold,
Domain-aware thresholds, adaptive and keen—
The richest of probing suites ever been seen! 🌟

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: introduction of shared probing primitives including ActivationBuffer and evaluation metrics for SAE feature probing.
Docstring Coverage ✅ Passed Docstring coverage is 91.30% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch pbinder/sae-interp-primitives

Comment @coderabbitai help to get the list of available commands and usage tips.

@copy-pr-bot

copy-pr-bot Bot commented Jun 11, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Replace weak sanity-bound asserts with one strong correctness check per non-trivial
metric, and drop trivia:
- auroc_all: validated against the pairwise AUROC definition (P(s+>s-)), an oracle
  independent of the argsort rank-sum impl (no new dependency).
- domain_f1, best_single_train_test, decode_eval: previously UNtested — now each has a
  hand-computed-reference / constructed-flip / separability test.
- dropped the standalone standardize + weak auroc-sanity tests (trivia / subsumed);
  split_indices folded into the buffer roundtrip.
- steering: merged the no-op identity into the exact-delta test (recon-cancellation),
  kept the tuple-output contract. 3 -> 2.
Net 7 -> 7 CPU tests, every one now a real correctness check.

Signed-off-by: Polina Binder <pbinder@nvidia.com>
…1622)

Steering's only consumers (the live engine's clamp hook + the steer.py harness) both
live in the evo2 serve recipe (#1622), and the harness imports Evo2SAE from it. So the
steering primitive + harness move to a dedicated PR stacked on #1622, where the core
clamp-hook dedup can happen in-place. This base is now the probing library only.

Signed-off-by: Polina Binder <pbinder@nvidia.com>
@polinabinder1 polinabinder1 changed the title sae: shared interpretability primitives (probing + steering) sae: shared probing primitives (eval metrics + ActivationBuffer) Jun 11, 2026
…AUROC)

The persistence half of probing that was missing: turns a buffer (codes + concept labels)
into a feature->label table — for each feature, the concept it best separates (highest
AUROC), kept only above min_auroc. Model-agnostic; the recipe CLI just loads a buffer,
calls this, and writes the annotations parquet. + CPU test.

Signed-off-by: Polina Binder <pbinder@nvidia.com>
@polinabinder1

Copy link
Copy Markdown
Collaborator Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor
✅ Action performed

Review finished.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (2)

54-65: 💤 Low value

allow_pickle=True poses a deserialization risk if loading untrusted files.

This is acceptable for internal artifacts but worth documenting. If these buffers might come from external sources, consider validating provenance or using a safer serialization format.

     `@classmethod`
     def load(cls, path: str) -> "ActivationBuffer":
-        """Load an ActivationBuffer from an .npz written by save()."""
+        """Load an ActivationBuffer from an .npz written by save().
+
+        Warning:
+            Uses allow_pickle=True; only load files from trusted sources.
+        """
         z = np.load(path, allow_pickle=True)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py`
around lines 54 - 65, The load method in ActivationBuffer uses np.load(...,
allow_pickle=True) which is unsafe for untrusted files; change load to avoid
allow_pickle=True by default (use allow_pickle=False) or add an explicit
parameter (e.g., allow_pickle: bool = False) and fail with a clear error if
pickled objects are required, and update the ActivationBuffer.load docstring to
document the deserialization risk and the need to validate provenance when
loading external files; ensure references to ActivationBuffer.load and the local
variable z are used to implement and surface the safer behavior.

243-245: 💤 Low value

Consider adding a comment explaining the +2 sizing for the remap tensor.

The +2 accounts for 0-indexing and ensures negative indexing (-1) wraps to a valid buffer position. While correct, this is subtle:

-    remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long)
+    # +2: one for 0-indexing, one so that -1 wraps to a valid (unused) slot
+    remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py`
around lines 243 - 245, Add an inline comment above the remap creation
explaining why the size is int(inst_ids.max().item()) + 2: we need +1 for
0-based indexing of the maximum id and an extra slot so that using -1 as a
sentinel (when indexing remap with potentially -1 inst_ids) will wrap to a valid
buffer position instead of raising an out-of-bounds error; reference the remap
tensor and the subsequent remap[uniq.long()] / remap[inst_ids.long()] usage (and
the torch.full default -1) so readers understand the sentinel handling.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In
`@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py`:
- Around line 54-65: The load method in ActivationBuffer uses np.load(...,
allow_pickle=True) which is unsafe for untrusted files; change load to avoid
allow_pickle=True by default (use allow_pickle=False) or add an explicit
parameter (e.g., allow_pickle: bool = False) and fail with a clear error if
pickled objects are required, and update the ActivationBuffer.load docstring to
document the deserialization risk and the need to validate provenance when
loading external files; ensure references to ActivationBuffer.load and the local
variable z are used to implement and surface the safer behavior.
- Around line 243-245: Add an inline comment above the remap creation explaining
why the size is int(inst_ids.max().item()) + 2: we need +1 for 0-based indexing
of the maximum id and an extra slot so that using -1 as a sentinel (when
indexing remap with potentially -1 inst_ids) will wrap to a valid buffer
position instead of raising an out-of-bounds error; reference the remap tensor
and the subsequent remap[uniq.long()] / remap[inst_ids.long()] usage (and the
torch.full default -1) so readers understand the sentinel handling.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 23ddf87a-6a45-46a2-8264-db968ee016e5

📥 Commits

Reviewing files that changed from the base of the PR and between e407165 and 79df727.

📒 Files selected for processing (3)
  • bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.py
  • bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py
  • bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py

… sizing

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
@polinabinder1

Copy link
Copy Markdown
Collaborator Author

Addressed the two nitpicks in 57837ec7: documented the allow_pickle=True trust caveat on ActivationBuffer.load, and added a comment explaining the +2 remap-tensor sizing (index-by-max-id + sentinel headroom). Tests still green (6 passed).

@polinabinder1 polinabinder1 marked this pull request as ready for review June 12, 2026 05:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant