Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pyrit/setup/configuration_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any, Optional, Union

from pyrit.common.path import DEFAULT_CONFIG_PATH
from pyrit.common.utils import verify_and_resolve_path
from pyrit.common.yaml_loadable import YamlLoadable
from pyrit.identifiers.class_name_utils import class_name_to_snake_case
from pyrit.setup.initialization import (
Expand Down Expand Up @@ -181,6 +182,35 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfigurationLoader":
filtered_data = {k: v for k, v in data.items() if v is not None}
return cls(**filtered_data)

@classmethod
def from_yaml_file(cls, file: pathlib.Path | str) -> "ConfigurationLoader":
"""
Create a ConfigurationLoader from a YAML file.

Relative initialization_scripts and env_files paths are eagerly resolved
against the config file's directory so they don't depend on the caller's
working directory.

Args:
file (pathlib.Path | str): Path to the YAML configuration file.

Returns:
ConfigurationLoader: A new instance with relative paths resolved.
"""
resolved_file = verify_and_resolve_path(file)
config: ConfigurationLoader = super().from_yaml_file(resolved_file)
config._make_relative_paths_absolute(base_dir=resolved_file.parent)
return config

def _make_relative_paths_absolute(self, *, base_dir: pathlib.Path) -> None:
"""Resolve relative initialization_scripts and env_files against a base directory."""
if self.initialization_scripts:
self.initialization_scripts = [
str(base_dir / s) if not pathlib.Path(s).is_absolute() else s for s in self.initialization_scripts
]
if self.env_files:
self.env_files = [str(base_dir / e) if not pathlib.Path(e).is_absolute() else e for e in self.env_files]

@staticmethod
def load_with_overrides(
config_file: Optional[pathlib.Path] = None,
Expand Down
96 changes: 89 additions & 7 deletions tests/unit/setup/test_configuration_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,40 @@ def test_resolve_initialization_scripts_relative_path(self):
# Check path ends with expected components (works on both Unix and Windows)
assert resolved[0].parts[-2:] == ("relative", "script.py")

@pytest.mark.parametrize(
("field_name", "relative_path"),
[
("initialization_scripts", "scripts/init.py"),
("env_files", "env/local.env"),
],
)
def test_from_yaml_file_resolves_relative_paths_from_config_directory(self, tmp_path, field_name, relative_path):
"""Relative paths from YAML are resolved from the config file's directory."""
config_dir = tmp_path / "configs"
config_dir.mkdir()
config_path = config_dir / "pyrit.yaml"
config_path.write_text(f"{field_name}:\n - ./{relative_path}\n", encoding="utf-8")

config = ConfigurationLoader.from_yaml_file(config_path)

actual = getattr(config, field_name)
assert actual == [str(config_dir / relative_path)]

def test_from_yaml_file_preserves_absolute_paths(self, tmp_path):
"""Absolute paths in YAML are not changed by from_yaml_file."""
abs_script = str(tmp_path / "absolute" / "script.py")
abs_env = str(tmp_path / "absolute" / ".env")
config_path = tmp_path / "pyrit.yaml"
config_path.write_text(
f"initialization_scripts:\n - {abs_script}\nenv_files:\n - {abs_env}\n",
encoding="utf-8",
)

config = ConfigurationLoader.from_yaml_file(config_path)

assert config.initialization_scripts == [abs_script]
assert config.env_files == [abs_env]

def test_resolve_env_files_none_returns_none(self):
"""Test that None (default) returns None to signal 'use defaults'."""
config = ConfigurationLoader()
Expand Down Expand Up @@ -429,6 +463,52 @@ def test_load_with_overrides_env_files_override(self, mock_default_path):

assert config.env_files == ["/path/to/.env"]

@mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH")
@pytest.mark.parametrize(
("field_name", "relative_path"),
[
("initialization_scripts", "scripts/init.py"),
("env_files", "env/local.env"),
],
)
def test_load_with_overrides_resolves_config_file_relative_paths_from_config_dir(
self, mock_default_path, tmp_path, field_name, relative_path
):
"""Config-file relative paths are resolved from the config file's directory."""
mock_default_path.exists.return_value = False
config_dir = tmp_path / "configs"
config_dir.mkdir()
config_path = config_dir / "pyrit.yaml"
config_path.write_text(f"{field_name}:\n - ./{relative_path}\n", encoding="utf-8")

config = ConfigurationLoader.load_with_overrides(config_file=config_path)

assert getattr(config, field_name) == [str(config_dir / relative_path)]

@mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH")
@pytest.mark.parametrize(
("field_name", "relative_path"),
[
("initialization_scripts", "scripts/override.py"),
("env_files", "env/override.env"),
],
)
def test_load_with_overrides_cli_relative_paths_resolve_from_cwd(
self, mock_default_path, tmp_path, field_name, relative_path
):
"""CLI path overrides resolve relative paths from the current directory, not the config dir."""
mock_default_path.exists.return_value = False
config_dir = tmp_path / "configs"
config_dir.mkdir()
config_path = config_dir / "pyrit.yaml"
config_path.write_text(f"{field_name}:\n - ./from-config-placeholder\n", encoding="utf-8")

config = ConfigurationLoader.load_with_overrides(config_file=config_path, **{field_name: [relative_path]})
resolver = "_resolve_initialization_scripts" if "scripts" in field_name else "_resolve_env_files"
resolved = getattr(config, resolver)()

assert resolved == [pathlib.Path.cwd() / relative_path]

@mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH")
def test_load_with_overrides_converts_sequence_to_list(self, mock_default_path):
"""Test that Sequence inputs are converted to list for dataclass compatibility."""
Expand Down Expand Up @@ -457,19 +537,21 @@ def test_load_with_overrides_explicit_config_file_not_found(self):
ConfigurationLoader.load_with_overrides(config_file=non_existent_path)

@mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH")
def test_load_with_overrides_explicit_config_file_overrides_default(self, mock_default_path):
def test_load_with_overrides_explicit_config_file_overrides_default(self, mock_default_path, tmp_path):
"""Test explicit config file values override default config file."""
mock_default_path.exists.return_value = False

# Create a temp config file
yaml_content = """
# Create a temp config file with absolute paths
abs_script = str(tmp_path / "explicit" / "script.py")
abs_env = str(tmp_path / "explicit" / ".env")
yaml_content = f"""
memory_db_type: azure_sql
initializers:
- explicit_init
initialization_scripts:
- /explicit/script.py
- {abs_script}
env_files:
- /explicit/.env
- {abs_env}
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
Expand All @@ -480,8 +562,8 @@ def test_load_with_overrides_explicit_config_file_overrides_default(self, mock_d

assert config.memory_db_type == "azure_sql"
assert config._initializer_configs[0].name == "explicit_init"
assert config.initialization_scripts == ["/explicit/script.py"]
assert config.env_files == ["/explicit/.env"]
assert config.initialization_scripts == [abs_script]
assert config.env_files == [abs_env]
finally:
config_path.unlink()

Expand Down
Loading