diff --git a/src/fastmdanalysis/analysis/dihedrals.py b/src/fastmdanalysis/analysis/dihedrals.py index a0cb5f7..eca2762 100644 --- a/src/fastmdanalysis/analysis/dihedrals.py +++ b/src/fastmdanalysis/analysis/dihedrals.py @@ -94,6 +94,9 @@ def __init__( self.units: str = units self.strict = strict + # Residue indices corresponding to the rows in self.data + self.residue_indices: Optional[np.ndarray] = None + # Populated during run() self.data: Optional[np.ndarray] = None self.results: Dict[str, np.ndarray] = {} @@ -115,6 +118,15 @@ def run(self) -> Dict[str, np.ndarray]: if angles.size == 0: raise AnalysisError("No phi angles found in trajectory (no protein?)") + # Restrict computation to selected residues (0-based dihedral/residue index) + full_res_indices = np.arange(angles.shape[1]) + if self.residues is not None: + res_list = [self.residues] if isinstance(self.residues, int) else list(self.residues) + angles = angles[:, res_list] + self.residue_indices = np.asarray(res_list, dtype=int) + else: + self.residue_indices = full_res_indices + # Circular mean per residue n_residues = angles.shape[1] avg_angles = np.zeros(n_residues) @@ -132,16 +144,12 @@ def run(self) -> Dict[str, np.ndarray]: avg_angles = np.degrees(avg_angles) self.data = avg_angles.reshape(-1, 1) - self.results = {"phi_avg": self.data} - - # Filter by residues if specified - if self.residues is not None: - if isinstance(self.residues, int): - res_list = [self.residues] - else: - res_list = list(self.residues) - filtered_data = self.data[res_list] - self.results["phi_avg_filtered"] = filtered_data.reshape(-1, 1) + # If residues were provided, self.data is already filtered. + self.results = { + "phi_avg": self.data, + "phi_avg_filtered": self.data, + "phi_residues": self.residue_indices, + } # Save data self._save_data(self.data, "phi_avg", header=f"phi_avg_{self.units}") @@ -197,8 +205,12 @@ def plot( raise AnalysisError("No phi data available to plot.") y = np.asarray(data, dtype=float).flatten() - n = len(y) - x = np.arange(n) + + # X-axis should reflect residue indices of the computed data (not 0..N-1) + if self.residue_indices is not None and len(self.residue_indices) == len(y): + x = self.residue_indices.astype(int) + else: + x = np.arange(len(y)) # Filter residues if residues is not None: @@ -278,6 +290,8 @@ def __init__( self.data = None self.results = {} + self.residue_indices: Optional[np.ndarray] = None + def run(self) -> Dict[str, np.ndarray]: logger.info("Starting Psi analysis") try: @@ -285,6 +299,14 @@ def run(self) -> Dict[str, np.ndarray]: if angles.size == 0: raise AnalysisError("No psi angles found in trajectory") + full_res_indices = np.arange(angles.shape[1]) + if self.residues is not None: + res_list = [self.residues] if isinstance(self.residues, int) else list(self.residues) + angles = angles[:, res_list] + self.residue_indices = np.asarray(res_list, dtype=int) + else: + self.residue_indices = full_res_indices + n_residues = angles.shape[1] avg_angles = np.zeros(n_residues) for i in range(n_residues): @@ -299,12 +321,11 @@ def run(self) -> Dict[str, np.ndarray]: avg_angles = np.degrees(avg_angles) self.data = avg_angles.reshape(-1, 1) - self.results = {"psi_avg": self.data} - - if self.residues is not None: - res_list = [self.residues] if isinstance(self.residues, int) else list(self.residues) - filtered_data = self.data[res_list] - self.results["psi_avg_filtered"] = filtered_data.reshape(-1, 1) + self.results = { + "psi_avg": self.data, + "psi_avg_filtered": self.data, + "psi_residues": self.residue_indices, + } self._save_data(self.data, "psi_avg", header=f"psi_avg_{self.units}") plot_path = self.plot() @@ -327,8 +348,10 @@ def plot(self, **kwargs) -> str: raise AnalysisError("No psi data available to plot.") y = np.asarray(kwargs["data"], dtype=float).flatten() - n = len(y) - x = np.arange(n) + if self.residue_indices is not None and len(self.residue_indices) == len(y): + x = self.residue_indices.astype(int) + else: + x = np.arange(len(y)) # Filter residues residues = kwargs.get("residues") @@ -406,6 +429,8 @@ def __init__( self.data = None self.results = {} + self.residue_indices: Optional[np.ndarray] = None + def run(self) -> Dict[str, np.ndarray]: logger.info("Starting Omega analysis") try: @@ -413,6 +438,14 @@ def run(self) -> Dict[str, np.ndarray]: if angles.size == 0: raise AnalysisError("No omega angles found in trajectory") + full_res_indices = np.arange(angles.shape[1]) + if self.residues is not None: + res_list = [self.residues] if isinstance(self.residues, int) else list(self.residues) + angles = angles[:, res_list] + self.residue_indices = np.asarray(res_list, dtype=int) + else: + self.residue_indices = full_res_indices + n_residues = angles.shape[1] avg_angles = np.zeros(n_residues) for i in range(n_residues): @@ -427,12 +460,11 @@ def run(self) -> Dict[str, np.ndarray]: avg_angles = np.degrees(avg_angles) self.data = avg_angles.reshape(-1, 1) - self.results = {"omega_avg": self.data} - - if self.residues is not None: - res_list = [self.residues] if isinstance(self.residues, int) else list(self.residues) - filtered_data = self.data[res_list] - self.results["omega_avg_filtered"] = filtered_data.reshape(-1, 1) + self.results = { + "omega_avg": self.data, + "omega_avg_filtered": self.data, + "omega_residues": self.residue_indices, + } self._save_data(self.data, "omega_avg", header=f"omega_avg_{self.units}") plot_path = self.plot() @@ -454,8 +486,10 @@ def plot(self, **kwargs) -> str: raise AnalysisError("No omega data available to plot.") y = np.asarray(kwargs["data"], dtype=float).flatten() - n = len(y) - x = np.arange(n) + if self.residue_indices is not None and len(self.residue_indices) == len(y): + x = self.residue_indices.astype(int) + else: + x = np.arange(len(y)) # Filter residues residues = kwargs.get("residues") @@ -602,8 +636,12 @@ def plot_ramachandran( x = phi_data y = psi_data - n = len(x) - res_indices = np.arange(n) + # Use residue indices when the per-angle analyses were residue-filtered + res_indices = self.results.get("phi_residues") + if res_indices is None: + res_indices = np.arange(len(x)) + else: + res_indices = np.asarray(res_indices, dtype=int) # Filter residues if residues is not None: diff --git a/tests/test_dihedrals_residue_options.py b/tests/test_dihedrals_residue_options.py index 36728cf..41a6e7e 100644 --- a/tests/test_dihedrals_residue_options.py +++ b/tests/test_dihedrals_residue_options.py @@ -14,6 +14,11 @@ def test_phi_residue_alias_no_unknown_warning(fastmda): messages = [str(wi.message) for wi in w] assert not _unknown_warnings(messages) + assert analysis.data.shape[0] == 1 + + assert "phi_residues" in analysis.results + assert list(np.asarray(analysis.results["phi_residues"]).astype(int)) == [0] + assert "phi_avg_filtered" in analysis.results filtered = analysis.results["phi_avg_filtered"] assert filtered.shape[0] == 1 @@ -28,6 +33,15 @@ def test_dihedrals_residue_selection_alias_propagates(fastmda): messages = [str(wi.message) for wi in w] assert not _unknown_warnings(messages) + # Ensure the combined analysis truly computed only 2 residues worth of data + assert analysis.results["phi_avg"].shape[0] == 2 + assert analysis.results["psi_avg"].shape[0] == 2 + assert analysis.results["omega_avg"].shape[0] == 2 + + assert list(np.asarray(analysis.results["phi_residues"]).astype(int)) == [0, 1] + assert list(np.asarray(analysis.results["psi_residues"]).astype(int)) == [0, 1] + assert list(np.asarray(analysis.results["omega_residues"]).astype(int)) == [0, 1] + for key in ("phi_avg_filtered", "psi_avg_filtered", "omega_avg_filtered"): assert key in analysis.results filtered = analysis.results[key]