sae: shared probing primitives (eval metrics + ActivationBuffer)#1629
sae: shared probing primitives (eval metrics + ActivationBuffer)#1629polinabinder1 wants to merge 5 commits into
Conversation
…ook) Model-agnostic SAE eval/intervention primitives, factored out of the evo2 eval (#1624) and steering (#1626) PRs so any recipe (evo2/codonfm/esm2) can reuse them and the evo2-specific harnesses stack on top: - sae.eval.probing: ActivationBuffer + scoring lenses (AUROC, linear/softmax decode, instance-level domain_f1) — pure functions of codes + labels. - sae.steering: delta-clamp forward hook + steer() context manager. CPU-testable, no torch-CUDA / no model. Tests: test_probing, test_steering (7). Signed-off-by: Polina Binder <pbinder@nvidia.com>
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR adds a comprehensive SAE feature-probing evaluation module ( ChangesSAE Probing Evaluation Suite
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Replace weak sanity-bound asserts with one strong correctness check per non-trivial metric, and drop trivia: - auroc_all: validated against the pairwise AUROC definition (P(s+>s-)), an oracle independent of the argsort rank-sum impl (no new dependency). - domain_f1, best_single_train_test, decode_eval: previously UNtested — now each has a hand-computed-reference / constructed-flip / separability test. - dropped the standalone standardize + weak auroc-sanity tests (trivia / subsumed); split_indices folded into the buffer roundtrip. - steering: merged the no-op identity into the exact-delta test (recon-cancellation), kept the tuple-output contract. 3 -> 2. Net 7 -> 7 CPU tests, every one now a real correctness check. Signed-off-by: Polina Binder <pbinder@nvidia.com>
…1622) Steering's only consumers (the live engine's clamp hook + the steer.py harness) both live in the evo2 serve recipe (#1622), and the harness imports Evo2SAE from it. So the steering primitive + harness move to a dedicated PR stacked on #1622, where the core clamp-hook dedup can happen in-place. This base is now the probing library only. Signed-off-by: Polina Binder <pbinder@nvidia.com>
…AUROC) The persistence half of probing that was missing: turns a buffer (codes + concept labels) into a feature->label table — for each feature, the concept it best separates (highest AUROC), kept only above min_auroc. Model-agnostic; the recipe CLI just loads a buffer, calls this, and writes the annotations parquet. + CPU test. Signed-off-by: Polina Binder <pbinder@nvidia.com>
|
@coderabbitai review |
✅ Action performedReview finished.
|
There was a problem hiding this comment.
🧹 Nitpick comments (2)
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py (2)
54-65: 💤 Low value
allow_pickle=Trueposes a deserialization risk if loading untrusted files.This is acceptable for internal artifacts but worth documenting. If these buffers might come from external sources, consider validating provenance or using a safer serialization format.
`@classmethod` def load(cls, path: str) -> "ActivationBuffer": - """Load an ActivationBuffer from an .npz written by save().""" + """Load an ActivationBuffer from an .npz written by save(). + + Warning: + Uses allow_pickle=True; only load files from trusted sources. + """ z = np.load(path, allow_pickle=True)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py` around lines 54 - 65, The load method in ActivationBuffer uses np.load(..., allow_pickle=True) which is unsafe for untrusted files; change load to avoid allow_pickle=True by default (use allow_pickle=False) or add an explicit parameter (e.g., allow_pickle: bool = False) and fail with a clear error if pickled objects are required, and update the ActivationBuffer.load docstring to document the deserialization risk and the need to validate provenance when loading external files; ensure references to ActivationBuffer.load and the local variable z are used to implement and surface the safer behavior.
243-245: 💤 Low valueConsider adding a comment explaining the
+2sizing for the remap tensor.The
+2accounts for 0-indexing and ensures negative indexing (-1) wraps to a valid buffer position. While correct, this is subtle:- remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long) + # +2: one for 0-indexing, one so that -1 wraps to a valid (unused) slot + remap = torch.full((int(inst_ids.max().item()) + 2,), -1, device=dev, dtype=torch.long)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py` around lines 243 - 245, Add an inline comment above the remap creation explaining why the size is int(inst_ids.max().item()) + 2: we need +1 for 0-based indexing of the maximum id and an extra slot so that using -1 as a sentinel (when indexing remap with potentially -1 inst_ids) will wrap to a valid buffer position instead of raising an out-of-bounds error; reference the remap tensor and the subsequent remap[uniq.long()] / remap[inst_ids.long()] usage (and the torch.full default -1) so readers understand the sentinel handling.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In
`@bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.py`:
- Around line 54-65: The load method in ActivationBuffer uses np.load(...,
allow_pickle=True) which is unsafe for untrusted files; change load to avoid
allow_pickle=True by default (use allow_pickle=False) or add an explicit
parameter (e.g., allow_pickle: bool = False) and fail with a clear error if
pickled objects are required, and update the ActivationBuffer.load docstring to
document the deserialization risk and the need to validate provenance when
loading external files; ensure references to ActivationBuffer.load and the local
variable z are used to implement and surface the safer behavior.
- Around line 243-245: Add an inline comment above the remap creation explaining
why the size is int(inst_ids.max().item()) + 2: we need +1 for 0-based indexing
of the maximum id and an extra slot so that using -1 as a sentinel (when
indexing remap with potentially -1 inst_ids) will wrap to a valid buffer
position instead of raising an out-of-bounds error; reference the remap tensor
and the subsequent remap[uniq.long()] / remap[inst_ids.long()] usage (and the
torch.full default -1) so readers understand the sentinel handling.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 23ddf87a-6a45-46a2-8264-db968ee016e5
📒 Files selected for processing (3)
bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/__init__.pybionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/eval/probing.pybionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_probing.py
… sizing Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
|
Addressed the two nitpicks in |
Summary
Shared, model-agnostic SAE probing primitives in the
saepackage (sibling ofloss_recovered/sparsity/dead_latents): scoring metrics + per-feature annotation, all pure functions of codes + labels.Contents —
sae.eval.probingActivationBuffer(codes + optional dense twin + per-token labels + instance ids)auroc_all,auroc_vec,best_single_train_testfit_logreg/fit_softmax/macro_auroc/decode_evaldomain_f1(precision-per-nt, recall-per-instance)annotate_features(per-feature best concept by AUROC → the annotation table)How to use
pytest sae/tests/test_probing.py— CPU, no model.Why hand-rolled (not sklearn)
GPU-vectorized over the whole ~32k-feature dictionary in one pass; sklearn is CPU/per-
(scores,label)— CodonFM used it and had to subsample to ≤5k features. Package staystorch+numpy. Each metric is validated against an independent reference in the tests (pairwise-AUROC oracle, hand-computeddomain_f1, etc.).Base of: #1630 (eval labels) → #1636 (harness).
On external libraries (checked — not a win)
We evaluated replacing the hand-rolled metrics with
sklearn/torchmetrics, function by function:auroc_all— no library computes a vectorized[features × labels]AUROC matrix on GPU;sklearn.roc_auc_score/torchmetricsare CPU and per-(scores, label), so a ~32k-feature dictionary becomes a 32k-iteration CPU loop. Kept.domain_f1,best_single_train_test,annotate_features— bespoke (instance-F1, winner's-curse, per-feature assignment); no library equivalent.fit_logreg/fit_softmax/decode_eval(~62 lines) — the onlysklearn-replaceable code, but they fit on the[N≈50k, F≈32k]SAE-code matrix, which is exactly where CodonFM triedsklearn.LogisticRegressionand had to subsample to ≤5k features. Swapping reintroduces that coverage loss and a runtime dep. Net regression.ActivationBuffer/split_indices/standardize—np.savez+ 7-line helpers; nothing to gain.Conclusion: kept the package
torch+numpy-only. The metrics are standard formulas (Mann–Whitney rank-AUROC, Adam BCE/softmax, InterPLM instance-F1) vectorized for full-dictionary GPU scale, and each is validated against an independent reference in the tests.Expected output
pytest sae/tests/test_probing.py→ 6 passed: AUROC vs the pairwise-definition oracle,domain_f1vs a hand-computed reference,best_singlewinner's-curse flip,decode_evalseparability,annotate_featuresbest-concept, buffer roundtrip.Summary by CodeRabbit
New Features
Tests