diff --git a/src/fastmdanalysis/analysis/dihedrals.py b/src/fastmdanalysis/analysis/dihedrals.py index b4b9020..a0cb5f7 100644 --- a/src/fastmdanalysis/analysis/dihedrals.py +++ b/src/fastmdanalysis/analysis/dihedrals.py @@ -36,7 +36,8 @@ class PhiAnalysis(BaseAnalysis): """ _ALIASES = { - "residues": "residue_selection", + "residue": "residues", + "residue_selection": "residues", } def __init__( @@ -241,7 +242,8 @@ class PsiAnalysis(BaseAnalysis): # Similar structure to PhiAnalysis, but for psi angles _ALIASES = { - "residues": "residue_selection", + "residue": "residues", + "residue_selection": "residues", } def __init__( @@ -253,13 +255,15 @@ def __init__( **kwargs ): logger.info("Initializing Psi analysis") + warn_unknown = kwargs.pop("_warn_unknown", False) + analysis_opts = {"residues": residues, "units": units, "strict": strict} analysis_opts.update(kwargs) forwarder = OptionsForwarder(aliases=self._ALIASES, strict=strict) resolved = forwarder.apply_aliases(analysis_opts) resolved = forwarder.filter_known( - resolved, {"residues", "units", "strict", "output"}, context="psi" + resolved, {"residues", "units", "strict", "output"}, context="psi", warn=warn_unknown ) residues = resolved.get("residues", None) @@ -366,7 +370,8 @@ class OmegaAnalysis(BaseAnalysis): """ _ALIASES = { - "residues": "residue_selection", + "residue": "residues", + "residue_selection": "residues", } def __init__( @@ -378,13 +383,15 @@ def __init__( **kwargs ): logger.info("Initializing Omega analysis") + warn_unknown = kwargs.pop("_warn_unknown", False) + analysis_opts = {"residues": residues, "units": units, "strict": strict} analysis_opts.update(kwargs) forwarder = OptionsForwarder(aliases=self._ALIASES, strict=strict) resolved = forwarder.apply_aliases(analysis_opts) resolved = forwarder.filter_known( - resolved, {"residues", "units", "strict", "output"}, context="omega" + resolved, {"residues", "units", "strict", "output"}, context="omega", warn=warn_unknown ) residues = resolved.get("residues", None) @@ -489,6 +496,11 @@ class DihedralsAnalysis(BaseAnalysis): Combined dihedral analysis for phi, psi, omega with Ramachandran plotting. """ + _ALIASES = { + "residue": "residues", + "residue_selection": "residues", + } + def __init__( self, trajectory: md.Trajectory, @@ -499,13 +511,15 @@ def __init__( **kwargs ): logger.info("Initializing Dihedrals analysis") + warn_unknown = kwargs.pop("_warn_unknown", False) + analysis_opts = {"types": types, "residues": residues, "units": units, "strict": strict} analysis_opts.update(kwargs) - forwarder = OptionsForwarder(strict=strict) + forwarder = OptionsForwarder(aliases=self._ALIASES, strict=strict) resolved = forwarder.apply_aliases(analysis_opts) resolved = forwarder.filter_known( - resolved, {"types", "residues", "units", "strict", "output"}, context="dihedrals" + resolved, {"types", "residues", "units", "strict", "output"}, context="dihedrals", warn=warn_unknown ) types = resolved.get("types", ["phi", "psi", "omega"]) diff --git a/tests/test_dihedrals_residue_options.py b/tests/test_dihedrals_residue_options.py new file mode 100644 index 0000000..36728cf --- /dev/null +++ b/tests/test_dihedrals_residue_options.py @@ -0,0 +1,36 @@ +import warnings +import numpy as np + + +def _unknown_warnings(messages): + return [m for m in messages if "Unknown options" in m or "Unsupported options" in m] + + +def test_phi_residue_alias_no_unknown_warning(fastmda): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + analysis = fastmda.phi(residue=0) + + messages = [str(wi.message) for wi in w] + assert not _unknown_warnings(messages) + + assert "phi_avg_filtered" in analysis.results + filtered = analysis.results["phi_avg_filtered"] + assert filtered.shape[0] == 1 + np.testing.assert_allclose(filtered[0], analysis.data[0], atol=1e-6) + + +def test_dihedrals_residue_selection_alias_propagates(fastmda): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + analysis = fastmda.dihedrals(residue_selection=[0, 1]) + + messages = [str(wi.message) for wi in w] + assert not _unknown_warnings(messages) + + for key in ("phi_avg_filtered", "psi_avg_filtered", "omega_avg_filtered"): + assert key in analysis.results + filtered = analysis.results[key] + assert filtered.shape[0] == 2 + base_key = key.replace("_filtered", "") + np.testing.assert_allclose(filtered[:, 0], analysis.results[base_key][:2, 0], atol=1e-6)