Skip to content
Merged
48 changes: 48 additions & 0 deletions src/fastmdanalysis/analysis/dihedrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ def plot(
mask = np.isin(x, residues)
x = x[mask]
y = y[mask]
if yerr is not None:
yerr = yerr[mask]

plot_matrix = np.column_stack([x, y])
header = f"residue_index phi_mean_{self.units}"
if yerr is not None:
plot_matrix = np.column_stack([plot_matrix, yerr])
header = f"{header} phi_std_{self.units}"
self._save_data(plot_matrix, "phi_avg_plot", header=header)

# Plot
fig, ax = plt.subplots(figsize=figsize)
Expand Down Expand Up @@ -392,6 +401,15 @@ def plot(self, **kwargs) -> str:
mask = np.isin(x, residues)
x = x[mask]
y = y[mask]
if yerr is not None:
yerr = yerr[mask]

plot_matrix = np.column_stack([x, y])
header = f"residue_index psi_mean_{self.units}"
if yerr is not None:
plot_matrix = np.column_stack([plot_matrix, yerr])
header = f"{header} psi_std_{self.units}"
self._save_data(plot_matrix, "psi_avg_plot", header=header)

fig, ax = plt.subplots(figsize=kwargs.get("figsize", (12, 6)))
ax.errorbar(
Expand Down Expand Up @@ -566,6 +584,15 @@ def plot(self, **kwargs) -> str:
mask = np.isin(x, residues)
x = x[mask]
y = y[mask]
if yerr is not None:
yerr = yerr[mask]

plot_matrix = np.column_stack([x, y])
header = f"residue_index omega_mean_{self.units}"
if yerr is not None:
plot_matrix = np.column_stack([plot_matrix, yerr])
header = f"{header} omega_std_{self.units}"
self._save_data(plot_matrix, "omega_avg_plot", header=header)

fig, ax = plt.subplots(figsize=kwargs.get("figsize", (12, 6)))
ax.errorbar(
Expand Down Expand Up @@ -759,6 +786,13 @@ def plot_ramachandran(
if psi_std is not None:
psi_std = psi_std[mask]

avg_matrix = np.column_stack([res_indices, x, y])
header = f"residue_index phi_mean_{self.units} psi_mean_{self.units}"
if phi_std is not None and psi_std is not None:
avg_matrix = np.column_stack([avg_matrix, phi_std, psi_std])
header = f"{header} phi_std_{self.units} psi_std_{self.units}"
self._save_data(avg_matrix, "ramachandran_avg", header=header)

fig, ax = plt.subplots(figsize=figsize)
cmap = plt.get_cmap("viridis")
norm = plt.Normalize(vmin=res_indices.min(), vmax=res_indices.max()) if len(res_indices) else None
Expand All @@ -782,6 +816,10 @@ def plot_ramachandran(
ax.set_ylabel(f"Psi ({self.units})")
ax.grid(True, alpha=0.3)

limit = 180.0 if self.units == "degrees" else np.pi
ax.set_xlim(-limit, limit)
ax.set_ylim(-limit, limit)

if len(res_indices):
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
cbar = plt.colorbar(mappable, ax=ax)
Expand Down Expand Up @@ -817,8 +855,18 @@ def plot_ramachandran(
ax_res.set_xlabel(f"Phi ({self.units})")
ax_res.set_ylabel(f"Psi ({self.units})")
ax_res.grid(True, alpha=0.3)
ax_res.set_xlim(-limit, limit)
ax_res.set_ylim(-limit, limit)
fig_res.tight_layout()

frame_matrix = np.column_stack([phi_angles[:, idx], psi_angles[:, idx]])
frame_header = f"phi_{self.units} psi_{self.units}"
self._save_data(
frame_matrix,
f"ramachandran_res{res}",
header=frame_header,
)

per_path = self._save_plot(
fig_res,
"ramachandran",
Expand Down
6 changes: 0 additions & 6 deletions tests/test_dihedrals_residue_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@ def test_phi_residue_alias_no_unknown_warning(fastmda):

messages = [str(wi.message) for wi in w]
assert not _unknown_warnings(messages)

assert analysis.data.shape[0] == 1

assert "phi_residues" in analysis.results
assert list(np.asarray(analysis.results["phi_residues"]).astype(int)) == [0]

assert "phi_avg_filtered" in analysis.results
filtered = analysis.results["phi_avg_filtered"]
assert filtered.shape[0] == 1
Expand All @@ -32,16 +29,13 @@ def test_dihedrals_residue_selection_alias_propagates(fastmda):

messages = [str(wi.message) for wi in w]
assert not _unknown_warnings(messages)

# Ensure the combined analysis truly computed only 2 residues worth of data
assert analysis.results["phi_avg"].shape[0] == 2
assert analysis.results["psi_avg"].shape[0] == 2
assert analysis.results["omega_avg"].shape[0] == 2

assert list(np.asarray(analysis.results["phi_residues"]).astype(int)) == [0, 1]
assert list(np.asarray(analysis.results["psi_residues"]).astype(int)) == [0, 1]
assert list(np.asarray(analysis.results["omega_residues"]).astype(int)) == [0, 1]

for key in ("phi_avg_filtered", "psi_avg_filtered", "omega_avg_filtered"):
assert key in analysis.results
filtered = analysis.results[key]
Expand Down