From 7a47e8c2c084a571ce162884c32cb959f6ab8ee9 Mon Sep 17 00:00:00 2001 From: Varun Joginpalli Date: Thu, 23 Apr 2026 21:30:46 +0000 Subject: [PATCH] Deployment Bug Fixes --- infra/parameters.example.json | 2 +- pyrit/setup/initializers/airt.py | 54 ++++++++++++-------- tests/unit/setup/test_airt_initializer.py | 60 +++++++++++++++++++++++ 3 files changed, 95 insertions(+), 21 deletions(-) diff --git a/infra/parameters.example.json b/infra/parameters.example.json index 28380a12f1..e82a5a0e0a 100644 --- a/infra/parameters.example.json +++ b/infra/parameters.example.json @@ -27,7 +27,7 @@ "value": "YOUR_DATABASE_NAME" }, "pyritInitializer": { - "value": "targets airt" + "value": "target airt" }, "envSecretName": { "value": "env-global" diff --git a/pyrit/setup/initializers/airt.py b/pyrit/setup/initializers/airt.py index a0e61c52d4..8fc62ea5e0 100644 --- a/pyrit/setup/initializers/airt.py +++ b/pyrit/setup/initializers/airt.py @@ -9,6 +9,7 @@ """ import json +import logging import os from collections.abc import Callable @@ -38,6 +39,8 @@ from pyrit.score.float_scale.self_ask_scale_scorer import SelfAskScaleScorer from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer +logger = logging.getLogger(__name__) + class AIRTInitializer(PyRITInitializer): """ @@ -275,32 +278,43 @@ def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name: def _validate_operation_fields(self) -> None: """ - Check that mandatory global memory labels (operation, operator) - are populated. + Ensure operator and operation are populated in GLOBAL_MEMORY_LABELS. + + Reads operator/operation from .pyrit_conf if it exists, then merges + them into GLOBAL_MEMORY_LABELS. In container/GUI deployments where + .pyrit_conf is not present, the labels are set per-user by the GUI + at runtime, so this method is a no-op. Raises: - ValueError: If mandatory global memory labels are missing. + ValueError: If .pyrit_conf exists but is missing operator or operation. """ - with open(DEFAULT_CONFIG_PATH) as f: - data = yaml.load(f, Loader=yaml.SafeLoader) + raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS") + labels = dict(json.loads(raw_labels)) if raw_labels else {} - if "operator" not in data: - raise ValueError( - "Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." - ) + if DEFAULT_CONFIG_PATH.exists(): + with open(DEFAULT_CONFIG_PATH) as f: + data = yaml.load(f, Loader=yaml.SafeLoader) or {} - if "operation" not in data: - raise ValueError( - "Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." - ) + if "operator" not in data: + raise ValueError( + "Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." + ) - raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS") - labels = dict(json.loads(raw_labels)) if raw_labels else {} + if "operation" not in data: + raise ValueError( + "Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." + ) - if "operator" not in labels: - labels["operator"] = data["operator"] + if "operator" not in labels: + labels["operator"] = data["operator"] - if "operation" not in labels: - labels["operation"] = data["operation"] + if "operation" not in labels: + labels["operation"] = data["operation"] - os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels) + os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels) + else: + logger.info( + "No .pyrit_conf found at %s — skipping operator/operation validation. " + "In GUI mode, these labels are set per-user at runtime.", + DEFAULT_CONFIG_PATH, + ) diff --git a/tests/unit/setup/test_airt_initializer.py b/tests/unit/setup/test_airt_initializer.py index 95d96c90a4..34d99d24ce 100644 --- a/tests/unit/setup/test_airt_initializer.py +++ b/tests/unit/setup/test_airt_initializer.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json import os import sys from unittest.mock import patch @@ -214,6 +215,65 @@ def test_validate_missing_operation_raises_error(self, tmp_path): ): init._validate_operation_fields() + def test_validate_operation_fields_skips_when_pyrit_conf_missing(self, tmp_path): + """Test that _validate_operation_fields does not crash when .pyrit_conf is missing. + + In container/GUI deployments, .pyrit_conf does not exist. The method should + skip validation gracefully instead of raising FileNotFoundError. + """ + nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf" + init = AIRTInitializer() + with patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path): + # Should not raise + init._validate_operation_fields() + + def test_validate_operation_fields_preserves_existing_labels_when_pyrit_conf_missing(self, tmp_path): + """Test that existing GLOBAL_MEMORY_LABELS are preserved when .pyrit_conf is missing.""" + nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf" + init = AIRTInitializer() + with ( + patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path), + patch.dict("os.environ", {"GLOBAL_MEMORY_LABELS": '{"operator": "gui_user", "operation": "gui_op"}'}), + ): + init._validate_operation_fields() + # Existing labels should remain untouched + labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"]) + assert labels["operator"] == "gui_user" + assert labels["operation"] == "gui_op" + + def test_validate_operation_fields_merges_conf_into_labels(self, tmp_path): + """Test that .pyrit_conf values are merged into GLOBAL_MEMORY_LABELS when labels are missing.""" + conf_file = tmp_path / ".pyrit_conf" + conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"})) + init = AIRTInitializer() + with ( + patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file), + patch.dict("os.environ", {}, clear=False), + ): + # Remove GLOBAL_MEMORY_LABELS if present + os.environ.pop("GLOBAL_MEMORY_LABELS", None) + init._validate_operation_fields() + labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"]) + assert labels["operator"] == "conf_user" + assert labels["operation"] == "conf_op" + + def test_validate_operation_fields_does_not_overwrite_existing_labels(self, tmp_path): + """Test that .pyrit_conf values do not overwrite existing GLOBAL_MEMORY_LABELS entries.""" + conf_file = tmp_path / ".pyrit_conf" + conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"})) + init = AIRTInitializer() + with ( + patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file), + patch.dict( + "os.environ", + {"GLOBAL_MEMORY_LABELS": '{"operator": "existing_user", "operation": "existing_op"}'}, + ), + ): + init._validate_operation_fields() + labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"]) + assert labels["operator"] == "existing_user" + assert labels["operation"] == "existing_op" + def test_validate_db_connection_raises_error(self): """Test that validate raises error when AZURE_SQL_DB_CONNECTION_STRING is missing.""" del os.environ["AZURE_SQL_DB_CONNECTION_STRING"]