diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 6f720b303..d31476eb9 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.path import DEFAULT_CONFIG_PATH +from pyrit.common.utils import verify_and_resolve_path from pyrit.common.yaml_loadable import YamlLoadable from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.setup.initialization import ( @@ -181,6 +182,35 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfigurationLoader": filtered_data = {k: v for k, v in data.items() if v is not None} return cls(**filtered_data) + @classmethod + def from_yaml_file(cls, file: pathlib.Path | str) -> "ConfigurationLoader": + """ + Create a ConfigurationLoader from a YAML file. + + Relative initialization_scripts and env_files paths are eagerly resolved + against the config file's directory so they don't depend on the caller's + working directory. + + Args: + file (pathlib.Path | str): Path to the YAML configuration file. + + Returns: + ConfigurationLoader: A new instance with relative paths resolved. + """ + resolved_file = verify_and_resolve_path(file) + config: ConfigurationLoader = super().from_yaml_file(resolved_file) + config._make_relative_paths_absolute(base_dir=resolved_file.parent) + return config + + def _make_relative_paths_absolute(self, *, base_dir: pathlib.Path) -> None: + """Resolve relative initialization_scripts and env_files against a base directory.""" + if self.initialization_scripts: + self.initialization_scripts = [ + str(base_dir / s) if not pathlib.Path(s).is_absolute() else s for s in self.initialization_scripts + ] + if self.env_files: + self.env_files = [str(base_dir / e) if not pathlib.Path(e).is_absolute() else e for e in self.env_files] + @staticmethod def load_with_overrides( config_file: Optional[pathlib.Path] = None, diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index 76776bf51..ed09c8b80 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -239,6 +239,40 @@ def test_resolve_initialization_scripts_relative_path(self): # Check path ends with expected components (works on both Unix and Windows) assert resolved[0].parts[-2:] == ("relative", "script.py") + @pytest.mark.parametrize( + ("field_name", "relative_path"), + [ + ("initialization_scripts", "scripts/init.py"), + ("env_files", "env/local.env"), + ], + ) + def test_from_yaml_file_resolves_relative_paths_from_config_directory(self, tmp_path, field_name, relative_path): + """Relative paths from YAML are resolved from the config file's directory.""" + config_dir = tmp_path / "configs" + config_dir.mkdir() + config_path = config_dir / "pyrit.yaml" + config_path.write_text(f"{field_name}:\n - ./{relative_path}\n", encoding="utf-8") + + config = ConfigurationLoader.from_yaml_file(config_path) + + actual = getattr(config, field_name) + assert actual == [str(config_dir / relative_path)] + + def test_from_yaml_file_preserves_absolute_paths(self, tmp_path): + """Absolute paths in YAML are not changed by from_yaml_file.""" + abs_script = str(tmp_path / "absolute" / "script.py") + abs_env = str(tmp_path / "absolute" / ".env") + config_path = tmp_path / "pyrit.yaml" + config_path.write_text( + f"initialization_scripts:\n - {abs_script}\nenv_files:\n - {abs_env}\n", + encoding="utf-8", + ) + + config = ConfigurationLoader.from_yaml_file(config_path) + + assert config.initialization_scripts == [abs_script] + assert config.env_files == [abs_env] + def test_resolve_env_files_none_returns_none(self): """Test that None (default) returns None to signal 'use defaults'.""" config = ConfigurationLoader() @@ -429,6 +463,52 @@ def test_load_with_overrides_env_files_override(self, mock_default_path): assert config.env_files == ["/path/to/.env"] + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + @pytest.mark.parametrize( + ("field_name", "relative_path"), + [ + ("initialization_scripts", "scripts/init.py"), + ("env_files", "env/local.env"), + ], + ) + def test_load_with_overrides_resolves_config_file_relative_paths_from_config_dir( + self, mock_default_path, tmp_path, field_name, relative_path + ): + """Config-file relative paths are resolved from the config file's directory.""" + mock_default_path.exists.return_value = False + config_dir = tmp_path / "configs" + config_dir.mkdir() + config_path = config_dir / "pyrit.yaml" + config_path.write_text(f"{field_name}:\n - ./{relative_path}\n", encoding="utf-8") + + config = ConfigurationLoader.load_with_overrides(config_file=config_path) + + assert getattr(config, field_name) == [str(config_dir / relative_path)] + + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + @pytest.mark.parametrize( + ("field_name", "relative_path"), + [ + ("initialization_scripts", "scripts/override.py"), + ("env_files", "env/override.env"), + ], + ) + def test_load_with_overrides_cli_relative_paths_resolve_from_cwd( + self, mock_default_path, tmp_path, field_name, relative_path + ): + """CLI path overrides resolve relative paths from the current directory, not the config dir.""" + mock_default_path.exists.return_value = False + config_dir = tmp_path / "configs" + config_dir.mkdir() + config_path = config_dir / "pyrit.yaml" + config_path.write_text(f"{field_name}:\n - ./from-config-placeholder\n", encoding="utf-8") + + config = ConfigurationLoader.load_with_overrides(config_file=config_path, **{field_name: [relative_path]}) + resolver = "_resolve_initialization_scripts" if "scripts" in field_name else "_resolve_env_files" + resolved = getattr(config, resolver)() + + assert resolved == [pathlib.Path.cwd() / relative_path] + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") def test_load_with_overrides_converts_sequence_to_list(self, mock_default_path): """Test that Sequence inputs are converted to list for dataclass compatibility.""" @@ -457,19 +537,21 @@ def test_load_with_overrides_explicit_config_file_not_found(self): ConfigurationLoader.load_with_overrides(config_file=non_existent_path) @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") - def test_load_with_overrides_explicit_config_file_overrides_default(self, mock_default_path): + def test_load_with_overrides_explicit_config_file_overrides_default(self, mock_default_path, tmp_path): """Test explicit config file values override default config file.""" mock_default_path.exists.return_value = False - # Create a temp config file - yaml_content = """ + # Create a temp config file with absolute paths + abs_script = str(tmp_path / "explicit" / "script.py") + abs_env = str(tmp_path / "explicit" / ".env") + yaml_content = f""" memory_db_type: azure_sql initializers: - explicit_init initialization_scripts: - - /explicit/script.py + - {abs_script} env_files: - - /explicit/.env + - {abs_env} """ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: f.write(yaml_content) @@ -480,8 +562,8 @@ def test_load_with_overrides_explicit_config_file_overrides_default(self, mock_d assert config.memory_db_type == "azure_sql" assert config._initializer_configs[0].name == "explicit_init" - assert config.initialization_scripts == ["/explicit/script.py"] - assert config.env_files == ["/explicit/.env"] + assert config.initialization_scripts == [abs_script] + assert config.env_files == [abs_env] finally: config_path.unlink()