Skip to content

feat: add RMSD metric#227

Open
k-chrispens wants to merge 2 commits into
mainfrom
kmc/add-rmsd-metric
Open

feat: add RMSD metric#227
k-chrispens wants to merge 2 commits into
mainfrom
kmc/add-rmsd-metric

Conversation

@k-chrispens
Copy link
Copy Markdown
Collaborator

@k-chrispens k-chrispens commented Apr 24, 2026

Adds RMSD metric, which we can use to hook up to downstream clustering analysis if we want.

Summary by CodeRabbit

  • New Features

    • Added an All-atom RMSD metric: computes global RMSD with optional superposition, per-residue RMSD breakdown, and optional per-batch RMSD logging. Selections can limit reported residues without changing the global RMSD.
  • Tests

    • Added tests and shared fixtures validating identity cases, selection behavior, and end-to-end RMSD values.

Review Change Stack

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

📝 Walkthrough

Walkthrough

Adds a new AllAtomRMSD metric that computes all-atom RMSD between predicted and reference atom stacks (optional Kabsch superposition), groups per-residue RMSD by token, supports mask-based residue selection for reported scores, and includes tests plus shared fixtures for altloc/backbone selections.

Changes

AllAtomRMSD Metric Implementation

Layer / File(s) Summary
Class surface and config
src/sampleworks/metrics/rmsd.py
Defines AllAtomRMSD class, constructor, and kwargs_to_compute_args/optional_kwargs exposing compute inputs.
Input annotation and common-atom filtering
src/sampleworks/metrics/rmsd.py
Annotates global token IDs, filters predicted and reference stacks to common atoms in matching order, converts to AtomArrayStack, and errors when no shared atoms exist.
Superposition and global RMSD
src/sampleworks/metrics/rmsd.py
When superimpose=True performs Kabsch alignment on common atoms and computes per-model global RMSD series and aggregated best_of_N values.
Selection and per-residue scoring
src/sampleworks/metrics/rmsd.py
Uses predicted stack mask (when selection provided) to derive selected token IDs, groups atoms by token into residue keys, and computes per-residue RMSD lists (selection limits reported residue keys only).
Result assembly and optional logging
src/sampleworks/metrics/rmsd.py
Assembles output dict (best_of_1_rmsd, best_of_{N}_rmsd, residue_rmsd_scores) and conditionally adds all_atom_rmsd_<i> batch logs.

Test Fixtures & Refactoring

Layer / File(s) Summary
Shared fixtures for altloc/backbone
tests/metrics/conftest.py
Adds altlocA_backbone and altlocB_backbone module-scoped fixtures that select altlocs A/B, restrict to backbone atoms, and return AtomArrayStack.
lDDT test module cleanup
tests/metrics/test_lddt_metrics.py
Removes local altloc/backbone fixture implementations so tests use the shared conftest fixtures.

AllAtomRMSD Tests

Layer / File(s) Summary
RMSD test cases
tests/metrics/test_rmsd_metrics.py
Adds identity test (zero RMSD), end-to-end selected-residue RMSD check against expected values, and selection-preserves-global RMSD test.

Sequence Diagram(s)

sequenceDiagram
  participant Test
  participant AllAtomRMSD
  participant Annotator
  participant Filter
  participant Kabsch
  participant RMSDCalc

  Test->>AllAtomRMSD: compute(predicted, reference, selection?)
  AllAtomRMSD->>Annotator: annotate global token IDs
  Annotator-->>AllAtomRMSD: annotated stacks
  AllAtomRMSD->>Filter: filter to common atoms (matching order)
  Filter-->>AllAtomRMSD: filtered stacks
  alt superimpose = True
    AllAtomRMSD->>Kabsch: compute alignment (pred → ref)
    Kabsch-->>AllAtomRMSD: superimposed predicted coords
  end
  AllAtomRMSD->>RMSDCalc: compute global RMSD per model
  RMSDCalc-->>AllAtomRMSD: per-model RMSD series
  AllAtomRMSD->>Annotator: group atoms by token -> residue keys
  Annotator-->>AllAtomRMSD: per-residue atom groups
  alt selection provided
    AllAtomRMSD->>AllAtomRMSD: mask -> selected token IDs
  end
  AllAtomRMSD->>RMSDCalc: compute per-residue RMSD lists
  RMSDCalc-->>AllAtomRMSD: residue_rmsd_scores
  AllAtomRMSD-->>Test: {best_of_1_rmsd, best_of_N_rmsd, residue_rmsd_scores, optional logs}
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

🐰 Hops of joy! A metric new to measure,
Atoms align, Kabsch twirls with gentle pleasure,
Residues scored, lists tidy and neat,
Global numbers steady, per-residue beat,
Tests hop in to keep the math precise.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: add RMSD metric' directly and clearly describes the main change—adding a new RMSD metric class. It accurately reflects the primary work across all modified files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kmc/add-rmsd-metric

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an all-atom RMSD metric to sampleworks.metrics with residue-level breakdown, plus tests, to enable RMSD-based evaluation consistent with existing LDDT metric outputs.

Changes:

  • Introduce AllAtomRMSD metric producing global and per-residue RMSD outputs.
  • Add end-to-end and identity tests for the RMSD metric.
  • Refactor LDDT metric tests to use shared fixtures in a new tests/metrics/conftest.py.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
src/sampleworks/metrics/rmsd.py Adds the new RMSD metric implementation and output schema.
tests/metrics/test_rmsd_metrics.py Adds RMSD-specific test coverage (identity, end-to-end, selection behavior).
tests/metrics/conftest.py Adds shared altloc/backbone fixtures used across metric tests.
tests/metrics/test_lddt_metrics.py Removes duplicated fixtures in favor of the shared conftest fixtures.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1 to +6
"""RMSD metric, developed starting from :class:`sampleworks.metrics.lddt.AllAtomLDDT`.

The metric returns a global per model RMSD and a per residue RMSD dictionary,
so it can be plugged into the same clustering code as the LDDT
metric.
"""
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR title mentions an "associated clustering script", but this PR only adds the RMSD metric/tests/fixtures. Either add the clustering script changes or update the PR title/description so it matches what’s actually being merged.

Copilot uses AI. Check for mistakes.
Comment thread src/sampleworks/metrics/rmsd.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/metrics/test_rmsd_metrics.py (1)

9-66: Consider adding coverage for superimpose=True and log_rmsd_for_every_batch=True.

The current tests exercise only the default constructor. Both superimpose (Kabsch path, including its subtle interaction with selection — see rmsd.py review) and log_rmsd_for_every_batch (which mutates the result dict) are untested. A quick test of the Kabsch path would be to rigidly rotate altlocA_backbone coords, compute RMSD with and without superimpose=True, and assert the superimposed RMSD is ~0 while the un-superimposed RMSD is large.

Want me to draft the additional tests, including one that composes superimpose=True with a non-None selection to lock in behavior?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/metrics/test_rmsd_metrics.py` around lines 9 - 66, Add tests that
exercise AllAtomRMSD.compute with the optional flags: (1) create a rigidly
rotated copy of altlocA_backbone and assert that compute(..., superimpose=False)
yields a large best_of_1_rmsd while compute(..., superimpose=True) yields ~0,
and also verify combining superimpose=True with a non-None selection string
keeps the global best_of_1_rmsd equal to the unselected call; (2) call compute
with log_rmsd_for_every_batch=True and assert the returned dict is mutated to
include the per-batch/log entries (e.g., additional keys or batch-level RMSD
structure) compared to the default call; reference AllAtomRMSD, compute,
superimpose, log_rmsd_for_every_batch, and selection to locate where to add
these new tests.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/sampleworks/metrics/rmsd.py`:
- Around line 145-149: The result dict currently sets both "best_of_1_rmsd" and
f"best_of_{len(global_rmsd)}_rmsd", causing a key collision when
len(global_rmsd) == 1; update the code that builds the result (the result:
dict[...] assignment using global_rmsd and residue_rmsd_scores) to always
include "best_of_1_rmsd" but only add the f"best_of_{len(global_rmsd)}_rmsd"
entry when len(global_rmsd) > 1 (i.e., guard the min-across-N key with an if
len(global_rmsd) > 1 to avoid overwriting or duplicating the same key).
- Around line 69-75: The docstring incorrectly uses the term "superpose" but the
actual parameter is named "superimpose"; update the docstring text around the
selection/residue_rmsd_scores description to replace "superpose=True" with
"superimpose=True" (and any other occurrences of "superpose") so it matches the
function signature (e.g., where residue_rmsd_scores and the superimpose
parameter are documented) and keep the clarifying sentence that atoms used for
Kabsch superposition/superimposition always use every atom common to the
predicted and reference stacks.
- Around line 108-125: The code may lose atomworks-added methods on
pred_aa_stack when you reassign it with superimpose(pred_aa_stack), causing
mask() to be missing when computing selected_token_ids; move the selection logic
to run before calling superimpose: compute mask = pred_aa_stack.mask and derive
selected_token_ids from the original pred_aa_stack (e.g., selected_arr =
pred_aa_stack[0, mask] and unique token_id extraction) while keeping the
superimpose(pred_aa_stack) call afterward so you only reassign pred_aa_stack
after selected_token_ids is resolved.

---

Nitpick comments:
In `@tests/metrics/test_rmsd_metrics.py`:
- Around line 9-66: Add tests that exercise AllAtomRMSD.compute with the
optional flags: (1) create a rigidly rotated copy of altlocA_backbone and assert
that compute(..., superimpose=False) yields a large best_of_1_rmsd while
compute(..., superimpose=True) yields ~0, and also verify combining
superimpose=True with a non-None selection string keeps the global
best_of_1_rmsd equal to the unselected call; (2) call compute with
log_rmsd_for_every_batch=True and assert the returned dict is mutated to include
the per-batch/log entries (e.g., additional keys or batch-level RMSD structure)
compared to the default call; reference AllAtomRMSD, compute, superimpose,
log_rmsd_for_every_batch, and selection to locate where to add these new tests.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: be81920d-2553-4651-b5af-48727b2bae6d

📥 Commits

Reviewing files that changed from the base of the PR and between fbf6d38 and 2d3c16d.

📒 Files selected for processing (4)
  • src/sampleworks/metrics/rmsd.py
  • tests/metrics/conftest.py
  • tests/metrics/test_lddt_metrics.py
  • tests/metrics/test_rmsd_metrics.py
💤 Files with no reviewable changes (1)
  • tests/metrics/test_lddt_metrics.py

Comment thread src/sampleworks/metrics/rmsd.py
Comment on lines +108 to +125
if self.superimpose:
pred_aa_stack, _ = superimpose(gt_ref, pred_aa_stack)

tok_idx = cast(np.ndarray, gt_ref.token_id).astype(np.int64)

# Resolve the subset of tokens to report, if a residue selection was given.
selected_token_ids: set[int] | None = None
if selection is not None:
mask_fn = pred_aa_stack.mask
if mask_fn is None:
raise RuntimeError(
"pred_aa_stack does not support mask(). Load atom arrays with "
"`atomworks.io.utils.io_utils.load_any()` to access this method."
)
mask = mask_fn(selection)
selected_arr = cast(AtomArray, pred_aa_stack[0, mask])
if selected_arr.token_id is not None:
selected_token_ids = {int(t) for t in np.unique(selected_arr.token_id)}
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check whether biotite's superimpose() preserves an atomworks-added .mask() method / attribute,
# and whether AtomArrayStack has a built-in `mask` anywhere in biotite usage.
rg -nP --type=py -C3 '\bsuperimpose\s*\(' -g '!**/test/**' -g '!**/tests/**'
echo '---'
# Look for how .mask is attached to atom arrays in the codebase (likely via atomworks load_any)
rg -nP --type=py -C3 '\b\.mask\s*\(' 
echo '---'
# Confirm where `mask` as an attribute is set on atom arrays
rg -nP --type=py -C3 'set_annotation\(\s*[\"\x27]mask[\"\x27]|\.mask\s*=\s*'

Repository: diff-use/sampleworks

Length of output: 4721


🌐 Web query:

biotite.structure.superimpose return type AtomArrayStack annotations preserved

💡 Result:

biotite.structure.superimpose returns a tuple (fitted, transformation), where fitted is an AtomArray or AtomArrayStack (matching the type of the input mobile structure) with updated coordinates that superimpose it onto the fixed structure, and transformation is an AffineTransformation object. The fitted object is explicitly described as "A copy of the mobile structure(s), superimposed on the fixed structure(s)." Since it is a copy of the input AtomArrayStack, and AtomArrayStack objects store annotation arrays shared across all models, the annotations are preserved in the output. This is confirmed in the official documentation and source code. The AffineTransformation.apply method, used internally and exposed for reuse, also creates a copy of the input structure (via atoms.copy) and only updates the .coord attribute, preserving all annotations.

Citations:


🏁 Script executed:

# Find where .mask is defined or attached in atomworks code
find . -type f -name "*.py" | xargs grep -l "\.mask\s*=" | head -20
echo "---"
# Look for load_any implementation 
rg -nP --type=py "def load_any|load_any" -A 10 -g '!**/test/**' -g '!**/tests/**' | head -50
echo "---"
# Search for how mask is set as annotation
rg -nP --type=py "set_annotation.*mask|mask.*annotation" | head -20

Repository: diff-use/sampleworks

Length of output: 3768


🏁 Script executed:

# Check if atomworks is imported and how it modifies atom arrays
rg -nP --type=py "import.*atomworks|from.*atomworks" | head -20
echo "---"
# Look for AtomArrayStack extensions or monkey-patching
rg -nP --type=py "AtomArrayStack\.|atom_array\.mask" -B2 -A2 | head -60

Repository: diff-use/sampleworks

Length of output: 6102


🏁 Script executed:

# Search for test files related to rmsd.py
find . -type f -name "*test*rmsd*" -o -name "*rmsd*test*" | head -10
echo "---"
# Look for any usage combining superimpose with mask in the codebase
rg -nP --type=py "superimpose.*mask|mask.*superimpose" -B3 -A3
echo "---"
# Check the rmsd.py file context around superimpose to understand the flow
head -n 130 src/sampleworks/metrics/rmsd.py | tail -n 30

Repository: diff-use/sampleworks

Length of output: 1475


🏁 Script executed:

# Check if there are any tests that call rmsd with both superimpose=True and selection parameter
rg -nP --type=py "superimpose.*True.*selection|selection.*superimpose" -B2 -A2
echo "---"
# Look for RMSD metric usage in tests with both parameters
rg -nP "RMSD\(.*superimpose.*True" -B2 -A5

Repository: diff-use/sampleworks

Length of output: 67


🏁 Script executed:

# Check test file for rmsd to see if any tests use both superimpose=True and selection
cat tests/metrics/test_rmsd_metrics.py | grep -n -A10 "superimpose.*True\|selection"

Repository: diff-use/sampleworks

Length of output: 1133


🏁 Script executed:

# More targeted: look for test cases that instantiate RMSD with superimpose=True
rg -nP --type=py "RMSD\(" tests/metrics/test_rmsd_metrics.py -A5

Repository: diff-use/sampleworks

Length of output: 982


🏁 Script executed:

# Check the rmsd.py file to see full context including imports and how superimpose is used
cat src/sampleworks/metrics/rmsd.py | head -n 135 | tail -n 35

Repository: diff-use/sampleworks

Length of output: 1688


🏁 Script executed:

# Check the AllAtomRMSD class definition and its default parameters
rg -nP --type=py "class AllAtomRMSD" -A30 src/sampleworks/metrics/rmsd.py
echo "---"
# Check if there's a superimpose parameter in the RMSD class
rg -nP --type=py "superimpose" src/sampleworks/metrics/rmsd.py | head -20

Repository: diff-use/sampleworks

Length of output: 1617


🏁 Script executed:

# Look for any test that explicitly sets superimpose=True
rg -nP --type=py "superimpose\s*=\s*True|superimpose=True" tests/metrics/test_rmsd_metrics.py
echo "---"
# Check the test fixtures to understand test structure
head -n 100 tests/metrics/test_rmsd_metrics.py

Repository: diff-use/sampleworks

Length of output: 2623


🏁 Script executed:

# Search for any instantiations of AllAtomRMSD with superimpose=True in the entire codebase
rg -nP --type=py "AllAtomRMSD\s*\(\s*superimpose\s*=\s*True|AllAtomRMSD\(\s*True" 
echo "---"
# Also check if there's a way to set superimpose through config or kwargs
rg -nP --type=py "AllAtomRMSD.*superimpose" -B2 -A2

Repository: diff-use/sampleworks

Length of output: 67


🏁 Script executed:

# Let's check conftest.py to see how test fixtures are created
head -n 80 tests/metrics/conftest.py

Repository: diff-use/sampleworks

Length of output: 828


🏁 Script executed:

# Let's check if there are any tests that specifically test superimpose=True with selection
rg -nP --type=py "superimpose\s*=\s*True" tests/ -A10 -B5 
echo "---"
# Double-check: any usage of AllAtomRMSD with both superimpose and selection anywhere
rg -nP --type=py "AllAtomRMSD.*superimpose.*True.*selection|selection.*AllAtomRMSD.*superimpose.*True"

Repository: diff-use/sampleworks

Length of output: 67


🏁 Script executed:

# Let's understand: is .mask a method or property? Search for how it's defined/used
rg -nP --type=py "\.mask\s*\(" src/sampleworks -B1 -A1 | head -40

Repository: diff-use/sampleworks

Length of output: 1564


🏁 Script executed:

# Search for any actual usage in the codebase of superimpose followed by mask calls
rg -nP --type=py "superimpose\(" -A20 | grep -A15 "\.mask\("
echo "---"
# Check if there are integration tests or real usage patterns
find tests -name "*.py" -exec grep -l "superimpose.*True" {} \;
echo "---"
# Search for any documentation or comments about this interaction
rg -nP --type=py "superimpose.*mask|mask.*superimpose" -B5 -A5

Repository: diff-use/sampleworks

Length of output: 112


🏁 Script executed:

# Let's check how atomworks loads structures to understand what .mask() is
rg -nP --type=py "class.*AtomArray|def.*mask" -A5 | head -60
echo "---"
# Critical: check if .mask is being set as an attribute on the object
rg -nP --type=py "\.mask\s*=" -B3 -A3

Repository: diff-use/sampleworks

Length of output: 4634


🏁 Script executed:

# Look at the actual test file again to see if there's any test with superimpose=True
cat tests/metrics/test_rmsd_metrics.py

Repository: diff-use/sampleworks

Length of output: 2619


🏁 Script executed:

# Search for any AllAtomRMSD instantiation with superimpose parameter
rg -nP "AllAtomRMSD\(" tests/ -A3 -B1

Repository: diff-use/sampleworks

Length of output: 1352


🏁 Script executed:

# Let's verify: does biotite.superimpose actually preserve custom methods/attributes?
# Check what biotite's copy mechanism does
rg -nP --type=py "def copy\|\.copy\(\)" src/sampleworks -B2 -A5 | head -40
echo "---"
# Check imports in rmsd.py to see what's available from biotite
head -n 20 src/sampleworks/metrics/rmsd.py

Repository: diff-use/sampleworks

Length of output: 775


🏁 Script executed:

# Final check: verify whether .mask is truly callable or if it's just checked for None
rg -nP "mask_fn\s*=" src/sampleworks/metrics/rmsd.py -B2 -A3

Repository: diff-use/sampleworks

Length of output: 365


🏁 Script executed:

# Critical: check if there are any existing bug reports or comments about this interaction
rg -nP "superimpose.*selection|selection.*superimpose|mask.*after.*superimpose" -B3 -A3
echo "---"
# Check if .mask is set/assigned somewhere on AtomArrayStack
rg -nP "pred_aa_stack.*mask|\.mask\s*=" src/sampleworks/metrics/rmsd.py -B2 -A2

Repository: diff-use/sampleworks

Length of output: 1008


🏁 Script executed:

# Search for any comments or issues related to this problem
rg -nP "FIXME|TODO|BUG|ISSUE" src/sampleworks/metrics/rmsd.py -B2 -A2 | head -30

Repository: diff-use/sampleworks

Length of output: 46


Potential runtime failure when combining superimpose=True with selection.

pred_aa_stack is reassigned on line 109 to the output of biotite.structure.superimpose(). The subsequent .mask() call on line 116 may fail if custom methods added by atomworks do not survive biotite's copy operation during superimposition. While the code includes a defensive check on line 117, this creates a brittle dependency on biotite's implementation details.

Resolve selected_token_ids before calling superimpose() to eliminate this risk:

🛠️ Suggested restructuring
-        # 3. Optional Kabsch superposition (always on every common atom, regardless of
-        # any `selection`).
-        gt_ref = gt_aa_stack[0]
-        if self.superimpose:
-            pred_aa_stack, _ = superimpose(gt_ref, pred_aa_stack)
-
-        tok_idx = cast(np.ndarray, gt_ref.token_id).astype(np.int64)
-
-        # Resolve the subset of tokens to report, if a residue selection was given.
-        selected_token_ids: set[int] | None = None
-        if selection is not None:
-            mask_fn = pred_aa_stack.mask
+        gt_ref = gt_aa_stack[0]
+        tok_idx = cast(np.ndarray, gt_ref.token_id).astype(np.int64)
+
+        # Resolve the subset of tokens to report BEFORE superimposing, to avoid
+        # depending on whether custom atomworks methods survive biotite's copy.
+        selected_token_ids: set[int] | None = None
+        if selection is not None:
+            mask_fn = pred_aa_stack.mask
             if mask_fn is None:
                 raise RuntimeError(
                     "pred_aa_stack does not support mask(). Load atom arrays with "
                     "`atomworks.io.utils.io_utils.load_any()` to access this method."
                 )
             mask = mask_fn(selection)
             selected_arr = cast(AtomArray, pred_aa_stack[0, mask])
             if selected_arr.token_id is not None:
                 selected_token_ids = {int(t) for t in np.unique(selected_arr.token_id)}
+
+        # 3. Optional Kabsch superposition (always on every common atom, regardless of
+        # any `selection`).
+        if self.superimpose:
+            pred_aa_stack, _ = superimpose(gt_ref, pred_aa_stack)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/sampleworks/metrics/rmsd.py` around lines 108 - 125, The code may lose
atomworks-added methods on pred_aa_stack when you reassign it with
superimpose(pred_aa_stack), causing mask() to be missing when computing
selected_token_ids; move the selection logic to run before calling superimpose:
compute mask = pred_aa_stack.mask and derive selected_token_ids from the
original pred_aa_stack (e.g., selected_arr = pred_aa_stack[0, mask] and unique
token_id extraction) while keeping the superimpose(pred_aa_stack) call afterward
so you only reassign pred_aa_stack after selected_token_ids is resolved.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@k-chrispens definitely look into this and harden it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Comment thread src/sampleworks/metrics/rmsd.py
@k-chrispens k-chrispens changed the title feat: add RMSD metric and associated clustering script feat: add RMSD metric Apr 24, 2026
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
# Resolve the subset of tokens to report, if a residue selection was given.
selected_token_ids: set[int] | None = None
if selection is not None:
mask_fn = pred_aa_stack.mask
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it necessary to create this variable mask_fn? Why not just use pred_aa_stack.mask?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, won't this raise an attribute error if the mask function isn't there?

Comment on lines +108 to +125
if self.superimpose:
pred_aa_stack, _ = superimpose(gt_ref, pred_aa_stack)

tok_idx = cast(np.ndarray, gt_ref.token_id).astype(np.int64)

# Resolve the subset of tokens to report, if a residue selection was given.
selected_token_ids: set[int] | None = None
if selection is not None:
mask_fn = pred_aa_stack.mask
if mask_fn is None:
raise RuntimeError(
"pred_aa_stack does not support mask(). Load atom arrays with "
"`atomworks.io.utils.io_utils.load_any()` to access this method."
)
mask = mask_fn(selection)
selected_arr = cast(AtomArray, pred_aa_stack[0, mask])
if selected_arr.token_id is not None:
selected_token_ids = {int(t) for t in np.unique(selected_arr.token_id)}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@k-chrispens definitely look into this and harden it.

Comment thread src/sampleworks/metrics/rmsd.py
Comment thread tests/metrics/conftest.py
)


@pytest.fixture(scope="module")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these have something else for their scope now? I.e., is it a problem for them to be kept around for multiple test modules?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
src/sampleworks/metrics/rmsd.py (2)

145-149: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Key collision when N=1: best_of_1_rmsd and best_of_{N}_rmsd resolve to the same key.

When the stack contains a single model, lines 146 and 147 both write to "best_of_1_rmsd". While the values are identical, downstream consumers expecting a distinct best_of_{N}_rmsd entry when N > 1 may break. Guard the min-across-N key with if len(global_rmsd) > 1.

🔧 Proposed fix
         result: dict[str, Any] = {
             "best_of_1_rmsd": float(global_rmsd[0]),
-            f"best_of_{len(global_rmsd)}_rmsd": float(global_rmsd.min()),
             "residue_rmsd_scores": residue_rmsd_scores,
         }
+        if len(global_rmsd) > 1:
+            result[f"best_of_{len(global_rmsd)}_rmsd"] = float(global_rmsd.min())
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/sampleworks/metrics/rmsd.py` around lines 145 - 149, The result dict
currently always sets both "best_of_1_rmsd" and
f"best_of_{len(global_rmsd)}_rmsd", causing a key collision when
len(global_rmsd)==1; update the code that builds result (the block referencing
global_rmsd and residue_rmsd_scores) to always include "best_of_1_rmsd" but only
add the f"best_of_{len(global_rmsd)}_rmsd" entry when len(global_rmsd) > 1
(i.e., guard the min-across-N key with if len(global_rmsd) > 1) so downstream
consumers see the aggregated key only for N>1.

108-125: ⚠️ Potential issue | 🟠 Major | ⚖️ Poor tradeoff

Potential runtime failure when combining superimpose=True with selection.

The code reassigns pred_aa_stack on line 109 via superimpose(), then accesses .mask() on line 116. If biotite's copy operation does not preserve atomworks-added methods, this will fail. Resolve selected_token_ids before calling superimpose() to eliminate the dependency on biotite's internal copy behavior.

🛠️ Suggested restructuring

Move selection logic before superimpose:

         gt_ref = gt_aa_stack[0]
+        tok_idx = cast(np.ndarray, gt_ref.token_id).astype(np.int64)
+
+        # Resolve the subset of tokens to report BEFORE superimposing
+        selected_token_ids: set[int] | None = None
+        if selection is not None:
+            mask_fn = pred_aa_stack.mask
+            if mask_fn is None:
+                raise RuntimeError(
+                    "pred_aa_stack does not support mask(). Load atom arrays with "
+                    "`atomworks.io.utils.io_utils.load_any()` to access this method."
+                )
+            mask = mask_fn(selection)
+            selected_arr = cast(AtomArray, pred_aa_stack[0, mask])
+            if selected_arr.token_id is not None:
+                selected_token_ids = {int(t) for t in np.unique(selected_arr.token_id)}
+
+        # 3. Optional Kabsch superposition
         if self.superimpose:
             pred_aa_stack, _ = superimpose(gt_ref, pred_aa_stack)
-
-        tok_idx = cast(np.ndarray, gt_ref.token_id).astype(np.int64)
-
-        # Resolve the subset of tokens to report, if a residue selection was given.
-        selected_token_ids: set[int] | None = None
-        if selection is not None:
-            mask_fn = pred_aa_stack.mask
-            if mask_fn is None:
-                raise RuntimeError(
-                    "pred_aa_stack does not support mask(). Load atom arrays with "
-                    "`atomworks.io.utils.io_utils.load_any()` to access this method."
-                )
-            mask = mask_fn(selection)
-            selected_arr = cast(AtomArray, pred_aa_stack[0, mask])
-            if selected_arr.token_id is not None:
-                selected_token_ids = {int(t) for t in np.unique(selected_arr.token_id)}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/sampleworks/metrics/rmsd.py` around lines 108 - 125, The selection-based
token resolution should occur before mutating pred_aa_stack via superimpose to
avoid losing atomworks methods; move the block that computes selected_token_ids
(the selection check, mask_fn = pred_aa_stack.mask, mask = mask_fn(selection),
selected_arr = cast(AtomArray, pred_aa_stack[0, mask]) and extraction of
token_id) to before the superimpose call, leaving only the
superimpose(pred_aa_stack, ...) call afterward and preserving use of
pred_aa_stack in superimpose; ensure you still reference the same names
(selected_token_ids, selection, pred_aa_stack, mask_fn) so downstream code uses
the precomputed selected_token_ids.
🧹 Nitpick comments (1)
tests/metrics/test_rmsd_metrics.py (1)

9-9: Address or remove the TODO comment.

The TODO suggests making tests more rigorous. While the current tests cover core functionality well, consider adding tests for edge cases (e.g., superimpose=True vs False, AtomArray vs AtomArrayStack inputs, error conditions).

Would you like me to suggest additional test cases or open an issue to track this enhancement?

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/metrics/test_rmsd_metrics.py` at line 9, Remove or address the TODO in
tests/metrics/test_rmsd_metrics.py: either delete the comment or add concrete
unit tests that cover the suggested edge cases — specifically add tests
exercising rmsd functions with superimpose=True and superimpose=False, with both
AtomArray and AtomArrayStack inputs, and tests for expected error conditions
(invalid shapes/types). If you prefer not to implement now, replace the TODO
with a short comment pointing to a newly opened issue tracking these
enhancements and include the issue number.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/sampleworks/metrics/rmsd.py`:
- Around line 88-93: Add a short rationale to each `# ty:
ignore[invalid-argument-type]` next to the calls to
add_global_token_id_annotation for predicted_atom_array_stack and
ground_truth_atom_array_stack explaining why the runtime type is valid despite
the type checker complaint (e.g., the arrays are numpy/torch tensors or a Union
that narrows at runtime, or the helper performs casting/validation), for
example: indicate that the function accepts array-like structures and we know at
this call site they conform to the expected shape/dtype so the static checker is
a false positive; update the two comments that follow the calls to
predicted_atom_array_stack and ground_truth_atom_array_stack accordingly.

In `@tests/metrics/test_rmsd_metrics.py`:
- Line 23: The test function test_all_atom_rmsd_end_to_end is missing a
NumPy-style docstring; add a brief NumPy-style docstring inside the function
that states what the test verifies (e.g., that all-atom RMSD between
altlocA_backbone and altlocB_backbone is computed correctly end-to-end), include
expected inputs (fixtures altlocA_backbone, altlocB_backbone) and the high-level
assertion intent so readers understand the purpose of
test_all_atom_rmsd_end_to_end.

---

Duplicate comments:
In `@src/sampleworks/metrics/rmsd.py`:
- Around line 145-149: The result dict currently always sets both
"best_of_1_rmsd" and f"best_of_{len(global_rmsd)}_rmsd", causing a key collision
when len(global_rmsd)==1; update the code that builds result (the block
referencing global_rmsd and residue_rmsd_scores) to always include
"best_of_1_rmsd" but only add the f"best_of_{len(global_rmsd)}_rmsd" entry when
len(global_rmsd) > 1 (i.e., guard the min-across-N key with if len(global_rmsd)
> 1) so downstream consumers see the aggregated key only for N>1.
- Around line 108-125: The selection-based token resolution should occur before
mutating pred_aa_stack via superimpose to avoid losing atomworks methods; move
the block that computes selected_token_ids (the selection check, mask_fn =
pred_aa_stack.mask, mask = mask_fn(selection), selected_arr = cast(AtomArray,
pred_aa_stack[0, mask]) and extraction of token_id) to before the superimpose
call, leaving only the superimpose(pred_aa_stack, ...) call afterward and
preserving use of pred_aa_stack in superimpose; ensure you still reference the
same names (selected_token_ids, selection, pred_aa_stack, mask_fn) so downstream
code uses the precomputed selected_token_ids.

---

Nitpick comments:
In `@tests/metrics/test_rmsd_metrics.py`:
- Line 9: Remove or address the TODO in tests/metrics/test_rmsd_metrics.py:
either delete the comment or add concrete unit tests that cover the suggested
edge cases — specifically add tests exercising rmsd functions with
superimpose=True and superimpose=False, with both AtomArray and AtomArrayStack
inputs, and tests for expected error conditions (invalid shapes/types). If you
prefer not to implement now, replace the TODO with a short comment pointing to a
newly opened issue tracking these enhancements and include the issue number.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1acddfb1-aff1-467f-96f5-a0653cc0d581

📥 Commits

Reviewing files that changed from the base of the PR and between 2d3c16d and b33579d.

📒 Files selected for processing (4)
  • src/sampleworks/metrics/rmsd.py
  • tests/metrics/conftest.py
  • tests/metrics/test_lddt_metrics.py
  • tests/metrics/test_rmsd_metrics.py
💤 Files with no reviewable changes (1)
  • tests/metrics/test_lddt_metrics.py
✅ Files skipped from review due to trivial changes (1)
  • tests/metrics/conftest.py

Comment on lines +88 to +93
predicted_atom_array_stack = add_global_token_id_annotation(
predicted_atom_array_stack # ty: ignore[invalid-argument-type]
)
ground_truth_atom_array_stack = add_global_token_id_annotation(
ground_truth_atom_array_stack # ty: ignore[invalid-argument-type]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify type ignore syntax used across the codebase
rg -nP --type=py '# ty: ignore|# type: ignore' | head -30
echo "---"
# Check pyproject.toml for type checker configuration
rg -nP 'type.*check|mypy|pyright|ruff.*typing' pyproject.toml -A5

Repository: diff-use/sampleworks

Length of output: 4211


Add rationale to # ty: ignore[invalid-argument-type] comments in RMSD
In src/sampleworks/metrics/rmsd.py (lines 89 and 92), the # ty: ignore[invalid-argument-type] directives should include a brief explanation of why the argument types don’t match the type checker’s expectation (beyond just the error code).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/sampleworks/metrics/rmsd.py` around lines 88 - 93, Add a short rationale
to each `# ty: ignore[invalid-argument-type]` next to the calls to
add_global_token_id_annotation for predicted_atom_array_stack and
ground_truth_atom_array_stack explaining why the runtime type is valid despite
the type checker complaint (e.g., the arrays are numpy/torch tensors or a Union
that narrows at runtime, or the helper performs casting/validation), for
example: indicate that the function accepts array-like structures and we know at
this call site they conform to the expected shape/dtype so the static checker is
a false positive; update the two comments that follow the calls to
predicted_atom_array_stack and ground_truth_atom_array_stack accordingly.

assert scores == pytest.approx([0.0], abs=1e-6), f"nonzero identity RMSD at {residue}"


def test_all_atom_rmsd_end_to_end(altlocA_backbone, altlocB_backbone):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add a docstring to the test function.

All functions should include docstrings per the guidelines. Consider adding a brief description of what this test verifies.

📝 Suggested docstring
 def test_all_atom_rmsd_end_to_end(altlocA_backbone, altlocB_backbone):
+    """Verify RMSD computation with selection matches expected values for altloc comparison."""
     selection_string = "res_id > 179 and res_id < 190"

As per coding guidelines: "Always include NumPy-style docstrings for every function and class."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/metrics/test_rmsd_metrics.py` at line 23, The test function
test_all_atom_rmsd_end_to_end is missing a NumPy-style docstring; add a brief
NumPy-style docstring inside the function that states what the test verifies
(e.g., that all-atom RMSD between altlocA_backbone and altlocB_backbone is
computed correctly end-to-end), include expected inputs (fixtures
altlocA_backbone, altlocB_backbone) and the high-level assertion intent so
readers understand the purpose of test_all_atom_rmsd_end_to_end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants