diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 7a27cb5666..af9c2dee62 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -20,7 +20,7 @@ ) from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName if TYPE_CHECKING: from pyrit.executor.attack.core import AttackContext @@ -242,7 +242,7 @@ def get_last_message( def set_system_prompt( self, *, - target: PromptChatTarget, + target: PromptTarget, conversation_id: str, system_prompt: str, labels: Optional[dict[str, str]] = None, @@ -251,11 +251,16 @@ def set_system_prompt( Set or update the system prompt for a conversation. Args: - target: The chat target to set the system prompt on. + target: The target to set the system prompt on. Must handle the + SYSTEM_PROMPT capability (natively or via an ADAPT policy). conversation_id: Unique identifier for the conversation. system_prompt: The system prompt text. labels: Optional labels to associate with the system prompt. + + Raises: + ValueError: If target cannot handle the SYSTEM_PROMPT capability. """ + target.configuration.ensure_can_handle(capability=CapabilityName.SYSTEM_PROMPT) target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, @@ -283,7 +288,7 @@ async def initialize_context_async( 3. Updates context.executed_turns for multi-turn attacks 4. Sets context.next_message if there's an unanswered user message - For PromptChatTarget: + For chat-capable PromptTarget: - Adds prepended messages to memory with simulated_assistant role - All messages get new UUIDs @@ -306,7 +311,7 @@ async def initialize_context_async( Raises: ValueError: If conversation_id is empty, or if prepended_conversation - requires a PromptChatTarget but target is not one. + requires a chat-capable PromptTarget but target is not one. """ if not conversation_id: raise ValueError("conversation_id cannot be empty") @@ -321,8 +326,11 @@ async def initialize_context_async( logger.debug(f"No prepended conversation for context initialization: {conversation_id}") return state - # Handle target type compatibility - is_chat_target = isinstance(target, PromptChatTarget) + # Targets that don't natively support editable history cannot consume a + # prepended multi-message conversation as-is — route them to the + # single-string fallback path. Type identity (PromptChatTarget) is a + # legacy signal for this; capability-based routing is the durable form. + is_chat_target = target.configuration.includes(capability=CapabilityName.EDITABLE_HISTORY) if not is_chat_target: return await self._handle_non_chat_target_async( context=context, @@ -366,8 +374,8 @@ async def _handle_non_chat_target_async( if config.non_chat_target_behavior == "raise": raise ValueError( - "prepended_conversation requires the objective target to be a PromptChatTarget. " - "Non-chat objective targets do not support conversation history. " + "prepended_conversation requires the objective target to be a chat-capable " + "PromptTarget. Non-chat objective targets do not support conversation history. " "Use PrependedConversationConfig with non_chat_target_behavior='normalize_first_turn' " "to normalize the conversation into the first message instead." ) diff --git a/pyrit/executor/attack/component/prepended_conversation_config.py b/pyrit/executor/attack/component/prepended_conversation_config.py index c78ffad767..fddeae5371 100644 --- a/pyrit/executor/attack/component/prepended_conversation_config.py +++ b/pyrit/executor/attack/component/prepended_conversation_config.py @@ -22,7 +22,7 @@ class PrependedConversationConfig: This class provides control over: - Which message roles should have request converters applied - How to normalize conversation history for non-chat objective targets - - What to do when the objective target is not a PromptChatTarget + - What to do when the objective target is not a chat-capable PromptTarget """ # Roles for which request converters should be applied to prepended messages. @@ -36,13 +36,13 @@ class PrependedConversationConfig: # ConversationContextNormalizer is used that produces "Turn N: User/Assistant" format. message_normalizer: Optional[MessageStringNormalizer] = None - # Behavior when the target is a PromptTarget but not a PromptChatTarget: + # Behavior when the target is a PromptTarget but not a chat-capable PromptTarget: # - "normalize_first_turn": Normalize the prepended conversation into a string and # store it in ConversationState.normalized_prepended_context. This context will be # prepended to the first message sent to the target. Uses objective_target_context_normalizer # if provided, otherwise falls back to ConversationContextNormalizer. # - "raise": Raise a ValueError. Use this when prepended conversation history must be - # maintained by the target (i.e., target must be a PromptChatTarget). + # maintained by the target (i.e., target must be a chat-capable PromptTarget). non_chat_target_behavior: Literal["normalize_first_turn", "raise"] = "normalize_first_turn" def get_message_normalizer(self) -> MessageStringNormalizer: diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index dac86d10aa..efe2dcca8e 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -8,7 +8,7 @@ import time from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -27,6 +27,7 @@ ConversationReference, Message, ) +from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from pyrit.executor.attack.core.attack_config import AttackScoringConfig @@ -233,6 +234,10 @@ class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Id Defines the interface for executing attacks and handling results. """ + #: Capability requirements placed on ``objective_target``. Subclasses + #: override to declare what the attack needs. Validated in ``__init__``. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() + def __init__( self, *, @@ -259,6 +264,7 @@ def __init__( ), logger=logger, ) + type(self).TARGET_REQUIREMENTS.validate(target=objective_target) self._objective_target = objective_target self._params_type = params_type # Guard so subclasses that set converters before calling super() aren't clobbered diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 1a70c89195..7280ea838a 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -29,6 +29,8 @@ ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from pyrit.score import TrueFalseScorer @@ -82,6 +84,15 @@ class ChunkedRequestAttack(MultiTurnAttackStrategy[ChunkedRequestAttackContext, """ ).strip() + # Chunked request issues multiple distinct turns that depend on the target + # remembering prior responses. History-squash adaptation would collapse + # them into a single prompt and silently break the attack's semantics. + # Declare MULTI_TURN as ``native_required`` so adaptation is rejected at + # construction time. + TARGET_REQUIREMENTS = TargetRequirements( + native_required=frozenset({CapabilityName.MULTI_TURN}), + ) + @apply_defaults def __init__( self, @@ -226,16 +237,7 @@ async def _setup_async(self, *, context: ChunkedRequestAttackContext) -> None: Args: context (ChunkedRequestAttackContext): The attack context containing attack parameters. - - Raises: - ValueError: If the objective target does not support multi-turn conversations. """ - if not self._objective_target.capabilities.supports_multi_turn: - raise ValueError( - "ChunkedRequestAttack requires a multi-turn target. " - "The objective target does not support multi-turn conversations." - ) - # Ensure the context has a session context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 4a180d5df3..434b4280b6 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -43,7 +43,9 @@ SeedPrompt, ) from pyrit.prompt_normalizer import PromptNormalizer -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import ( FloatScaleThresholdScorer, Scorer, @@ -112,6 +114,16 @@ class CrescendoAttack(MultiTurnAttackStrategy[CrescendoAttackContext, CrescendoA You can learn more about the Crescendo attack [@russinovich2024crescendo]. """ + # Crescendo fundamentally relies on multi-turn conversation history to + # gradually escalate prompts; history-squash adaptation would collapse the + # conversation into a single prompt and silently break the attack's + # semantics. Declare MULTI_TURN as native_required so adaptation is + # rejected at construction time. + TARGET_REQUIREMENTS = TargetRequirements( + required=frozenset({CapabilityName.EDITABLE_HISTORY, CapabilityName.MULTI_TURN}), + native_required=frozenset({CapabilityName.MULTI_TURN}), + ) + # Default system prompt template path for Crescendo attack DEFAULT_ADVERSARIAL_CHAT_SYSTEM_PROMPT_TEMPLATE_PATH: Path = ( Path(EXECUTOR_SEED_PROMPT_PATH) / "crescendo" / "crescendo_variant_1.yaml" @@ -121,7 +133,7 @@ class CrescendoAttack(MultiTurnAttackStrategy[CrescendoAttackContext, CrescendoA def __init__( self, *, - objective_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] attack_adversarial_config: AttackAdversarialConfig, attack_converter_config: Optional[AttackConverterConfig] = None, attack_scoring_config: Optional[AttackScoringConfig] = None, @@ -134,7 +146,8 @@ def __init__( Initialize the Crescendo attack strategy. Args: - objective_target (PromptChatTarget): The target system to attack. Must be a PromptChatTarget. + objective_target (PromptTarget): The target system to attack. Must + support editable conversation history. attack_adversarial_config (AttackAdversarialConfig): Configuration for the adversarial component, including the adversarial chat target and optional system prompt path. attack_converter_config (Optional[AttackConverterConfig]): Configuration for attack converters, @@ -148,7 +161,7 @@ def __init__( application by role, message normalization, and non-chat target behavior. Raises: - ValueError: If objective_target is not a PromptChatTarget. + ValueError: If objective_target does not natively support editable history. """ # Initialize base class super().__init__(objective_target=objective_target, logger=logger, context_type=CrescendoAttackContext) @@ -257,17 +270,7 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None: Args: context (CrescendoAttackContext): Attack context with configuration - - Raises: - ValueError: If the objective target does not support multi-turn conversations. """ - if not self._objective_target.capabilities.supports_multi_turn: - raise ValueError( - "CrescendoAttack requires a multi-turn target. Crescendo fundamentally relies on " - "multi-turn conversation history to gradually escalate prompts. " - "Use RedTeamingAttack or TreeOfAttacksWithPruning instead." - ) - # Ensure the context has a session context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index a9d4b75adc..3239c5568e 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -29,6 +29,8 @@ ) from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import Scorer if TYPE_CHECKING: @@ -123,6 +125,15 @@ class MultiPromptSendingAttack(MultiTurnAttackStrategy[MultiTurnAttackContext[An and multiple scorer types for comprehensive evaluation. """ + # Sending a sequence of distinct prompts depends on the target maintaining + # conversation state between them. History-squash adaptation would collapse + # them into one message and silently break the attack's sequencing + # semantics. Declare MULTI_TURN as ``native_required`` so adaptation is + # rejected at construction time. + TARGET_REQUIREMENTS = TargetRequirements( + native_required=frozenset({CapabilityName.MULTI_TURN}), + ) + @apply_defaults def __init__( self, @@ -204,16 +215,7 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: Args: context (MultiTurnAttackContext): The attack context containing attack parameters. - - Raises: - ValueError: If the objective target does not support multi-turn conversations. """ - if not self._objective_target.capabilities.supports_multi_turn: - raise ValueError( - "MultiPromptSendingAttack requires a multi-turn target. " - "The objective target does not support multi-turn conversations." - ) - # Ensure the context has a session (like red_teaming.py does) context.session = ConversationSession() diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 04c8084f7b..007845d864 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -18,6 +18,7 @@ ) from pyrit.memory import CentralMemory from pyrit.models import ConversationReference, ConversationType +from pyrit.prompt_target.common.target_capabilities import CapabilityName if TYPE_CHECKING: from pyrit.models import ( @@ -117,7 +118,7 @@ def _rotate_conversation_for_single_turn_target( Args: context: The current attack context. """ - if self._objective_target.capabilities.supports_multi_turn: + if self._objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): return if context.executed_turns == 0: diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 7ea7f927b7..e8da296d60 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -50,7 +50,9 @@ SeedPrompt, ) from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptChatTarget, PromptTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.score import ( FloatScaleThresholdScorer, Scorer, @@ -64,6 +66,13 @@ logger = logging.getLogger(__name__) +# TAP sets a system prompt on its adversarial target and drives a multi-turn dialogue through it. +# Both capabilities must be natively supported — adaptation would silently change the semantics +# (e.g. history-squash normalization would collapse the escalation into a single turn). +_ADVERSARIAL_REQUIREMENTS = TargetRequirements( + native_required=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}), +) + class TAPAttackScoringConfig(AttackScoringConfig): """ @@ -257,7 +266,7 @@ class _TreeOfAttacksNode: def __init__( self, *, - objective_target: PromptChatTarget, + objective_target: PromptTarget, adversarial_chat: PromptChatTarget, adversarial_chat_seed_prompt: SeedPrompt, adversarial_chat_prompt_template: SeedPrompt, @@ -279,7 +288,7 @@ def __init__( Initialize a tree node. Args: - objective_target (PromptChatTarget): The target to attack. + objective_target (PromptTarget): The target to attack. adversarial_chat (PromptChatTarget): The chat target for generating adversarial prompts. adversarial_chat_seed_prompt (SeedPrompt): The seed prompt for the first turn. adversarial_chat_prompt_template (SeedPrompt): The template for subsequent turns. @@ -780,7 +789,7 @@ def duplicate(self) -> "_TreeOfAttacksNode": # For single-turn targets, duplicate only the system messages (e.g., system prompt # from prepended conversation) so the target retains its configuration without # carrying over attack turn history that would cause validation errors. - if self._objective_target.capabilities.supports_multi_turn: + if self._objective_target.configuration.includes(capability=CapabilityName.MULTI_TURN): duplicate_node.objective_target_conversation_id = self._memory.duplicate_conversation( conversation_id=self.objective_target_conversation_id ) @@ -1254,7 +1263,7 @@ class TreeOfAttacksWithPruningAttack(AttackStrategy[TAPAttackContext, TAPAttackR def __init__( self, *, - objective_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] attack_adversarial_config: AttackAdversarialConfig, attack_converter_config: Optional[AttackConverterConfig] = None, attack_scoring_config: Optional[AttackScoringConfig] = None, @@ -1271,7 +1280,7 @@ def __init__( Initialize the Tree of Attacks with Pruning attack strategy. Args: - objective_target (PromptChatTarget): The target system to attack. + objective_target (PromptTarget): The target system to attack. attack_adversarial_config (AttackAdversarialConfig): Configuration for the adversarial chat component. attack_converter_config (Optional[AttackConverterConfig]): Configuration for attack converters. Defaults to None. @@ -1293,7 +1302,8 @@ def __init__( Raises: ValueError: If attack_scoring_config uses a non-FloatScaleThresholdScorer objective scorer, - if target is not PromptChatTarget, or if parameters are invalid. + if the adversarial target does not natively support the capabilities TAP needs, + or if parameters are invalid. """ # Validate tree parameters if tree_depth < 1: @@ -1322,8 +1332,14 @@ def __init__( # Initialize adversarial configuration self._adversarial_chat = attack_adversarial_config.target - if not isinstance(self._adversarial_chat, PromptChatTarget): - raise ValueError("The adversarial target must be a PromptChatTarget for TAP attack.") + # TAP sets a system prompt on the adversarial target and drives a + # multi-turn dialogue through it; both capabilities must be native. + # (The class-level ``TARGET_REQUIREMENTS`` inherited from ``AttackStrategy`` + # only covers ``objective_target``; this is a separate target.) + try: + _ADVERSARIAL_REQUIREMENTS.validate(target=self._adversarial_chat) + except ValueError as exc: + raise ValueError(f"TreeOfAttacksWithPruningAttack {exc}") from exc # Load system prompts self._adversarial_chat_system_prompt_path = ( @@ -1857,7 +1873,7 @@ def _create_attack_node( generate adversarial prompts and evaluate responses. """ node = _TreeOfAttacksNode( - objective_target=cast("PromptChatTarget", self._objective_target), + objective_target=self._objective_target, adversarial_chat=self._adversarial_chat, adversarial_chat_seed_prompt=self._adversarial_chat_seed_prompt, adversarial_chat_system_seed_prompt=self._adversarial_chat_system_seed_prompt, diff --git a/pyrit/prompt_converter/denylist_converter.py b/pyrit/prompt_converter/denylist_converter.py index 46f427caef..7cbc6f5e9e 100644 --- a/pyrit/prompt_converter/denylist_converter.py +++ b/pyrit/prompt_converter/denylist_converter.py @@ -10,7 +10,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,14 +19,14 @@ class DenylistConverter(LLMGenericTextConverter): """ Replaces forbidden words or phrases in a prompt with synonyms using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, denylist: list[str] | None = None, ): @@ -34,7 +34,7 @@ def __init__( Initialize the converter with a target, an optional system prompt template, and a denylist. Args: - converter_target (PromptChatTarget): The target for the prompt conversion. + converter_target (PromptTarget): The target for the prompt conversion. Can be omitted if a default has been configured via PyRIT initialization. system_prompt_template (Optional[SeedPrompt]): The system prompt template to use for the conversion. If not provided, a default template will be used. diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index f56990247f..9cb15ef1f6 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -15,7 +15,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -27,12 +27,13 @@ class LLMGenericTextConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, user_prompt_template_with_objective: Optional[SeedPrompt] = None, **kwargs: Any, @@ -41,8 +42,10 @@ def __init__( Initialize the converter with a target and optional prompt templates. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. - Can be omitted if a default has been configured via PyRIT initialization. + converter_target (PromptTarget): The endpoint that converts the prompt. Must satisfy + ``CHAT_CONSUMER_REQUIREMENTS`` (multi-turn + editable history capabilities, possibly + via normalization-pipeline adaptation). Can be omitted if a default has been configured + via PyRIT initialization. system_prompt_template (SeedPrompt, Optional): The prompt template to set as the system prompt. user_prompt_template_with_objective (SeedPrompt, Optional): The prompt template to set as the user prompt. expects @@ -51,6 +54,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ + super().__init__(converter_target=converter_target) self._converter_target = converter_target self._system_prompt_template = system_prompt_template self._prompt_kwargs = kwargs diff --git a/pyrit/prompt_converter/malicious_question_generator_converter.py b/pyrit/prompt_converter/malicious_question_generator_converter.py index 41a7848458..5725fff9c4 100644 --- a/pyrit/prompt_converter/malicious_question_generator_converter.py +++ b/pyrit/prompt_converter/malicious_question_generator_converter.py @@ -10,7 +10,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,21 +19,21 @@ class MaliciousQuestionGeneratorConverter(LLMGenericTextConverter): """ Generates malicious questions using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with a specific target and template. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. """ diff --git a/pyrit/prompt_converter/math_prompt_converter.py b/pyrit/prompt_converter/math_prompt_converter.py index fd6491bbc1..a3520b190c 100644 --- a/pyrit/prompt_converter/math_prompt_converter.py +++ b/pyrit/prompt_converter/math_prompt_converter.py @@ -10,7 +10,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,21 +19,21 @@ class MathPromptConverter(LLMGenericTextConverter): """ Converts natural language instructions into symbolic mathematics problems using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with a specific target and template. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. """ diff --git a/pyrit/prompt_converter/noise_converter.py b/pyrit/prompt_converter/noise_converter.py index 0d7bdf302f..86c5375773 100644 --- a/pyrit/prompt_converter/noise_converter.py +++ b/pyrit/prompt_converter/noise_converter.py @@ -11,7 +11,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -20,14 +20,14 @@ class NoiseConverter(LLMGenericTextConverter): """ Injects noise errors into a conversation using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] noise: Optional[str] = None, number_errors: int = 5, prompt_template: Optional[SeedPrompt] = None, @@ -36,7 +36,7 @@ def __init__( Initialize the converter with the specified parameters. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. noise (str): The noise to inject. Grammar error, delete random letter, insert random space, etc. number_errors (int): The number of errors to inject. diff --git a/pyrit/prompt_converter/persuasion_converter.py b/pyrit/prompt_converter/persuasion_converter.py index 11b6bd66e6..090cd6117d 100644 --- a/pyrit/prompt_converter/persuasion_converter.py +++ b/pyrit/prompt_converter/persuasion_converter.py @@ -21,7 +21,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -47,19 +47,20 @@ class PersuasionConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] persuasion_technique: str, ): """ Initialize the converter with the specified target and prompt template. Args: - converter_target (PromptChatTarget): The chat target used to perform rewriting on user prompts. + converter_target (PromptTarget): The chat target used to perform rewriting on user prompts. Can be omitted if a default has been configured via PyRIT initialization. persuasion_technique (str): Persuasion technique to be used by the converter, determines the system prompt to be used to generate new prompts. Must be one of "authority_endorsement", "evidence_based", @@ -69,6 +70,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the persuasion technique is not supported or does not exist. """ + super().__init__(converter_target=converter_target) self.converter_target = converter_target try: diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 141076e701..88ca34ea44 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -6,11 +6,15 @@ import inspect import re from dataclasses import dataclass -from typing import Any, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, get_args from pyrit import prompt_converter from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.models import PromptDataType +from pyrit.prompt_target.common.target_requirements import TargetRequirements + +if TYPE_CHECKING: + from pyrit.prompt_target import PromptTarget @dataclass @@ -48,6 +52,11 @@ class PromptConverter(Identifiable): #: Tuple of output modalities supported by this converter. Subclasses must override this. SUPPORTED_OUTPUT_TYPES: tuple[PromptDataType, ...] = () + #: Capability requirements placed on the converter's target (if any). + #: Subclasses that use a target should override this and pass the target to + #: ``super().__init__(converter_target=...)`` so the base class can validate it. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() + _identifier: Optional[ComponentIdentifier] = None def __init_subclass__(cls, **kwargs: object) -> None: @@ -75,11 +84,17 @@ def __init_subclass__(cls, **kwargs: object) -> None: f"Declare the output modalities this converter produces." ) - def __init__(self) -> None: + def __init__(self, *, converter_target: Optional["PromptTarget"] = None) -> None: """ Initialize the prompt converter. + + Args: + converter_target (Optional[PromptTarget]): Target used by the converter, if any. When + provided, it is validated against ``TARGET_REQUIREMENTS``. """ super().__init__() + if converter_target is not None: + type(self).TARGET_REQUIREMENTS.validate(target=converter_target) @abc.abstractmethod async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: diff --git a/pyrit/prompt_converter/random_translation_converter.py b/pyrit/prompt_converter/random_translation_converter.py index 74953c2603..7e11810323 100644 --- a/pyrit/prompt_converter/random_translation_converter.py +++ b/pyrit/prompt_converter/random_translation_converter.py @@ -13,7 +13,7 @@ from pyrit.prompt_converter.prompt_converter import ConverterResult from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy from pyrit.prompt_converter.word_level_converter import WordLevelConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -22,7 +22,7 @@ class RandomTranslationConverter(LLMGenericTextConverter, WordLevelConverter): """ Translates each individual word in a prompt to a random language using an LLM. - An existing ``PromptChatTarget`` is used to perform the translation (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the translation (like Azure OpenAI). """ SUPPORTED_INPUT_TYPES = ("text",) @@ -35,7 +35,7 @@ class RandomTranslationConverter(LLMGenericTextConverter, WordLevelConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, languages: Optional[list[str]] = None, word_selection_strategy: Optional[WordSelectionStrategy] = None, @@ -44,7 +44,7 @@ def __init__( Initialize the converter with a target, an optional system prompt template, and language options. Args: - converter_target (PromptChatTarget): The target for the prompt conversion. + converter_target (PromptTarget): The target for the prompt conversion. Can be omitted if a default has been configured via PyRIT initialization. system_prompt_template (Optional[SeedPrompt]): The system prompt template to use for the conversion. If not provided, a default template will be used. diff --git a/pyrit/prompt_converter/scientific_translation_converter.py b/pyrit/prompt_converter/scientific_translation_converter.py index 2a6c965996..bdc7987041 100644 --- a/pyrit/prompt_converter/scientific_translation_converter.py +++ b/pyrit/prompt_converter/scientific_translation_converter.py @@ -10,7 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ class ScientificTranslationConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] mode: str = "combined", prompt_template: Optional[SeedPrompt] = None, ) -> None: @@ -53,7 +53,7 @@ def __init__( Initialize the scientific translation converter. Args: - converter_target (PromptChatTarget): The LLM target to perform the conversion. + converter_target (PromptTarget): The LLM target to perform the conversion. mode (str): The translation mode to use. Built-in options are: - ``academic``: Use academic/homework style framing diff --git a/pyrit/prompt_converter/tense_converter.py b/pyrit/prompt_converter/tense_converter.py index 237a2934d5..eede7adef9 100644 --- a/pyrit/prompt_converter/tense_converter.py +++ b/pyrit/prompt_converter/tense_converter.py @@ -10,7 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,14 +19,14 @@ class TenseConverter(LLMGenericTextConverter): """ Converts a conversation to a different tense using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] tense: str, prompt_template: Optional[SeedPrompt] = None, ): @@ -34,7 +34,7 @@ def __init__( Initialize the converter with the target chat support, tense, and optional prompt template. Args: - converter_target (PromptChatTarget): The target chat support for the conversion which will translate. + converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. tense (str): The tense the converter should convert the prompt to. E.g. past, present, future. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. diff --git a/pyrit/prompt_converter/tone_converter.py b/pyrit/prompt_converter/tone_converter.py index a7b8e5a9f1..4a6d0e859e 100644 --- a/pyrit/prompt_converter/tone_converter.py +++ b/pyrit/prompt_converter/tone_converter.py @@ -10,7 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -19,14 +19,14 @@ class ToneConverter(LLMGenericTextConverter): """ Converts a conversation to a different tone using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). """ @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] tone: str, prompt_template: Optional[SeedPrompt] = None, ): @@ -34,7 +34,7 @@ def __init__( Initialize the converter with the target chat support, tone, and optional prompt template. Args: - converter_target (PromptChatTarget): The target chat support for the conversion which will translate. + converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. tone (str): The tone for the conversation. E.g. upset, sarcastic, indifferent, etc. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. diff --git a/pyrit/prompt_converter/toxic_sentence_generator_converter.py b/pyrit/prompt_converter/toxic_sentence_generator_converter.py index d3390c6af7..636e50ad8d 100644 --- a/pyrit/prompt_converter/toxic_sentence_generator_converter.py +++ b/pyrit/prompt_converter/toxic_sentence_generator_converter.py @@ -14,7 +14,7 @@ from pyrit.models import PromptDataType, SeedPrompt from pyrit.prompt_converter.llm_generic_text_converter import LLMGenericTextConverter from pyrit.prompt_converter.prompt_converter import ConverterResult -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import PromptTarget logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ class ToxicSentenceGeneratorConverter(LLMGenericTextConverter): """ Generates toxic sentence starters using an LLM. - An existing ``PromptChatTarget`` is used to perform the conversion (like Azure OpenAI). + An existing ``PromptTarget`` is used to perform the conversion (like Azure OpenAI). Based on Project Moonshot's attack module that generates toxic sentences to test LLM safety guardrails: @@ -34,14 +34,14 @@ class ToxicSentenceGeneratorConverter(LLMGenericTextConverter): def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with a specific target and template. Args: - converter_target (PromptChatTarget): The endpoint that converts the prompt. + converter_target (PromptTarget): The endpoint that converts the prompt. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. If not provided, defaults to the ``toxic_sentence_generator.yaml``. diff --git a/pyrit/prompt_converter/translation_converter.py b/pyrit/prompt_converter/translation_converter.py index 911f72ab57..dfb67682f9 100644 --- a/pyrit/prompt_converter/translation_converter.py +++ b/pyrit/prompt_converter/translation_converter.py @@ -24,7 +24,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -36,12 +36,13 @@ class TranslationConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] language: str, prompt_template: Optional[SeedPrompt] = None, max_retries: int = 3, @@ -51,7 +52,7 @@ def __init__( Initialize the converter with the target chat support, language, and optional prompt template. Args: - converter_target (PromptChatTarget): The target chat support for the conversion which will translate. + converter_target (PromptTarget): The target chat support for the conversion which will translate. Can be omitted if a default has been configured via PyRIT initialization. language (str): The language for the conversion. E.g. Spanish, French, leetspeak, etc. prompt_template (SeedPrompt, Optional): The prompt template for the conversion. @@ -62,6 +63,7 @@ def __init__( ValueError: If converter_target is not provided and no default has been configured. ValueError: If the language is not provided. """ + super().__init__(converter_target=converter_target) self.converter_target = converter_target # Retry strategy for the conversion diff --git a/pyrit/prompt_converter/variation_converter.py b/pyrit/prompt_converter/variation_converter.py index 328e463072..ca9c86aa6e 100644 --- a/pyrit/prompt_converter/variation_converter.py +++ b/pyrit/prompt_converter/variation_converter.py @@ -23,7 +23,7 @@ SeedPrompt, ) from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget logger = logging.getLogger(__name__) @@ -35,19 +35,20 @@ class VariationConverter(PromptConverter): SUPPORTED_INPUT_TYPES = ("text",) SUPPORTED_OUTPUT_TYPES = ("text",) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS @apply_defaults def __init__( self, *, - converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] + converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[assignment] prompt_template: Optional[SeedPrompt] = None, ): """ Initialize the converter with the specified target and prompt template. Args: - converter_target (PromptChatTarget): The target to which the prompt will be sent for conversion. + converter_target (PromptTarget): The target to which the prompt will be sent for conversion. Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt, optional): The template used for generating the system prompt. If not provided, a default template will be used. @@ -55,6 +56,7 @@ def __init__( Raises: ValueError: If converter_target is not provided and no default has been configured. """ + super().__init__(converter_target=converter_target) self.converter_target = converter_target # set to default strategy if not provided diff --git a/pyrit/prompt_converter/word_level_converter.py b/pyrit/prompt_converter/word_level_converter.py index 7753c49528..50368aed9c 100644 --- a/pyrit/prompt_converter/word_level_converter.py +++ b/pyrit/prompt_converter/word_level_converter.py @@ -33,6 +33,7 @@ def __init__( *, word_selection_strategy: Optional[WordSelectionStrategy] = None, word_split_separator: Optional[str] = " ", + **kwargs, ): """ Initialize the converter with the specified selection strategy. @@ -42,8 +43,10 @@ def __init__( words to convert. If None, all words will be converted. Defaults to None. word_split_separator (Optional[str]): Separator used to split words in the input text. If None, splits by any whitespace. Defaults to " ". + **kwargs: Forwarded to ``PromptConverter.__init__`` to support cooperative multiple inheritance + (e.g., ``converter_target`` when mixed with LLM-based converters). """ - super().__init__() + super().__init__(**kwargs) self._word_selection_strategy = word_selection_strategy or AllWordsSelectionStrategy() self._word_split_separator = word_split_separator diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index c71dca4089..db24087d22 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -20,7 +20,10 @@ UnsupportedCapabilityBehavior, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration -from pyrit.prompt_target.common.target_requirements import TargetRequirements +from pyrit.prompt_target.common.target_requirements import ( + CHAT_CONSUMER_REQUIREMENTS, + TargetRequirements, +) from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget from pyrit.prompt_target.http_target.http_target import HTTPTarget @@ -51,6 +54,7 @@ "AzureMLChatTarget", "CapabilityName", "CapabilityHandlingPolicy", + "CHAT_CONSUMER_REQUIREMENTS", "CopilotType", "ConversationNormalizationPipeline", "GandalfLevel", diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index ce1f254678..c707b7fcad 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -3,11 +3,10 @@ from typing import Optional -from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.prompt_target import PromptTarget -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration @@ -27,6 +26,7 @@ class PromptChatTarget(PromptTarget): supports_multi_turn=True, supports_multi_message_pieces=True, supports_system_prompt=True, + supports_editable_history=True, ) ) @@ -64,37 +64,6 @@ def __init__( custom_capabilities=custom_capabilities, ) - def set_system_prompt( - self, - *, - system_prompt: str, - conversation_id: str, - attack_identifier: Optional[ComponentIdentifier] = None, - labels: Optional[dict[str, str]] = None, - ) -> None: - """ - Set the system prompt for the prompt target. May be overridden by subclasses. - - Raises: - RuntimeError: If the conversation already exists. - """ - messages = self._memory.get_conversation(conversation_id=conversation_id) - - if messages: - raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") - - self._memory.add_message_to_memory( - request=MessagePiece( - role="system", - conversation_id=conversation_id, - original_value=system_prompt, - converted_value=system_prompt, - prompt_target_identifier=self.get_identifier(), - attack_identifier=attack_identifier, - labels=labels, - ).to_message() - ) - def is_response_format_json(self, message_piece: MessagePiece) -> bool: """ Check if the response format is JSON and ensure the target supports it. @@ -128,7 +97,7 @@ def _get_json_response_config(self, *, message_piece: MessagePiece) -> _JsonResp """ config = _JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata) - if config.enabled and not self.capabilities.supports_json_output: + if config.enabled and not self.configuration.includes(capability=CapabilityName.JSON_OUTPUT): target_name = self.get_identifier().class_name raise ValueError(f"This target {target_name} does not support JSON response format.") diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 53d7d2085a..5c9712d21b 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -8,8 +8,8 @@ from pyrit.identifiers import ComponentIdentifier, Identifiable from pyrit.memory import CentralMemory, MemoryInterface -from pyrit.models import Message -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target.common.target_capabilities import CapabilityName, TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration, resolve_configuration_compat logger = logging.getLogger(__name__) @@ -178,7 +178,7 @@ def _validate_request(self, *, normalized_conversation: list[Message]) -> None: custom_configuration_message = ( "If your target does support this, set the custom_configuration parameter accordingly." ) - if not self.capabilities.supports_multi_message_pieces and n_pieces != 1: + if not self.configuration.includes(capability=CapabilityName.MULTI_MESSAGE_PIECES) and n_pieces != 1: raise ValueError( f"This target only supports a single message piece. Received: {n_pieces} pieces. " f"{custom_configuration_message}" @@ -194,7 +194,7 @@ def _validate_request(self, *, normalized_conversation: list[Message]) -> None: f"{custom_configuration_message}" ) - if not self.capabilities.supports_multi_turn and len(normalized_conversation) > 1: + if not self.configuration.includes(capability=CapabilityName.MULTI_TURN) and len(normalized_conversation) > 1: raise ValueError(f"This target only supports a single turn conversation. {custom_configuration_message}") async def _get_normalized_conversation_async(self, *, message: Message) -> list[Message]: @@ -270,6 +270,56 @@ def set_model_name(self, *, model_name: str) -> None: """ self._model_name = model_name + def set_system_prompt( + self, + *, + system_prompt: str, + conversation_id: str, + attack_identifier: ComponentIdentifier | None = None, + labels: dict[str, str] | None = None, + ) -> None: + """ + Inject a system prompt into memory for the given conversation. + + Writes a ``system``-role message so the target's normalization pipeline + (or the target itself, when it natively supports system prompts) will + pick it up on the next ``send_prompt_async`` call. + + If the target does not natively support system prompts, whether this + call is ultimately honored depends on the target's + :class:`CapabilityHandlingPolicy`: + + * ``ADAPT`` — the normalization pipeline (e.g. system squash) will + fold the system message into user content on the wire. + * ``RAISE`` — the first send after the system prompt is set will + raise, because the pipeline cannot adapt the missing capability. + + Args: + system_prompt (str): The system prompt text to set. + conversation_id (str): The conversation id to attach the prompt to. + attack_identifier (ComponentIdentifier | None): Optional attack identifier. + labels (dict[str, str] | None): Optional labels. + + Raises: + RuntimeError: If the conversation already has messages. + """ + messages = self._memory.get_conversation(conversation_id=conversation_id) + + if messages: + raise RuntimeError("Conversation already exists, system prompt needs to be set at the beginning") + + self._memory.add_message_to_memory( + request=MessagePiece( + role="system", + conversation_id=conversation_id, + original_value=system_prompt, + converted_value=system_prompt, + prompt_target_identifier=self.get_identifier(), + attack_identifier=attack_identifier, + labels=labels, + ).to_message() + ) + def dispose_db_engine(self) -> None: """ Dispose database engine to release database connections and resources. diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 58c6a7e05c..6ae9ed69e2 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -190,6 +190,7 @@ def get_known_capabilities(underlying_model: str) -> "Optional[TargetCapabilitie supports_multi_message_pieces=True, supports_system_prompt=True, supports_json_output=True, + supports_editable_history=True, input_modalities=_TEXT_IMAGE_INPUT, output_modalities=_TEXT_OUTPUT, ) @@ -200,6 +201,7 @@ def get_known_capabilities(underlying_model: str) -> "Optional[TargetCapabilitie supports_system_prompt=True, supports_json_schema=True, supports_json_output=True, + supports_editable_history=True, input_modalities=_TEXT_IMAGE_INPUT, output_modalities=_TEXT_OUTPUT, ) diff --git a/pyrit/prompt_target/common/target_requirements.py b/pyrit/prompt_target/common/target_requirements.py index 95182b47b5..9118d7f427 100644 --- a/pyrit/prompt_target/common/target_requirements.py +++ b/pyrit/prompt_target/common/target_requirements.py @@ -6,9 +6,10 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING +from pyrit.prompt_target.common.target_capabilities import CapabilityName + if TYPE_CHECKING: - from pyrit.prompt_target.common.target_capabilities import CapabilityName - from pyrit.prompt_target.common.target_configuration import TargetConfiguration + from pyrit.prompt_target.common.prompt_target import PromptTarget @dataclass(frozen=True) @@ -17,38 +18,51 @@ class TargetRequirements: Declarative description of what a consumer (attack, converter, scorer) requires from a target. - Consumers define their requirements once and validate them against a - ``TargetConfiguration`` at construction time. This replaces ad-hoc - ``isinstance`` checks and scattered capability branching. + The single source of truth for capability names is the + :class:`CapabilityName` enum; this class is simply a typed wrapper + around the set of capabilities a consumer needs. + + Two tiers of requirement are supported: + + * ``required`` \u2014 satisfied either by native support on the target or + by an ``ADAPT`` entry in the target's + :class:`CapabilityHandlingPolicy`. Use this when the consumer only + needs the behavior to appear on the wire. + * ``native_required`` \u2014 must be natively supported. Adaptation is + rejected. Use this when adaptation would silently change the + consumer's semantics (e.g. an attack that depends on the target + remembering prior turns, where history-squash normalization would + collapse the conversation into a single prompt). """ - # The set of capabilities the consumer requires. - required_capabilities: frozenset[CapabilityName] = field(default_factory=frozenset) + required: frozenset[CapabilityName] = field(default_factory=frozenset) + native_required: frozenset[CapabilityName] = field(default_factory=frozenset) - def validate(self, *, configuration: TargetConfiguration) -> None: + def validate(self, *, target: PromptTarget) -> None: """ - Validate that the target configuration can satisfy all requirements. - - Iterates over every required capability and delegates to - ``TargetConfiguration.ensure_can_handle``, which checks native support - first and then consults the handling policy. All violations are - collected and reported in a single ``ValueError``. + Validate that ``target`` can satisfy every declared requirement. Args: - configuration (TargetConfiguration): The target configuration to validate against. + target (PromptTarget): The target to validate against. Raises: - ValueError: If any required capability is missing and the policy - does not allow adaptation. + ValueError: If any ``native_required`` capability is not natively + supported, or if any ``required`` capability is not supported + natively and has no ``ADAPT`` entry in the target's policy. """ - errors: list[str] = [] - for capability in sorted(self.required_capabilities, key=lambda c: c.value): - try: - configuration.ensure_can_handle(capability=capability) - except ValueError as exc: - errors.append(str(exc)) - if errors: - raise ValueError( - f"Target does not satisfy {len(errors)} required capability(ies):\n" - + "\n".join(f" - {e}" for e in errors) - ) + for capability in self.native_required: + if not target.configuration.includes(capability=capability): + raise ValueError( + f"Target must natively support '{capability.value}'; " + "adaptation is not acceptable for this consumer." + ) + for capability in self.required: + target.configuration.ensure_can_handle(capability=capability) + + +# Shared requirement used by scorers and converters that set a system prompt +# and drive a short multi-turn conversation. Adaptation is acceptable, native +# support is not required. +CHAT_CONSUMER_REQUIREMENTS = TargetRequirements( + required=frozenset({CapabilityName.EDITABLE_HISTORY, CapabilityName.MULTI_TURN}), +) diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index db059d6807..5e45136043 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -72,6 +72,7 @@ class OpenAIChatTarget(OpenAITarget, PromptChatTarget): supports_json_output=True, supports_multi_message_pieces=True, supports_system_prompt=True, + supports_editable_history=True, input_modalities=frozenset( {frozenset({"text"}), frozenset({"image_path"}), frozenset({"text", "image_path"})} ), diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 283636f081..bd7e40292b 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -14,7 +14,7 @@ import uuid from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, ClassVar, Optional, Union, cast from tqdm.auto import tqdm @@ -25,6 +25,7 @@ from pyrit.models import AttackResult from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult from pyrit.prompt_target import OpenAIChatTarget, PromptTarget +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.registry import ScorerRegistry from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique @@ -50,6 +51,10 @@ class Scenario(ABC): aggregates the results into a ScenarioResult. """ + #: Capability requirements placed on ``objective_target``. Subclasses override to declare + #: what the scenario needs. Validated in ``initialize_async`` once the target is supplied. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() + def __init__( self, *, @@ -316,6 +321,7 @@ async def initialize_async( # Set instance variables from parameters self._objective_target = objective_target self._objective_target_identifier = objective_target.get_identifier() + type(self).TARGET_REQUIREMENTS.validate(target=objective_target) self._dataset_config_provided = dataset_config is not None self._dataset_config = dataset_config if dataset_config else self.default_dataset_config() self._max_concurrency = max_concurrency diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index b8963b9264..0d7be337fb 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -28,6 +28,8 @@ PromptConverterConfiguration, ) from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.dataset_configuration import DatasetConfiguration @@ -147,6 +149,13 @@ class Psychosocial(Scenario): VERSION: int = 1 + #: Psychosocial runs CrescendoAttack, which requires the target to natively support + #: editable conversation history (for backtracking). Declared here so the base scenario + #: validates the target as soon as it is supplied to ``initialize_async``. + TARGET_REQUIREMENTS = TargetRequirements( + native_required=frozenset({CapabilityName.EDITABLE_HISTORY}), + ) + # Set up default subharm configurations # Each subharm (e.g., 'imminent_crisis', 'licensed_therapist') can have unique escalation/scoring # The key is the harm_category_filter value from the strategy @@ -421,10 +430,6 @@ def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScore async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: if self._objective_target is None: raise ValueError("objective_target must be set before creating attacks") - if not isinstance(self._objective_target, PromptChatTarget): - raise TypeError( - f"PsychosocialHarmsScenario requires a PromptChatTarget, got {type(self._objective_target).__name__}" - ) resolved = self._resolve_seed_groups() self._seed_groups = resolved.seed_groups diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index a117034b3b..8e36895ee3 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -7,7 +7,7 @@ from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.identifiers import ComponentIdentifier from pyrit.models import PromptDataType, Score, UnvalidatedScore -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.score.scorer import Scorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -24,14 +24,16 @@ class FloatScaleScorer(Scorer): is scored independently, returning one score per piece. """ - def __init__(self, *, validator: ScorerPromptValidator) -> None: + def __init__(self, *, validator: ScorerPromptValidator, chat_target: Optional[PromptTarget] = None) -> None: """ Initialize the FloatScaleScorer. Args: validator: A validator object used to validate scores. + chat_target: Optional chat target used by the scorer, forwarded to the base class + for validation against ``TARGET_REQUIREMENTS``. """ - super().__init__(validator=validator) + super().__init__(validator=validator, chat_target=chat_target) def validate_return_scores(self, scores: list[Score]) -> None: """ @@ -70,7 +72,7 @@ def get_scorer_metrics(self) -> Optional["HarmScorerMetrics"]: async def _score_value_with_llm( self, *, - prompt_target: PromptChatTarget, + prompt_target: PromptTarget, system_prompt: str, message_value: str, message_data_type: PromptDataType, diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index 45c64dab00..dcc91a1992 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -9,7 +9,7 @@ from pyrit.exceptions.exception_classes import InvalidJsonException from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -21,11 +21,12 @@ class InsecureCodeScorer(FloatScaleScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, system_prompt_path: Optional[Union[str, Path]] = None, validator: Optional[ScorerPromptValidator] = None, ): @@ -33,12 +34,12 @@ def __init__( Initialize the Insecure Code Scorer. Args: - chat_target (PromptChatTarget): The target to use for scoring code security. + chat_target (PromptTarget): The target to use for scoring code security. system_prompt_path (Optional[Union[str, Path]]): Path to the YAML file containing the system prompt. Defaults to the default insecure code scoring prompt if not provided. validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None. """ - super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + super().__init__(validator=validator or self._DEFAULT_VALIDATOR, chat_target=chat_target) self._prompt_target = chat_target diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index ae9e0acc4b..3abbc9cf85 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -5,13 +5,14 @@ from typing import TYPE_CHECKING, Optional +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator if TYPE_CHECKING: from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, UnvalidatedScore - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): @@ -24,11 +25,12 @@ class SelfAskGeneralFloatScaleScorer(FloatScaleScorer): supported_data_types=["text"], is_objective_required=True, ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, system_prompt_format_string: str, prompt_format_string: Optional[str] = None, category: Optional[str] = None, @@ -52,7 +54,9 @@ def __init__( in the response, the provided `category` argument will be applied. Args: - chat_target (PromptChatTarget): The chat target used to score. + chat_target (PromptTarget): The chat target used to score. Must satisfy + CHAT_CONSUMER_REQUIREMENTS (multi-turn + editable history capabilities, + possibly via normalization-pipeline adaptation). system_prompt_format_string (str): System prompt template with placeholders for objective, prompt, and message_piece. prompt_format_string (Optional[str]): User prompt template with the same placeholders. @@ -71,7 +75,7 @@ def __init__( ValueError: If system_prompt_format_string is not provided or empty. ValueError: If min_value is greater than max_value. """ - super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + super().__init__(validator=validator or self._DEFAULT_VALIDATOR, chat_target=chat_target) self._prompt_target = chat_target if not system_prompt_format_string: raise ValueError("system_prompt_format_string must be provided and non-empty.") diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index c6762089b6..458a60c650 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -12,7 +12,7 @@ from pyrit.common.path import HARM_DEFINITION_PATH, SCORER_LIKERT_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -173,11 +173,12 @@ class SelfAskLikertScorer(FloatScaleScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, likert_scale: Optional[LikertScalePaths] = None, custom_likert_path: Optional[Path] = None, custom_system_prompt_path: Optional[Path] = None, @@ -187,7 +188,7 @@ def __init__( Initialize the SelfAskLikertScorer. Args: - chat_target (PromptChatTarget): The chat target to use for scoring. + chat_target (PromptTarget): The chat target to use for scoring. likert_scale (Optional[LikertScalePaths]): The Likert scale configuration to use for scoring. custom_likert_path (Optional[Path]): Path to a custom YAML file containing the Likert scale definition. This allows users to use their own Likert scales without modifying the code, as long as @@ -201,7 +202,7 @@ def __init__( ValueError: If both `likert_scale` and `custom_likert_path` are provided, if neither is provided, or if the provided Likert scale or system prompt YAML file is improperly formatted. """ - super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + super().__init__(validator=validator or self._DEFAULT_VALIDATOR, chat_target=chat_target) self._prompt_target = chat_target self._likert_scale = likert_scale diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 4bf0dc2dee..2ea1d27807 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -11,7 +11,7 @@ from pyrit.common.path import SCORER_SCALES_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -39,11 +39,12 @@ class SystemPaths(enum.Enum): supported_data_types=["text"], is_objective_required=True, ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, scale_arguments_path: Optional[Union[Path, str]] = None, system_prompt_path: Optional[Union[Path, str]] = None, validator: Optional[ScorerPromptValidator] = None, @@ -52,14 +53,14 @@ def __init__( Initialize the SelfAskScaleScorer. Args: - chat_target (PromptChatTarget): The chat target to use for scoring. + chat_target (PromptTarget): The chat target to use for scoring. scale_arguments_path (Optional[Union[Path, str]]): Path to the YAML file containing scale definitions. Defaults to TREE_OF_ATTACKS_SCALE if not provided. system_prompt_path (Optional[Union[Path, str]]): Path to the YAML file containing the system prompt. Defaults to GENERAL_SYSTEM_PROMPT if not provided. validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None. """ - super().__init__(validator=validator or self._DEFAULT_VALIDATOR) + super().__init__(validator=validator or self._DEFAULT_VALIDATOR, chat_target=chat_target) self._prompt_target = chat_target diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 11308edb64..79db6936e9 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -12,6 +12,7 @@ from typing import ( TYPE_CHECKING, Any, + ClassVar, Optional, Union, cast, @@ -35,11 +36,12 @@ UnvalidatedScore, ) from pyrit.prompt_target.batch_helper import batch_task_async +from pyrit.prompt_target.common.target_requirements import TargetRequirements if TYPE_CHECKING: from collections.abc import Sequence - from pyrit.prompt_target import PromptChatTarget, PromptTarget + from pyrit.prompt_target import PromptTarget from pyrit.score.scorer_evaluation.metrics_type import RegistryUpdateBehavior from pyrit.score.scorer_evaluation.scorer_evaluator import ( ScorerEvalDatasetFiles, @@ -59,16 +61,26 @@ class Scorer(Identifiable, abc.ABC): # Specifies glob patterns for datasets and a result file name. evaluation_file_mapping: Optional[ScorerEvalDatasetFiles] = None + #: Capability requirements placed on the scorer's chat target (if any). + #: Subclasses that use a chat target should override this and pass the + #: target to ``super().__init__(chat_target=...)`` so the base class can + #: validate it. + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() + _identifier: Optional[ComponentIdentifier] = None - def __init__(self, *, validator: ScorerPromptValidator): + def __init__(self, *, validator: ScorerPromptValidator, chat_target: Optional[PromptTarget] = None): """ Initialize the Scorer. Args: validator (ScorerPromptValidator): Validator for message pieces and scorer configuration. + chat_target (Optional[PromptTarget]): Chat target used by the scorer, if any. When + provided, it is validated against ``TARGET_REQUIREMENTS``. """ self._validator = validator + if chat_target is not None: + type(self).TARGET_REQUIREMENTS.validate(target=chat_target) def get_identifier(self) -> ComponentIdentifier: """ @@ -494,7 +506,7 @@ def scale_value_float(self, value: float, min_value: float, max_value: float) -> async def _score_value_with_llm( self, *, - prompt_target: PromptChatTarget, + prompt_target: PromptTarget, system_prompt: str, message_value: str, message_data_type: PromptDataType, @@ -516,7 +528,7 @@ async def _score_value_with_llm( description fields. Args: - prompt_target (PromptChatTarget): The target LLM to send the message to. + prompt_target (PromptTarget): The target LLM to send the message to. system_prompt (str): The system-level prompt that guides the behavior of the target LLM. message_value (str): The actual value or content to be scored by the LLM (e.g., text, image path, audio path). diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index 2aab7c264e..7016d881f2 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -11,7 +11,7 @@ from pyrit.exceptions import PyritException, pyrit_target_retry from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, MessagePiece, Score -from pyrit.prompt_target import GandalfLevel, PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, GandalfLevel, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -30,12 +30,13 @@ class GandalfScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator(supported_data_types=["text"]) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, level: GandalfLevel, - chat_target: PromptChatTarget, + chat_target: PromptTarget, validator: Optional[ScorerPromptValidator] = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: @@ -44,12 +45,18 @@ def __init__( Args: level (GandalfLevel): The Gandalf challenge level to score against. - chat_target (PromptChatTarget): The chat target used for password extraction. + chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy + CHAT_CONSUMER_REQUIREMENTS (multi-turn + editable history capabilities, + possibly via normalization-pipeline adaptation). validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to text data type validator. score_aggregator (TrueFalseAggregatorFunc): Aggregator for combining scores. Defaults to TrueFalseScoreAggregator.OR. """ - super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) + super().__init__( + validator=validator or self._DEFAULT_VALIDATOR, + score_aggregator=score_aggregator, + chat_target=chat_target, + ) self._prompt_target = chat_target self._defender = level.value diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index 7102ba3af6..6f1bc8e49d 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -11,7 +11,7 @@ from pyrit.common.path import SCORER_CONTENT_CLASSIFIERS_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -37,11 +37,12 @@ class SelfAskCategoryScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator() + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, content_classifier_path: Union[str, Path], score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, validator: Optional[ScorerPromptValidator] = None, @@ -50,13 +51,19 @@ def __init__( Initialize a new instance of the SelfAskCategoryScorer class. Args: - chat_target (PromptChatTarget): The chat target to interact with. + chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy + CHAT_CONSUMER_REQUIREMENTS (multi-turn + editable history capabilities, + possibly via normalization-pipeline adaptation). content_classifier_path (Union[str, Path]): The path to the classifier YAML file. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. """ - super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) + super().__init__( + score_aggregator=score_aggregator, + validator=validator or self._DEFAULT_VALIDATOR, + chat_target=chat_target, + ) self._prompt_target = chat_target diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 44bb362748..b3d3a080da 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -15,7 +16,7 @@ if TYPE_CHECKING: from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, UnvalidatedScore - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): @@ -28,11 +29,12 @@ class SelfAskGeneralTrueFalseScorer(TrueFalseScorer): supported_data_types=["text"], is_objective_required=False, ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, system_prompt_format_string: str, prompt_format_string: Optional[str] = None, category: Optional[str] = None, @@ -55,7 +57,9 @@ def __init__( in the response, the provided `category` argument will be applied. Args: - chat_target (PromptChatTarget): The chat target used to score. + chat_target (PromptTarget): The chat target used to score. Must satisfy + CHAT_CONSUMER_REQUIREMENTS (multi-turn + editable history capabilities, + possibly via normalization-pipeline adaptation). system_prompt_format_string (str): System prompt template with placeholders for objective, task (alias of objective), prompt, and message_piece. prompt_format_string (Optional[str]): User prompt template with the same placeholders. @@ -73,7 +77,11 @@ def __init__( Raises: ValueError: If system_prompt_format_string is not provided or empty. """ - super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) + super().__init__( + validator=validator or self._DEFAULT_VALIDATOR, + score_aggregator=score_aggregator, + chat_target=chat_target, + ) self._prompt_target = chat_target if not system_prompt_format_string: raise ValueError("system_prompt_format_string must be provided and non-empty.") diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index bf1c017dde..c6d35cdfe8 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -18,7 +18,7 @@ import pathlib from pyrit.models import MessagePiece, Score, UnvalidatedScore - from pyrit.prompt_target import PromptChatTarget + from pyrit.prompt_target import PromptTarget class SelfAskQuestionAnswerScorer(SelfAskTrueFalseScorer): @@ -37,7 +37,7 @@ class SelfAskQuestionAnswerScorer(SelfAskTrueFalseScorer): def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, true_false_question_path: Optional[pathlib.Path] = None, validator: Optional[ScorerPromptValidator] = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, @@ -46,7 +46,9 @@ def __init__( Initialize the SelfAskQuestionAnswerScorer object. Args: - chat_target (PromptChatTarget): The chat target to use for the scorer. + chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy + CHAT_CONSUMER_REQUIREMENTS (multi-turn + editable history capabilities, + possibly via normalization-pipeline adaptation). true_false_question_path (Optional[pathlib.Path]): The path to the true/false question file. Defaults to None, which uses the default question_answering.yaml file. validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index cf9b30f1d8..c7d2010011 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -8,7 +8,7 @@ from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt, UnvalidatedScore -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -64,11 +64,12 @@ class SelfAskRefusalScorer(TrueFalseScorer): """ _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator() + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, refusal_system_prompt_path: Union[RefusalScorerPaths, Path, str] = RefusalScorerPaths.OBJECTIVE_STRICT, prompt_format_string: Optional[str] = None, validator: Optional[ScorerPromptValidator] = None, @@ -78,7 +79,9 @@ def __init__( Initialize the SelfAskRefusalScorer. Args: - chat_target (PromptChatTarget): The endpoint that will be used to score the prompt. + chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy + CHAT_CONSUMER_REQUIREMENTS (multi-turn + editable history capabilities, + possibly via normalization-pipeline adaptation). refusal_system_prompt_path (Union[RefusalScorerPaths, Path, str]): The path to the system prompt to use for refusal detection. Can be a RefusalScorerPaths enum value, a Path, or a string path. Defaults to RefusalScorerPaths.OBJECTIVE_STRICT. @@ -100,7 +103,11 @@ def __init__( result_file="refusal_scorer/refusal_metrics.jsonl", ) - super().__init__(score_aggregator=score_aggregator, validator=validator or self._DEFAULT_VALIDATOR) + super().__init__( + score_aggregator=score_aggregator, + validator=validator or self._DEFAULT_VALIDATOR, + chat_target=chat_target, + ) self._prompt_target = chat_target diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index d79060fcb4..fcf714f339 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -12,7 +12,7 @@ from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.identifiers import ComponentIdentifier from pyrit.models import MessagePiece, Score, SeedPrompt -from pyrit.prompt_target import PromptChatTarget +from pyrit.prompt_target import CHAT_CONSUMER_REQUIREMENTS, PromptTarget from pyrit.score.scorer_prompt_validator import ScorerPromptValidator from pyrit.score.true_false.true_false_score_aggregator import ( TrueFalseAggregatorFunc, @@ -93,11 +93,12 @@ class SelfAskTrueFalseScorer(TrueFalseScorer): _DEFAULT_VALIDATOR: ScorerPromptValidator = ScorerPromptValidator( supported_data_types=["text", "image_path"], ) + TARGET_REQUIREMENTS = CHAT_CONSUMER_REQUIREMENTS def __init__( self, *, - chat_target: PromptChatTarget, + chat_target: PromptTarget, true_false_question_path: Optional[Union[str, Path]] = None, true_false_question: Optional[TrueFalseQuestion] = None, true_false_system_prompt_path: Optional[Union[str, Path]] = None, @@ -108,7 +109,9 @@ def __init__( Initialize the SelfAskTrueFalseScorer. Args: - chat_target (PromptChatTarget): The chat target to interact with. + chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy + CHAT_CONSUMER_REQUIREMENTS (multi-turn + editable history capabilities, + possibly via normalization-pipeline adaptation). true_false_question_path (Optional[Union[str, Path]]): The path to the true/false question file. true_false_question (Optional[TrueFalseQuestion]): The true/false question object. true_false_system_prompt_path (Optional[Union[str, Path]]): The path to the system prompt file. @@ -120,7 +123,11 @@ def __init__( ValueError: If both true_false_question_path and true_false_question are provided. ValueError: If required keys are missing in true_false_question. """ - super().__init__(validator=validator or self._DEFAULT_VALIDATOR, score_aggregator=score_aggregator) + super().__init__( + validator=validator or self._DEFAULT_VALIDATOR, + score_aggregator=score_aggregator, + chat_target=chat_target, + ) self._prompt_target = chat_target diff --git a/pyrit/score/true_false/true_false_scorer.py b/pyrit/score/true_false/true_false_scorer.py index b0c90c0737..6b6e79815e 100644 --- a/pyrit/score/true_false/true_false_scorer.py +++ b/pyrit/score/true_false/true_false_scorer.py @@ -12,6 +12,7 @@ ) if TYPE_CHECKING: + from pyrit.prompt_target import PromptTarget from pyrit.score.scorer_evaluation.scorer_evaluator import ScorerEvalDatasetFiles from pyrit.score.scorer_evaluation.scorer_metrics import ObjectiveScorerMetrics @@ -33,6 +34,7 @@ def __init__( *, validator: ScorerPromptValidator, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, + chat_target: Optional["PromptTarget"] = None, ) -> None: """ Initialize the TrueFalseScorer. @@ -41,6 +43,8 @@ def __init__( validator (ScorerPromptValidator): Custom validator. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. + chat_target (Optional[PromptTarget]): Optional chat target used by the scorer, + forwarded to the base class for validation against ``TARGET_REQUIREMENTS``. """ self._score_aggregator = score_aggregator @@ -55,7 +59,7 @@ def __init__( result_file="objective/objective_achieved_metrics.jsonl", ) - super().__init__(validator=validator) + super().__init__(validator=validator, chat_target=chat_target) def validate_return_scores(self, scores: list[Score]) -> None: """ diff --git a/tests/unit/backend/test_converter_service.py b/tests/unit/backend/test_converter_service.py index 418441385e..385ddb7d5f 100644 --- a/tests/unit/backend/test_converter_service.py +++ b/tests/unit/backend/test_converter_service.py @@ -413,6 +413,7 @@ def _try_instantiate_converter(converter_name: str): from pyrit.common.apply_defaults import _RequiredValueSentinel from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget + from pyrit.prompt_target.common.prompt_target import PromptTarget # Converters requiring external credentials or resources that can't be mocked # at the constructor level — these validate env vars / files in __init__ body @@ -461,11 +462,18 @@ def _try_instantiate_converter(converter_name: str): ann = param.annotation ann_str = str(ann) if ann is not inspect.Parameter.empty else "" - # PromptChatTarget — mock it with a proper identifier + # PromptChatTarget / PromptTarget — mock it with a proper identifier if ann is not inspect.Parameter.empty and ( - (isinstance(ann, type) and issubclass(ann, PromptChatTarget)) or "PromptChatTarget" in ann_str + (isinstance(ann, type) and issubclass(ann, PromptTarget)) + or "PromptChatTarget" in ann_str + or "PromptTarget" in ann_str ): - mock_target = MagicMock(spec=PromptChatTarget) + spec_cls = ( + PromptChatTarget + if (isinstance(ann, type) and issubclass(ann, PromptChatTarget)) or "PromptChatTarget" in ann_str + else PromptTarget + ) + mock_target = MagicMock(spec=spec_cls) mock_target.__class__.__name__ = "MockChatTarget" # Configure get_identifier() to return a proper identifier-like object # so that _create_identifier can extract class_name, model_name, etc. diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index c86e741e9c..5c2f0f4c79 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -94,7 +94,13 @@ def mock_chat_target() -> MagicMock: @pytest.fixture def mock_prompt_target() -> MagicMock: """Create a mock prompt target (non-chat) for testing.""" + from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + from pyrit.prompt_target.common.target_configuration import TargetConfiguration + target = MagicMock(spec=PromptTarget) + target.configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False), + ) target.get_identifier.return_value = _mock_target_id("MockTarget") return target @@ -1059,7 +1065,7 @@ async def test_non_chat_target_behavior_raise_explicit( config = PrependedConversationConfig(non_chat_target_behavior="raise") with pytest.raises( - ValueError, match="prepended_conversation requires the objective target to be a PromptChatTarget" + ValueError, match="prepended_conversation requires the objective target to be a chat-capable" ): await manager.initialize_context_async( context=context, diff --git a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py index 3baeaf463c..9de79ab384 100644 --- a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py @@ -27,6 +27,7 @@ def _make_strategy(*, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn + target.configuration.includes.return_value = supports_multi_turn target.get_identifier.return_value = MagicMock() with patch.multiple( @@ -376,6 +377,7 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn + target.configuration.includes.return_value = supports_multi_turn target.get_identifier.return_value = MagicMock() adversarial_chat = MagicMock() @@ -684,8 +686,14 @@ class TestValueErrorGuards: """Test that incompatible attacks raise ValueError for single-turn targets.""" def _make_single_turn_target(self): + from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + from pyrit.prompt_target.common.target_configuration import TargetConfiguration + target = MagicMock() target.capabilities.supports_multi_turn = False + target.configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True), + ) target.get_identifier.return_value = MagicMock() return target @@ -706,52 +714,34 @@ def _make_scoring_config(self): @pytest.mark.asyncio async def test_crescendo_raises_for_single_turn_target(self): - from pyrit.executor.attack.multi_turn.crescendo import CrescendoAttack, CrescendoAttackContext + from pyrit.executor.attack.multi_turn.crescendo import CrescendoAttack target = self._make_single_turn_target() - attack = CrescendoAttack( - objective_target=target, - attack_adversarial_config=self._make_adversarial_config(), - attack_scoring_config=self._make_scoring_config(), - ) - - context = CrescendoAttackContext( - params=AttackParameters(objective="Test"), - ) - with pytest.raises(ValueError, match="CrescendoAttack requires a multi-turn target"): - await attack._setup_async(context=context) + with pytest.raises(ValueError, match="supports_multi_turn"): + CrescendoAttack( + objective_target=target, + attack_adversarial_config=self._make_adversarial_config(), + attack_scoring_config=self._make_scoring_config(), + ) @pytest.mark.asyncio async def test_multi_prompt_sending_raises_for_single_turn_target(self): from pyrit.executor.attack.multi_turn.multi_prompt_sending import MultiPromptSendingAttack target = self._make_single_turn_target() - attack = MultiPromptSendingAttack(objective_target=target) - - context = MultiTurnAttackContext( - params=AttackParameters(objective="Test"), - ) - with pytest.raises(ValueError, match="MultiPromptSendingAttack requires a multi-turn target"): - await attack._setup_async(context=context) + with pytest.raises(ValueError, match="supports_multi_turn"): + MultiPromptSendingAttack(objective_target=target) @pytest.mark.asyncio async def test_chunked_request_raises_for_single_turn_target(self): - from pyrit.executor.attack.multi_turn.chunked_request import ( - ChunkedRequestAttack, - ChunkedRequestAttackContext, - ) + from pyrit.executor.attack.multi_turn.chunked_request import ChunkedRequestAttack target = self._make_single_turn_target() - attack = ChunkedRequestAttack(objective_target=target) - - context = ChunkedRequestAttackContext( - params=AttackParameters(objective="Test"), - ) - with pytest.raises(ValueError, match="ChunkedRequestAttack requires a multi-turn target"): - await attack._setup_async(context=context) + with pytest.raises(ValueError, match="supports_multi_turn"): + ChunkedRequestAttack(objective_target=target) @pytest.mark.usefixtures("patch_central_database") @@ -764,6 +754,7 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn + target.configuration.includes.return_value = supports_multi_turn target.get_identifier.return_value = MagicMock() adversarial_chat = MagicMock() diff --git a/tests/unit/prompt_converter/test_prompt_converter.py b/tests/unit/prompt_converter/test_prompt_converter.py index 35dc595339..e19682a4e2 100644 --- a/tests/unit/prompt_converter/test_prompt_converter.py +++ b/tests/unit/prompt_converter/test_prompt_converter.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import re +from unittest.mock import patch import pytest from unit.mocks import MockPromptTarget @@ -579,3 +580,22 @@ def test_llm_based_converters_supported_types( converter = converter_class(**converter_args) assert sorted(converter.supported_input_types) == sorted(expected_input_types) assert sorted(converter.supported_output_types) == sorted(expected_output_types) + + +@pytest.mark.parametrize( + "converter_class, converter_args", + [ + (LLMGenericTextConverter, {"prompt_template": SeedPrompt(data_type="text", value="test template")}), + (MaliciousQuestionGeneratorConverter, {}), + (MathPromptConverter, {}), + (PersuasionConverter, {"persuasion_technique": "misrepresentation"}), + (TranslationConverter, {"language": "es"}), + (VariationConverter, {}), + ], +) +def test_llm_based_converters_validate_target_requirements(setup_memory, converter_class, converter_args): + """Ensure LLM-based converters validate their target via TARGET_REQUIREMENTS on construction.""" + converter_args["converter_target"] = setup_memory + with patch("pyrit.prompt_target.common.target_requirements.TargetRequirements.validate") as mock_validate: + converter_class(**converter_args) + mock_validate.assert_called_once_with(target=setup_memory) diff --git a/tests/unit/prompt_target/target/test_target_requirements.py b/tests/unit/prompt_target/target/test_target_requirements.py index 002ccf086c..fdba9410cb 100644 --- a/tests/unit/prompt_target/target/test_target_requirements.py +++ b/tests/unit/prompt_target/target/test_target_requirements.py @@ -1,131 +1,120 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from unittest.mock import MagicMock + import pytest +from pyrit.prompt_target import ( + CHAT_CONSUMER_REQUIREMENTS, + CapabilityName, + TargetRequirements, +) from pyrit.prompt_target.common.target_capabilities import ( CapabilityHandlingPolicy, - CapabilityName, TargetCapabilities, UnsupportedCapabilityBehavior, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration -from pyrit.prompt_target.common.target_requirements import TargetRequirements - - -@pytest.fixture -def adapt_all_policy(): - return CapabilityHandlingPolicy( - behaviors={ - CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, - } - ) -# --------------------------------------------------------------------------- -# Construction -# --------------------------------------------------------------------------- +def _make_target(*, configuration: TargetConfiguration) -> MagicMock: + target = MagicMock() + target.configuration = configuration + return target -def test_init_default_has_empty_capabilities(): - reqs = TargetRequirements() - assert reqs.required_capabilities == frozenset() +def test_default_requirements_require_nothing(): + assert TargetRequirements().required == frozenset() -def test_init_with_capabilities(): +def test_construction_from_frozenset(): reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) + required=frozenset({CapabilityName.MULTI_TURN, CapabilityName.JSON_OUTPUT}), ) - assert CapabilityName.MULTI_TURN in reqs.required_capabilities - assert CapabilityName.SYSTEM_PROMPT in reqs.required_capabilities + assert reqs.required == {CapabilityName.MULTI_TURN, CapabilityName.JSON_OUTPUT} -# --------------------------------------------------------------------------- -# validate — all pass -# --------------------------------------------------------------------------- +def test_chat_consumer_requirements_shape(): + assert CHAT_CONSUMER_REQUIREMENTS.required == { + CapabilityName.EDITABLE_HISTORY, + CapabilityName.MULTI_TURN, + } -def test_validate_passes_when_target_supports_all_natively(): - caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) - ) - reqs.validate(configuration=config) +def test_requirements_are_frozen(): + reqs = TargetRequirements(required=frozenset({CapabilityName.MULTI_TURN})) + with pytest.raises(Exception): + reqs.required = frozenset() # type: ignore[misc] -def test_validate_passes_when_policy_is_adapt(adapt_all_policy): - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) +def test_validate_passes_on_native_support(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_editable_history=True, + ), + ), ) - reqs.validate(configuration=config) - - -def test_validate_passes_with_empty_requirements(): - caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements() - reqs.validate(configuration=config) - -# --------------------------------------------------------------------------- -# validate — failures -# --------------------------------------------------------------------------- - - -def test_validate_raises_when_capability_missing_and_no_policy(): - # EDITABLE_HISTORY has no normalizer and no handling policy — validate raises. - caps = TargetCapabilities(supports_editable_history=False, supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps) - reqs = TargetRequirements(required_capabilities=frozenset({CapabilityName.EDITABLE_HISTORY})) - with pytest.raises(ValueError, match="supports_editable_history"): - reqs.validate(configuration=config) - - -def test_validate_raises_when_capability_missing_and_policy_raise(adapt_all_policy): - # json_output is missing and the policy is RAISE — validate raises. - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False, supports_json_output=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - reqs = TargetRequirements(required_capabilities=frozenset({CapabilityName.JSON_OUTPUT})) - with pytest.raises(ValueError, match="supports_json_output"): - reqs.validate(configuration=config) - - -def test_validate_collects_all_unsatisfied_capabilities(adapt_all_policy): - """When multiple capabilities are missing, validate reports all violations.""" - caps = TargetCapabilities( - supports_multi_turn=False, - supports_system_prompt=False, - supports_json_output=False, - supports_editable_history=False, + CHAT_CONSUMER_REQUIREMENTS.validate(target=target) + + +def test_validate_passes_when_policy_is_adapt(): + # Note: EDITABLE_HISTORY is not adaptable, so this test uses a custom + # requirement over capabilities that the policy can adapt. + reqs = TargetRequirements(required=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT})) + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=False, + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + }, + ), + ), ) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - # json_output => RAISE, editable_history => no policy (raises) - reqs = TargetRequirements( - required_capabilities=frozenset({CapabilityName.JSON_OUTPUT, CapabilityName.EDITABLE_HISTORY}) - ) - with pytest.raises(ValueError, match="2 required capability") as exc_info: - reqs.validate(configuration=config) - assert "supports_json_output" in str(exc_info.value) - assert "supports_editable_history" in str(exc_info.value) + reqs.validate(target=target) + + +def test_validate_raises_when_capability_neither_native_nor_adapt(): + reqs = TargetRequirements(required=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT})) + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + ), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + }, + ), + ), + ) -def test_validate_mixed_adapt_and_raise(adapt_all_policy): - """One capability adapts but another raises — validate should raise.""" - caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False, supports_json_output=False) - config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) - # multi_turn and system_prompt => ADAPT (OK), json_output => RAISE (fail) - reqs = TargetRequirements( - required_capabilities=frozenset( - {CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT, CapabilityName.JSON_OUTPUT} - ) + with pytest.raises(ValueError, match=CapabilityName.SYSTEM_PROMPT.value): + reqs.validate(target=target) + + +def test_validate_empty_required_always_passes(): + target = _make_target( + configuration=TargetConfiguration( + capabilities=TargetCapabilities(), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + }, + ), + ), ) - with pytest.raises(ValueError, match="supports_json_output"): - reqs.validate(configuration=config) + + TargetRequirements().validate(target=target) diff --git a/tests/unit/scenario/test_psychosocial_harms.py b/tests/unit/scenario/test_psychosocial_harms.py index 205d6ed3f1..4a207cc45b 100644 --- a/tests/unit/scenario/test_psychosocial_harms.py +++ b/tests/unit/scenario/test_psychosocial_harms.py @@ -323,6 +323,67 @@ async def test_no_target_duplication_async( assert scenario._objective_scorer is not None +@pytest.mark.usefixtures(*FIXTURES) +class TestPsychosocialTargetRequirements: + """Tests for Psychosocial TARGET_REQUIREMENTS declaration and enforcement.""" + + def test_target_requirements_declares_editable_history_natively(self): + """Psychosocial runs CrescendoAttack, so it must require EDITABLE_HISTORY natively.""" + from pyrit.prompt_target.common.target_capabilities import CapabilityName + + assert CapabilityName.EDITABLE_HISTORY in Psychosocial.TARGET_REQUIREMENTS.native_required + + @pytest.mark.asyncio + async def test_initialize_async_invokes_target_requirements_validate( + self, + mock_objective_target, + mock_objective_scorer, + mock_resolved_seed_data, + mock_dataset_config, + ): + """initialize_async must delegate capability validation to TARGET_REQUIREMENTS.validate.""" + with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + scenario = Psychosocial(objective_scorer=mock_objective_scorer) + with patch("pyrit.prompt_target.common.target_requirements.TargetRequirements.validate") as mock_validate: + await scenario.initialize_async( + objective_target=mock_objective_target, + dataset_config=mock_dataset_config, + ) + + # Scorers / attacks also validate; ensure the scenario itself validated objective_target. + assert any(call.kwargs.get("target") is mock_objective_target for call in mock_validate.call_args_list), ( + "Expected TARGET_REQUIREMENTS.validate to be called with objective_target" + ) + + @pytest.mark.asyncio + async def test_initialize_async_rejects_target_missing_editable_history( + self, + mock_objective_scorer, + mock_resolved_seed_data, + mock_dataset_config, + ): + """A target that does not natively support EDITABLE_HISTORY must be rejected.""" + from pyrit.prompt_target import PromptTarget + from pyrit.prompt_target.common.target_capabilities import CapabilityName + + non_chat_target = MagicMock(spec=PromptTarget) + non_chat_target.get_identifier.return_value = ComponentIdentifier( + class_name="NonChatTarget", class_module="test" + ) + # Configuration reports no EDITABLE_HISTORY support + non_chat_target.configuration.includes.side_effect = ( + lambda *, capability: capability != CapabilityName.EDITABLE_HISTORY + ) + + with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + scenario = Psychosocial(objective_scorer=mock_objective_scorer) + with pytest.raises(ValueError, match="editable_history"): + await scenario.initialize_async( + objective_target=non_chat_target, + dataset_config=mock_dataset_config, + ) + + @pytest.mark.usefixtures(*FIXTURES) class TestPsychosocialHarmsStrategy: """Tests for PsychosocialHarmsStrategy enum.""" diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index 947cf6f645..89ebd9aaf0 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -292,6 +292,36 @@ async def test_initialize_async_uses_default_values(self, mock_objective_target) assert scenario._max_concurrency == 10 assert scenario._memory_labels == {} + @pytest.mark.asyncio + async def test_initialize_async_validates_target_requirements(self, mock_objective_target): + """Test that initialize_async validates objective_target against TARGET_REQUIREMENTS.""" + scenario = ConcreteScenario(name="Test Scenario", version=1) + + with patch("pyrit.prompt_target.common.target_requirements.TargetRequirements.validate") as mock_validate: + await scenario.initialize_async(objective_target=mock_objective_target) + + mock_validate.assert_called_once_with(target=mock_objective_target) + + @pytest.mark.asyncio + async def test_initialize_async_propagates_target_requirements_error(self, mock_objective_target): + """Test that initialize_async surfaces errors from TARGET_REQUIREMENTS.validate.""" + scenario = ConcreteScenario(name="Test Scenario", version=1) + + with patch( + "pyrit.prompt_target.common.target_requirements.TargetRequirements.validate", + side_effect=ValueError("Target must natively support 'editable_history'"), + ): + with pytest.raises(ValueError, match="editable_history"): + await scenario.initialize_async(objective_target=mock_objective_target) + + def test_scenario_base_target_requirements_is_empty(self): + """Base Scenario declares an empty TargetRequirements so it accepts any target by default.""" + from pyrit.prompt_target.common.target_requirements import TargetRequirements + + assert isinstance(Scenario.TARGET_REQUIREMENTS, TargetRequirements) + assert Scenario.TARGET_REQUIREMENTS.required == frozenset() + assert Scenario.TARGET_REQUIREMENTS.native_required == frozenset() + @pytest.mark.usefixtures("patch_central_database") class TestScenarioExecution: