Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3ec2bf6
Add valdo.pipeline CLI with per-stage YAML-driven execution
minhuanli May 4, 2026
9e3b8d7
Remove twinning-specific language from reindex documentation
minhuanli May 4, 2026
f103984
Add valdo.refine prerequisite guidance before add_phases_and_blobs
minhuanli May 4, 2026
f09b994
Add PTP1B_pipeline/ to .gitignore
minhuanli May 4, 2026
59e93bb
Add reindex validation plots and skip-if-done detection
minhuanli May 4, 2026
e76470d
Fix Scaler_pool.batch_scaling call — when_opt not supported in pool v…
minhuanli May 4, 2026
868d621
Support .txt file path in expand_glob_field for explicit file lists
minhuanli May 4, 2026
743aebb
Add scale validation plots and skip-if-done detection
minhuanli May 4, 2026
52b3bd9
Add PTP1B pipeline run docs, configs, and fix VAE training defaults
minhuanli May 4, 2026
439b87c
Add validation plots from reindex and scale stages
minhuanli May 4, 2026
50a2b0f
Fix NaN in VAE training; add scale all-NaN detection; clean dataset
minhuanli May 4, 2026
f0798b0
Add reconstruct and rescale steps (pipeline steps 6-7)
minhuanli May 4, 2026
7af4538
Add tag_blobs step and AUC plot script for PTP1B pipeline
minhuanli May 4, 2026
472ed71
Rewrite PIPELINE_RUN.md as a new-user-friendly guide; add helper scripts
minhuanli May 4, 2026
a83d2c2
Align pipeline defaults with Doeke's settings; fix bugs and improve c…
minhuanli May 8, 2026
796da57
Add filter stage, AUC-vs-N plot, bound model fix; update pipeline run…
minhuanli May 8, 2026
1f5864a
Update AUC/ROC plots from full 1617-dataset run
minhuanli May 11, 2026
16fc1cf
Fix add_phases phase-file lookup; add metric and notebook scripts
minhuanli May 11, 2026
cffba70
Add Keedy/Ginn split to heavy atom peak metric
minhuanli May 11, 2026
7e0220d
Add ablation metrics scripts and self-contained mapping file
minhuanli May 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
dev/
scaled_mtzs/
*pkl
PTP1B_pipeline/
CLAUDE.md

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
429 changes: 429 additions & 0 deletions PTP1B_pipeline/PIPELINE_RUN.md

Large diffs are not rendered by default.

88 changes: 88 additions & 0 deletions PTP1B_pipeline/collect_ablation_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python
"""
Collect ablation metrics across all hyperparameter settings and print a comparison table.

Run from PTP1B_pipeline/ after ablation jobs have completed:
python collect_ablation_metrics.py
"""

import os
import sys
import csv
import glob as _glob

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from compute_valdo_metrics import compute_metrics

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ABLATION_DIR = os.path.join(SCRIPT_DIR, "ablation")
OUTPUT_CSV = os.path.join(ABLATION_DIR, "ablation_metrics.csv")

SETTINGS = [
("baseline (latent=7, relu, wkl=1, [3,6]/100)",
os.path.join(SCRIPT_DIR, "vae", "recons_phased")),
("latent_dim_3",
os.path.join(ABLATION_DIR, "latent_dim_3", "vae", "recons_phased")),
("latent_dim_5",
os.path.join(ABLATION_DIR, "latent_dim_5", "vae", "recons_phased")),
("latent_dim_9",
os.path.join(ABLATION_DIR, "latent_dim_9", "vae", "recons_phased")),
("activation_tanh",
os.path.join(ABLATION_DIR, "activation_tanh", "vae", "recons_phased")),
("w_kl_0.1",
os.path.join(ABLATION_DIR, "w_kl_0.1", "vae", "recons_phased")),
("w_kl_10",
os.path.join(ABLATION_DIR, "w_kl_10", "vae", "recons_phased")),
("hidden_small ([2,4]/100)",
os.path.join(ABLATION_DIR, "hidden_small", "vae", "recons_phased")),
("hidden_large ([4,8]/100)",
os.path.join(ABLATION_DIR, "hidden_large", "vae", "recons_phased")),
("hidden_wide ([3,6]/200)",
os.path.join(ABLATION_DIR, "hidden_wide", "vae", "recons_phased")),
]

FIELDS = ["setting", "apo_mean", "apo_std", "n_apo",
"keedy_mean", "n_keedy", "ginn_mean", "n_ginn", "all_mean", "n_all"]

COL_W = 38
rows = []

print(f"{'Setting':<{COL_W}} {'Apo mean':>10} {'Keedy HA':>10} {'Ginn HA':>10} {'All HA':>10} {'N_apo':>6}")
print("-" * (COL_W + 50))

for name, phased_dir in SETTINGS:
n_mtz = len(_glob.glob(os.path.join(phased_dir, "*.mtz"))) if os.path.isdir(phased_dir) else 0

if n_mtz == 0:
print(f"{name:<{COL_W}} {'NOT READY':>10}")
rows.append({"setting": name, **{f: "N/A" for f in FIELDS[1:]}})
continue

m = compute_metrics(phased_dir)
print(
f"{name:<{COL_W}} "
f"{m['apo_mean']:>10.3f} "
f"{m['keedy_mean']:>10.3f} "
f"{m['ginn_mean']:>10.3f} "
f"{m['all_mean']:>10.3f} "
f"{m['n_apo']:>6}"
)
rows.append({
"setting": name,
"apo_mean": f"{m['apo_mean']:.4f}",
"apo_std": f"{m['apo_std']:.4f}",
"n_apo": m["n_apo"],
"keedy_mean": f"{m['keedy_mean']:.4f}",
"n_keedy": m["n_keedy"],
"ginn_mean": f"{m['ginn_mean']:.4f}",
"n_ginn": m["n_ginn"],
"all_mean": f"{m['all_mean']:.4f}",
"n_all": m["n_all"],
})

with open(OUTPUT_CSV, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=FIELDS)
writer.writeheader()
writer.writerows(rows)

print(f"\nCSV written to: {OUTPUT_CSV}")
270 changes: 270 additions & 0 deletions PTP1B_pipeline/compute_valdo_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
#!/usr/bin/env python
"""
Compute VAE map quality metrics for the PTP1B VALDO pipeline:
1. Apo peak metric — mean highest WDF peak across true-apo datasets
2. Heavy atom metric — WDF map value at ligand heavy atom (Cl/Br/S/I) positions

Run from PTP1B_pipeline/:
python compute_valdo_metrics.py [--recons-phased PATH]
"""

import argparse
import os
import re
import glob
import numpy as np
import gemmi
from tqdm import tqdm

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
RECONS_PHASED = os.path.join(SCRIPT_DIR, "vae", "recons_phased")
BOUND_MODELS_STD = os.path.join(SCRIPT_DIR, "bound_models_standardized")
ALL_SUPERPOSED_V2 = os.path.join(SCRIPT_DIR, "all_superposed_v2")
MAPPING_TXT = os.path.join(SCRIPT_DIR, "ligand_cif_to_dataset_mapping.txt")

DIFF_COL = "WDF"
PHASE_COL = "PH2FOFCWT"


def classify_bound_models(superposed_dir):
"""
Return (keedy_ids, ginn_ids) sets of 4-digit dataset IDs by parsing
all_superposed_v2 filenames:
Keedy: PTP1B_yXXXX_*
Ginn: yXXXX_cluster4x_*
IDs in both (e.g. 0205) appear in both sets.
"""
keedy_ids, ginn_ids = set(), set()
for fname in os.listdir(superposed_dir):
if fname == "README.txt":
continue
m_keedy = re.match(r"PTP1B_y(\d{4})_", fname)
m_ginn = re.match(r"y(\d{4})_cluster4x_", fname)
if m_keedy:
keedy_ids.add(m_keedy.group(1))
if m_ginn:
ginn_ids.add(m_ginn.group(1))
return keedy_ids, ginn_ids


def load_apo_ids(mapping_txt):
"""Return set of 4-digit dataset IDs where no ligand was soaked (line has no .cif)."""
apo_ids = set()
with open(mapping_txt) as f:
for line in f:
stripped = line.strip()
if stripped and not stripped.endswith(".cif"):
# format: PTP1B-yXXXX: → [-5:-1] = XXXX
apo_ids.add(stripped[-5:-1])
return apo_ids


def find_mtz(dataset_id, phased_dir):
"""Find phased MTZ for a 4-digit dataset ID; tries _0 then _1 suffix."""
for suffix in ("_0", "_1"):
path = os.path.join(phased_dir, f"{dataset_id}{suffix}.mtz")
if os.path.exists(path):
return path
matches = glob.glob(os.path.join(phased_dir, f"{dataset_id}_*.mtz"))
return matches[0] if matches else None


def extract_id_from_mtz(mtz_path):
"""Extract 4-digit ID from filename like 0049_0.mtz."""
m = re.match(r"(\d{4})_", os.path.basename(mtz_path))
return m.group(1) if m else None


def extract_id_from_pdb(pdb_path):
"""Extract 4-digit ID from bound_models_standardized filename XXXX.pdb."""
m = re.match(r"(\d{4})\.pdb$", os.path.basename(pdb_path))
return m.group(1) if m else None


def wdf_grid(mtz_path):
mtz = gemmi.read_mtz_file(mtz_path)
grid = mtz.transform_f_phi_to_map(DIFF_COL, PHASE_COL, sample_rate=3.0)
grid.normalize()
return grid


def heavy_atom_peak(grid, pdb_path):
"""
Return max normalized WDF value across all Cl/Br/S/I atoms of LIG residue,
expanded over crystallographic symmetry. Returns None if no heavy atoms found.
"""
st = gemmi.read_pdb(pdb_path)
sel = gemmi.Selection("[CL,Br,S,I]")
sel_model = sel.copy_model_selection(st[0])
lig_atoms = [cra for cra in sel_model.all() if cra.residue.name == "LIG"]

if not lig_atoms:
return None

ops = grid.spacegroup.operations()
peak_vals = []

for cra in lig_atoms:
frac = st.cell.fractionalize(cra.atom.pos)
for op in ops:
mapped = op.apply_to_xyz(frac.tolist())
wx = mapped[0] - np.floor(mapped[0])
wy = mapped[1] - np.floor(mapped[1])
wz = mapped[2] - np.floor(mapped[2])
a = round(wx * grid.nu) % grid.nu
b = round(wy * grid.nv) % grid.nv
c = round(wz * grid.nw) % grid.nw
peak_vals.append(grid.get_value(a, b, c))

return float(np.max(peak_vals))


def report_ha(results):
if not results:
print(" No results.")
return
print(f" Datasets with heavy atoms : {len(results)}")
for did, val in results:
print(f" {did}: {val:.4f}")
vals = [v for _, v in results]
print(f" Mean WDF peak : {np.mean(vals):.4f} Std : {np.std(vals):.4f}")


def compute_metrics(recons_phased_dir):
"""
Compute apo peak and heavy atom peak metrics for a given recons_phased directory.
Returns a dict with keys: n_apo, apo_mean, apo_std,
n_keedy, keedy_mean, n_ginn, ginn_mean, n_all, all_mean.
"""
apo_ids = load_apo_ids(MAPPING_TXT)
keedy_ids, ginn_ids = classify_bound_models(ALL_SUPERPOSED_V2)
all_mtz = sorted(glob.glob(os.path.join(recons_phased_dir, "*.mtz")))

# ── Metric 1: Apo peak ──────────────────────────────────────────────────
apo_peaks = []
for mtz_path in tqdm(all_mtz, desc="Apo peak", leave=False):
did = extract_id_from_mtz(mtz_path)
if did is None or did not in apo_ids:
continue
try:
grid = wdf_grid(mtz_path)
apo_peaks.append(float(np.max(grid.array)))
except Exception:
pass

# ── Metric 2: Heavy atom peak ────────────────────────────────────────────
all_pdb = sorted(glob.glob(os.path.join(BOUND_MODELS_STD, "*.pdb")))
ha_all = []

for pdb_path in tqdm(all_pdb, desc="Heavy atom peak", leave=False):
did = extract_id_from_pdb(pdb_path)
if did is None:
continue
mtz_path = find_mtz(did, recons_phased_dir)
if mtz_path is None:
continue
try:
grid = wdf_grid(mtz_path)
val = heavy_atom_peak(grid, pdb_path)
if val is not None:
ha_all.append((did, val))
except Exception:
pass

ha_keedy = [(did, v) for did, v in ha_all if did in keedy_ids]
ha_ginn = [(did, v) for did, v in ha_all if did in ginn_ids]

def safe_mean(pairs):
vals = [v for _, v in pairs]
return float(np.mean(vals)) if vals else float("nan")

return dict(
n_apo = len(apo_peaks),
apo_mean = float(np.mean(apo_peaks)) if apo_peaks else float("nan"),
apo_std = float(np.std(apo_peaks)) if apo_peaks else float("nan"),
n_keedy = len(ha_keedy),
keedy_mean = safe_mean(ha_keedy),
n_ginn = len(ha_ginn),
ginn_mean = safe_mean(ha_ginn),
n_all = len(ha_all),
all_mean = safe_mean(ha_all),
)


def main():
parser = argparse.ArgumentParser(description="Compute VALDO map quality metrics.")
parser.add_argument("--recons-phased", default=RECONS_PHASED,
help="Path to recons_phased directory (default: vae/recons_phased/)")
args = parser.parse_args()

print(f"Phased MTZ dir: {args.recons_phased}")
apo_ids = load_apo_ids(MAPPING_TXT)
keedy_ids, ginn_ids = classify_bound_models(ALL_SUPERPOSED_V2)
all_mtz = sorted(glob.glob(os.path.join(args.recons_phased, "*.mtz")))

# ── Metric 1: Apo peak ──────────────────────────────────────────────────
print("=== APO PEAK METRIC ===")
apo_peaks = []
apo_failed = []

for mtz_path in tqdm(all_mtz, desc="Apo peak"):
did = extract_id_from_mtz(mtz_path)
if did is None or did not in apo_ids:
continue
try:
grid = wdf_grid(mtz_path)
apo_peaks.append(float(np.max(grid.array)))
except Exception as e:
apo_failed.append((did, str(e)))

print(f"Apo datasets processed : {len(apo_peaks)}")
if apo_failed:
print(f"Failed : {len(apo_failed)}")
for did, err in apo_failed:
print(f" {did}: {err}")
print(f"Mean highest WDF peak : {np.mean(apo_peaks):.4f}")
print(f"Std : {np.std(apo_peaks):.4f}")

# ── Metric 2: Heavy atom peak ────────────────────────────────────────────
print("\n=== HEAVY ATOM PEAK METRIC ===")
all_pdb = sorted(glob.glob(os.path.join(BOUND_MODELS_STD, "*.pdb")))
ha_all = []
ha_no_heavy = []
ha_missing = []

for pdb_path in tqdm(all_pdb, desc="Heavy atom peak"):
did = extract_id_from_pdb(pdb_path)
if did is None:
continue
mtz_path = find_mtz(did, args.recons_phased)
if mtz_path is None:
ha_missing.append(did)
continue
try:
grid = wdf_grid(mtz_path)
val = heavy_atom_peak(grid, pdb_path)
if val is None:
ha_no_heavy.append(did)
else:
ha_all.append((did, val))
except Exception as e:
ha_missing.append((did, str(e)))

if ha_no_heavy:
print(f"No Cl/Br/S/I in LIG : {len(ha_no_heavy)} — {ha_no_heavy}")
if ha_missing:
print(f"MTZ missing / failed : {len(ha_missing)}")

ha_keedy = [(did, v) for did, v in ha_all if did in keedy_ids]
ha_ginn = [(did, v) for did, v in ha_all if did in ginn_ids]

print(f"\n-- Keedy bound models --")
report_ha(ha_keedy)
print(f"\n-- Ginn bound models --")
report_ha(ha_ginn)
print(f"\n-- All bound models --")
report_ha(ha_all)


if __name__ == "__main__":
main()
23 changes: 23 additions & 0 deletions PTP1B_pipeline/configs/config_add_phases_and_blobs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
file_list: "/mnt/home/mli10/projects/valdo/PTP1B_pipeline/vae/recons/*.mtz"
phasing_path: "/mnt/home/mli10/projects/valdo/PTP1B_pipeline/refine_1nwl/refine_output/"
output_folder: "/mnt/home/mli10/projects/valdo/PTP1B_pipeline/vae/recons_phased/"
blob_output_folder: "/mnt/home/mli10/projects/valdo/PTP1B_pipeline/vae/blobs/"
model_folder: "/mnt/home/mli10/projects/valdo/PTP1B_pipeline/refine_1nwl/refine_output/"
phase_2FOFC_col_in: "PH2FOFCWT"
phase_FOFC_col_in: "PHFOFCWT"
phase_2FOFC_col_out: "PH2FOFCWT"
phase_FOFC_col_out: "PHFOFCWT"
rfree_label_in: null
sigF_col: "SIGF-obs-scaled"
diff_col: "diff"
sigdF_pct: 95.0
absdF_pct: 99.99
F_col: "F-obs-scaled"
recons_col: "recons"
extrapolate_factors: [2, 4, 6, 8, 16]
blob_diff_col: "WDF"
phase_col: "PH2FOFCWT"
cutoff: 3.5
radius_in_A: 4.0
prefix: ""
ncpu: 8
Loading