Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion infra/parameters.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"value": "YOUR_DATABASE_NAME"
},
"pyritInitializer": {
"value": "targets airt"
"value": "target airt"
},
"envSecretName": {
"value": "env-global"
Expand Down
54 changes: 34 additions & 20 deletions pyrit/setup/initializers/airt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import json
import logging
import os
from collections.abc import Callable

Expand Down Expand Up @@ -38,6 +39,8 @@
from pyrit.score.float_scale.self_ask_scale_scorer import SelfAskScaleScorer
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer

logger = logging.getLogger(__name__)


class AIRTInitializer(PyRITInitializer):
"""
Expand Down Expand Up @@ -275,32 +278,43 @@ def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name:

def _validate_operation_fields(self) -> None:
"""
Check that mandatory global memory labels (operation, operator)
are populated.
Ensure operator and operation are populated in GLOBAL_MEMORY_LABELS.

Reads operator/operation from .pyrit_conf if it exists, then merges
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels weird to do such strict enforcement but only if the file exists, and nothing otherwise.

In the GUI, it's plenty obvious that you should set them an if it's the deployed GUI it will auto-populate the operator anyway. I think the primary reason this exists is for enforcing these labels on scanner runs. Are those possible without conf file? If not, then I think this is fine.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! Yeah, scanner runs can work without .pyrit_conf. The original #1578 code didn’t enforce anything there either, it just crashed with a FileNotFoundError. So this change isn’t introducing a new gap, just turning a crash into a graceful skip.
Also, I believe all scanner examples use --initializers target load_default_datasets, not airt. The airt initializer is really only used by pyrit_backend for the GUI, where the operator gets set per-user at runtime. Happy to follow up if we want stricter label enforcement everywhere!

them into GLOBAL_MEMORY_LABELS. In container/GUI deployments where
.pyrit_conf is not present, the labels are set per-user by the GUI
at runtime, so this method is a no-op.

Raises:
ValueError: If mandatory global memory labels are missing.
ValueError: If .pyrit_conf exists but is missing operator or operation.
"""
with open(DEFAULT_CONFIG_PATH) as f:
data = yaml.load(f, Loader=yaml.SafeLoader)
raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS")
labels = dict(json.loads(raw_labels)) if raw_labels else {}

if "operator" not in data:
raise ValueError(
"Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)
if DEFAULT_CONFIG_PATH.exists():
with open(DEFAULT_CONFIG_PATH) as f:
data = yaml.load(f, Loader=yaml.SafeLoader) or {}

if "operation" not in data:
raise ValueError(
"Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)
if "operator" not in data:
raise ValueError(
"Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)

raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS")
labels = dict(json.loads(raw_labels)) if raw_labels else {}
if "operation" not in data:
raise ValueError(
"Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)

if "operator" not in labels:
labels["operator"] = data["operator"]
if "operator" not in labels:
labels["operator"] = data["operator"]

if "operation" not in labels:
labels["operation"] = data["operation"]
if "operation" not in labels:
labels["operation"] = data["operation"]

os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels)
os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels)
else:
logger.info(
"No .pyrit_conf found at %s — skipping operator/operation validation. "
"In GUI mode, these labels are set per-user at runtime.",
DEFAULT_CONFIG_PATH,
)
60 changes: 60 additions & 0 deletions tests/unit/setup/test_airt_initializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import os
import sys
from unittest.mock import patch
Expand Down Expand Up @@ -214,6 +215,65 @@ def test_validate_missing_operation_raises_error(self, tmp_path):
):
init._validate_operation_fields()

def test_validate_operation_fields_skips_when_pyrit_conf_missing(self, tmp_path):
"""Test that _validate_operation_fields does not crash when .pyrit_conf is missing.

In container/GUI deployments, .pyrit_conf does not exist. The method should
skip validation gracefully instead of raising FileNotFoundError.
"""
nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf"
init = AIRTInitializer()
with patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path):
# Should not raise
init._validate_operation_fields()

def test_validate_operation_fields_preserves_existing_labels_when_pyrit_conf_missing(self, tmp_path):
"""Test that existing GLOBAL_MEMORY_LABELS are preserved when .pyrit_conf is missing."""
nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf"
init = AIRTInitializer()
with (
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path),
patch.dict("os.environ", {"GLOBAL_MEMORY_LABELS": '{"operator": "gui_user", "operation": "gui_op"}'}),
):
init._validate_operation_fields()
# Existing labels should remain untouched
labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"])
assert labels["operator"] == "gui_user"
assert labels["operation"] == "gui_op"

def test_validate_operation_fields_merges_conf_into_labels(self, tmp_path):
"""Test that .pyrit_conf values are merged into GLOBAL_MEMORY_LABELS when labels are missing."""
conf_file = tmp_path / ".pyrit_conf"
conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"}))
init = AIRTInitializer()
with (
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file),
patch.dict("os.environ", {}, clear=False),
):
# Remove GLOBAL_MEMORY_LABELS if present
os.environ.pop("GLOBAL_MEMORY_LABELS", None)
init._validate_operation_fields()
labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"])
assert labels["operator"] == "conf_user"
assert labels["operation"] == "conf_op"

def test_validate_operation_fields_does_not_overwrite_existing_labels(self, tmp_path):
"""Test that .pyrit_conf values do not overwrite existing GLOBAL_MEMORY_LABELS entries."""
conf_file = tmp_path / ".pyrit_conf"
conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"}))
init = AIRTInitializer()
with (
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file),
patch.dict(
"os.environ",
{"GLOBAL_MEMORY_LABELS": '{"operator": "existing_user", "operation": "existing_op"}'},
),
):
init._validate_operation_fields()
labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"])
assert labels["operator"] == "existing_user"
assert labels["operation"] == "existing_op"

def test_validate_db_connection_raises_error(self):
"""Test that validate raises error when AZURE_SQL_DB_CONNECTION_STRING is missing."""
del os.environ["AZURE_SQL_DB_CONNECTION_STRING"]
Expand Down
Loading