Skip to content

Commit 936ae43

Browse files
authored
FIX: AIRTInitializer container crash and parameters.example.json typo (#1648)
1 parent aabf9bd commit 936ae43

3 files changed

Lines changed: 95 additions & 21 deletions

File tree

infra/parameters.example.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"value": "YOUR_DATABASE_NAME"
2828
},
2929
"pyritInitializer": {
30-
"value": "targets airt"
30+
"value": "target airt"
3131
},
3232
"envSecretName": {
3333
"value": "env-global"

pyrit/setup/initializers/airt.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
import json
12+
import logging
1213
import os
1314
from collections.abc import Callable
1415

@@ -38,6 +39,8 @@
3839
from pyrit.score.float_scale.self_ask_scale_scorer import SelfAskScaleScorer
3940
from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer
4041

42+
logger = logging.getLogger(__name__)
43+
4144

4245
class AIRTInitializer(PyRITInitializer):
4346
"""
@@ -275,32 +278,43 @@ def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name:
275278

276279
def _validate_operation_fields(self) -> None:
277280
"""
278-
Check that mandatory global memory labels (operation, operator)
279-
are populated.
281+
Ensure operator and operation are populated in GLOBAL_MEMORY_LABELS.
282+
283+
Reads operator/operation from .pyrit_conf if it exists, then merges
284+
them into GLOBAL_MEMORY_LABELS. In container/GUI deployments where
285+
.pyrit_conf is not present, the labels are set per-user by the GUI
286+
at runtime, so this method is a no-op.
280287
281288
Raises:
282-
ValueError: If mandatory global memory labels are missing.
289+
ValueError: If .pyrit_conf exists but is missing operator or operation.
283290
"""
284-
with open(DEFAULT_CONFIG_PATH) as f:
285-
data = yaml.load(f, Loader=yaml.SafeLoader)
291+
raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS")
292+
labels = dict(json.loads(raw_labels)) if raw_labels else {}
286293

287-
if "operator" not in data:
288-
raise ValueError(
289-
"Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
290-
)
294+
if DEFAULT_CONFIG_PATH.exists():
295+
with open(DEFAULT_CONFIG_PATH) as f:
296+
data = yaml.load(f, Loader=yaml.SafeLoader) or {}
291297

292-
if "operation" not in data:
293-
raise ValueError(
294-
"Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
295-
)
298+
if "operator" not in data:
299+
raise ValueError(
300+
"Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
301+
)
296302

297-
raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS")
298-
labels = dict(json.loads(raw_labels)) if raw_labels else {}
303+
if "operation" not in data:
304+
raise ValueError(
305+
"Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
306+
)
299307

300-
if "operator" not in labels:
301-
labels["operator"] = data["operator"]
308+
if "operator" not in labels:
309+
labels["operator"] = data["operator"]
302310

303-
if "operation" not in labels:
304-
labels["operation"] = data["operation"]
311+
if "operation" not in labels:
312+
labels["operation"] = data["operation"]
305313

306-
os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels)
314+
os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels)
315+
else:
316+
logger.info(
317+
"No .pyrit_conf found at %s — skipping operator/operation validation. "
318+
"In GUI mode, these labels are set per-user at runtime.",
319+
DEFAULT_CONFIG_PATH,
320+
)

tests/unit/setup/test_airt_initializer.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
import json
45
import os
56
import sys
67
from unittest.mock import patch
@@ -214,6 +215,65 @@ def test_validate_missing_operation_raises_error(self, tmp_path):
214215
):
215216
init._validate_operation_fields()
216217

218+
def test_validate_operation_fields_skips_when_pyrit_conf_missing(self, tmp_path):
219+
"""Test that _validate_operation_fields does not crash when .pyrit_conf is missing.
220+
221+
In container/GUI deployments, .pyrit_conf does not exist. The method should
222+
skip validation gracefully instead of raising FileNotFoundError.
223+
"""
224+
nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf"
225+
init = AIRTInitializer()
226+
with patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path):
227+
# Should not raise
228+
init._validate_operation_fields()
229+
230+
def test_validate_operation_fields_preserves_existing_labels_when_pyrit_conf_missing(self, tmp_path):
231+
"""Test that existing GLOBAL_MEMORY_LABELS are preserved when .pyrit_conf is missing."""
232+
nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf"
233+
init = AIRTInitializer()
234+
with (
235+
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path),
236+
patch.dict("os.environ", {"GLOBAL_MEMORY_LABELS": '{"operator": "gui_user", "operation": "gui_op"}'}),
237+
):
238+
init._validate_operation_fields()
239+
# Existing labels should remain untouched
240+
labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"])
241+
assert labels["operator"] == "gui_user"
242+
assert labels["operation"] == "gui_op"
243+
244+
def test_validate_operation_fields_merges_conf_into_labels(self, tmp_path):
245+
"""Test that .pyrit_conf values are merged into GLOBAL_MEMORY_LABELS when labels are missing."""
246+
conf_file = tmp_path / ".pyrit_conf"
247+
conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"}))
248+
init = AIRTInitializer()
249+
with (
250+
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file),
251+
patch.dict("os.environ", {}, clear=False),
252+
):
253+
# Remove GLOBAL_MEMORY_LABELS if present
254+
os.environ.pop("GLOBAL_MEMORY_LABELS", None)
255+
init._validate_operation_fields()
256+
labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"])
257+
assert labels["operator"] == "conf_user"
258+
assert labels["operation"] == "conf_op"
259+
260+
def test_validate_operation_fields_does_not_overwrite_existing_labels(self, tmp_path):
261+
"""Test that .pyrit_conf values do not overwrite existing GLOBAL_MEMORY_LABELS entries."""
262+
conf_file = tmp_path / ".pyrit_conf"
263+
conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"}))
264+
init = AIRTInitializer()
265+
with (
266+
patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file),
267+
patch.dict(
268+
"os.environ",
269+
{"GLOBAL_MEMORY_LABELS": '{"operator": "existing_user", "operation": "existing_op"}'},
270+
),
271+
):
272+
init._validate_operation_fields()
273+
labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"])
274+
assert labels["operator"] == "existing_user"
275+
assert labels["operation"] == "existing_op"
276+
217277
def test_validate_db_connection_raises_error(self):
218278
"""Test that validate raises error when AZURE_SQL_DB_CONNECTION_STRING is missing."""
219279
del os.environ["AZURE_SQL_DB_CONNECTION_STRING"]

0 commit comments

Comments
 (0)