|
1 | 1 | # Copyright (c) Microsoft Corporation. |
2 | 2 | # Licensed under the MIT license. |
3 | 3 |
|
| 4 | +import json |
4 | 5 | import os |
5 | 6 | import sys |
6 | 7 | from unittest.mock import patch |
@@ -214,6 +215,65 @@ def test_validate_missing_operation_raises_error(self, tmp_path): |
214 | 215 | ): |
215 | 216 | init._validate_operation_fields() |
216 | 217 |
|
| 218 | + def test_validate_operation_fields_skips_when_pyrit_conf_missing(self, tmp_path): |
| 219 | + """Test that _validate_operation_fields does not crash when .pyrit_conf is missing. |
| 220 | +
|
| 221 | + In container/GUI deployments, .pyrit_conf does not exist. The method should |
| 222 | + skip validation gracefully instead of raising FileNotFoundError. |
| 223 | + """ |
| 224 | + nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf" |
| 225 | + init = AIRTInitializer() |
| 226 | + with patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path): |
| 227 | + # Should not raise |
| 228 | + init._validate_operation_fields() |
| 229 | + |
| 230 | + def test_validate_operation_fields_preserves_existing_labels_when_pyrit_conf_missing(self, tmp_path): |
| 231 | + """Test that existing GLOBAL_MEMORY_LABELS are preserved when .pyrit_conf is missing.""" |
| 232 | + nonexistent_path = tmp_path / "nonexistent" / ".pyrit_conf" |
| 233 | + init = AIRTInitializer() |
| 234 | + with ( |
| 235 | + patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", nonexistent_path), |
| 236 | + patch.dict("os.environ", {"GLOBAL_MEMORY_LABELS": '{"operator": "gui_user", "operation": "gui_op"}'}), |
| 237 | + ): |
| 238 | + init._validate_operation_fields() |
| 239 | + # Existing labels should remain untouched |
| 240 | + labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"]) |
| 241 | + assert labels["operator"] == "gui_user" |
| 242 | + assert labels["operation"] == "gui_op" |
| 243 | + |
| 244 | + def test_validate_operation_fields_merges_conf_into_labels(self, tmp_path): |
| 245 | + """Test that .pyrit_conf values are merged into GLOBAL_MEMORY_LABELS when labels are missing.""" |
| 246 | + conf_file = tmp_path / ".pyrit_conf" |
| 247 | + conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"})) |
| 248 | + init = AIRTInitializer() |
| 249 | + with ( |
| 250 | + patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file), |
| 251 | + patch.dict("os.environ", {}, clear=False), |
| 252 | + ): |
| 253 | + # Remove GLOBAL_MEMORY_LABELS if present |
| 254 | + os.environ.pop("GLOBAL_MEMORY_LABELS", None) |
| 255 | + init._validate_operation_fields() |
| 256 | + labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"]) |
| 257 | + assert labels["operator"] == "conf_user" |
| 258 | + assert labels["operation"] == "conf_op" |
| 259 | + |
| 260 | + def test_validate_operation_fields_does_not_overwrite_existing_labels(self, tmp_path): |
| 261 | + """Test that .pyrit_conf values do not overwrite existing GLOBAL_MEMORY_LABELS entries.""" |
| 262 | + conf_file = tmp_path / ".pyrit_conf" |
| 263 | + conf_file.write_text(yaml.dump({"operator": "conf_user", "operation": "conf_op"})) |
| 264 | + init = AIRTInitializer() |
| 265 | + with ( |
| 266 | + patch("pyrit.setup.initializers.airt.DEFAULT_CONFIG_PATH", conf_file), |
| 267 | + patch.dict( |
| 268 | + "os.environ", |
| 269 | + {"GLOBAL_MEMORY_LABELS": '{"operator": "existing_user", "operation": "existing_op"}'}, |
| 270 | + ), |
| 271 | + ): |
| 272 | + init._validate_operation_fields() |
| 273 | + labels = json.loads(os.environ["GLOBAL_MEMORY_LABELS"]) |
| 274 | + assert labels["operator"] == "existing_user" |
| 275 | + assert labels["operation"] == "existing_op" |
| 276 | + |
217 | 277 | def test_validate_db_connection_raises_error(self): |
218 | 278 | """Test that validate raises error when AZURE_SQL_DB_CONNECTION_STRING is missing.""" |
219 | 279 | del os.environ["AZURE_SQL_DB_CONNECTION_STRING"] |
|
0 commit comments