diff --git a/privacy_guard/attacks/extraction/generation_attack.py b/privacy_guard/attacks/extraction/generation_attack.py index 7a5bad3..7fbc32d 100644 --- a/privacy_guard/attacks/extraction/generation_attack.py +++ b/privacy_guard/attacks/extraction/generation_attack.py @@ -91,7 +91,7 @@ def __init__( input_column: str = "prompt", target_column: str = "target", output_column: str = "prediction", - batch_size: int = 4, + batch_size: int = 1, **generation_kwargs: Any, ) -> None: if output_file is None and output_format is not None: @@ -133,6 +133,7 @@ def run_attack(self) -> TextInclusionAnalysisInput: logger.info(f"Generating text for {len(prompts)} prompts") generations = self.predictor.generate( prompts=prompts, + batch_size=self.batch_size, **self.generation_kwargs, ) diff --git a/privacy_guard/attacks/extraction/predictors/gpt_oss_predictor.py b/privacy_guard/attacks/extraction/predictors/gpt_oss_predictor.py new file mode 100644 index 0000000..1ba8c2a --- /dev/null +++ b/privacy_guard/attacks/extraction/predictors/gpt_oss_predictor.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +""" +GPT OSS predictor implementation for openai extraction attacks. +""" + +from typing import Any, Dict, List + +import transformers.utils.import_utils + +from privacy_guard.attacks.extraction.predictors.huggingface_predictor import ( + HuggingFacePredictor, +) +from transformers.utils.import_utils import ( + _is_package_available, + is_accelerate_available, +) + + +class GPTOSSPredictor(HuggingFacePredictor): + """ + Inherits from HuggingFacePredictor and updates the generation logic to match + GPT OSS expectation. + + Use this predictor for models like "gpt-oss-20b" and "gpt-oss-120b" + + Note: HuggingFacePredictor "get_logits" and "get_logprobs" behavior is + not yet tested w/ GPTOSSPredictor + """ + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + accelerate_available = self.accelerate_available_workaround() + if not accelerate_available: + raise ImportError( + 'Required library "accelerate" for GPT OSS not available' + ) + + super().__init__( + *args, + **kwargs, + ) + + def accelerate_available_workaround(self) -> bool: + """ + In old transformers versions, availability for the required 'accelerate' package + is checked once at import time and the result is saved for all future checks. + + For Meta internal packaging this check returns as false at import time even when + the package is available at runtime. + + This is a workaround which updates the saved values in transformers + when this class is initialized. + + See the following link to the old transformers code pointer. + https://github.com/huggingface/transformers/blob/ + e95441bdb586a7c3c9b4f61a41e99178c1becf54/src/transformers/utils/import_utils.py#L126 + """ + if is_accelerate_available(): + return True + + _accelerate_available, _accelerate_version = ( # pyre-ignore + _is_package_available("accelerate", return_version=True) + ) + + if _accelerate_available: + transformers.utils.import_utils._accelerate_available = ( + _accelerate_available + ) + transformers.utils.import_utils._accelerate_version = _accelerate_version + + return is_accelerate_available() + + return False + + def preprocess_batch_messages(self, batch: List[str]) -> List[Dict[str, str]]: + """ + Prepare a batch of messages for prediction. + + Differs than parent HuggingfacePredictor in that it returns a list of Dict + instead of str, and includes "role" user field. + """ + clean_batch = [] + for item in batch: + if not isinstance(item, str): + raise Warning(f"Found non-string item in batch: {type(item)}") + clean_batch.append(str(item) if item is not None else "") + else: + clean_batch.append({"role": "user", "content": item}) + return clean_batch + + # Override + def _generate_process_batch( + self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any + ) -> List[str]: + """Process a single batch of prompts. + apply_chat_template is used to apply the harmony response format, required for + gpt models to work properly. + """ + clean_batch: List[Dict[str, str]] = self.preprocess_batch_messages(batch) + + # Different than parent HuggingfacePredictor class + add_generation_prompt = ( + True + if "add_generation_prompt" not in generation_kwargs + else generation_kwargs.pop("add_generation_prompt") + ) + reasoning_effort = ( + "medium" + if "reasoning_effort" not in generation_kwargs + else generation_kwargs.pop("reasoning_effort") + ) + inputs = self.tokenizer.apply_chat_template( # pyre-ignore + clean_batch, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_dict=True, + return_tensors="pt", + reasoning_effort=reasoning_effort, + ).to(self.device) + + return self._generate_decode_logic( + inputs=inputs, max_new_tokens=max_new_tokens, **generation_kwargs + ) diff --git a/privacy_guard/attacks/extraction/predictors/huggingface_predictor.py b/privacy_guard/attacks/extraction/predictors/huggingface_predictor.py index 045da99..387e3e5 100644 --- a/privacy_guard/attacks/extraction/predictors/huggingface_predictor.py +++ b/privacy_guard/attacks/extraction/predictors/huggingface_predictor.py @@ -36,8 +36,23 @@ def __init__( device: str | None = None, model_kwargs: Dict[str, Any] | None = None, tokenizer_kwargs: Dict[str, Any] | None = None, + include_prompt_in_generation_result: bool = True, **kwargs: Any, ) -> None: + """ + A predictor class that leverages HuggingFace models for text generation and inference. + This class wraps a HuggingFace model and tokenizer, allowing for flexible configuration + and usage on different devices (CPU/Cuda GPU). It supports passing custom arguments to both + the model and tokenizer, and can optionally include the input prompt in the generation result. + Args: + model_name: The name or path of the HuggingFace model to load (locally or remotely from Huggingface). + device: The device to run the model on (e.g., 'cpu', 'cuda') If None, selects device based on cuda availability + model_kwargs: Additional keyword arguments to pass to the model during initialization. Defaults to None. + tokenizer_kwargs: Additional keyword arguments to pass to the tokenizer during initialization. Defaults to None. + include_prompt_in_generation_result: If True, the generation results will incldue the full prompt. + If false, the result will decode only the newly generated tokens and not the input prompt. + **kwargs: Additional keyword arguments for the base predictor or other custom settings. + """ self.model_name: str = model_name self.device: str = ( device @@ -51,6 +66,7 @@ def __init__( self.tokenizer_kwargs: Dict[str, Any] = tokenizer_kwargs or {} self.model: PreTrainedModel self.tokenizer: PreTrainedTokenizer + self.include_prompt_in_generation_result = include_prompt_in_generation_result # Model already loaded on device - now pass the kwargs self.model, self.tokenizer = load_model_and_tokenizer( model_name, @@ -73,15 +89,16 @@ def preprocess_batch(self, batch: List[str]) -> List[str]: clean_batch.append(item) return clean_batch - def _generate_process_batch( - self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any + def _generate_decode_logic( + self, + inputs: Dict[str, Any], + max_new_tokens: int = 512, + **generation_kwargs: Any, ) -> List[str]: - """Process a single batch of prompts.""" - clean_batch = self.preprocess_batch(batch) - - inputs = self.tokenizer( - clean_batch, return_tensors="pt", padding=True, truncation=True - ).to(self.device) + """Calls the correct generate call based on the model type. + Supports logic for returning only the generated text, or the full sample including + the prompt.""" + include_prompt_in_generation_result = self.include_prompt_in_generation_result with torch.no_grad(): # Handle both regular models and DDP-wrapped models @@ -94,10 +111,35 @@ def _generate_process_batch( **inputs, max_new_tokens=max_new_tokens, **generation_kwargs ) - batch_results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + if include_prompt_in_generation_result: + batch_results = self.tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + else: + trimmed_outputs = [] + for output, input_val in zip(outputs, inputs["input_ids"]): + trimmed_outputs.append(output[len(input_val) :]) + + batch_results = self.tokenizer.batch_decode( + trimmed_outputs, skip_special_tokens=True + ) return batch_results + def _generate_process_batch( + self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any + ) -> List[str]: + """Process a single batch of prompts.""" + clean_batch = self.preprocess_batch(batch) + + inputs = self.tokenizer( + clean_batch, return_tensors="pt", padding=True, truncation=True + ).to(self.device) + + return self._generate_decode_logic( + inputs=inputs, max_new_tokens=max_new_tokens, **generation_kwargs + ) + def generate(self, prompts: List[str], **generation_kwargs: Any) -> List[str]: """Generate text continuations for given prompts.""" if not prompts: diff --git a/privacy_guard/attacks/extraction/predictors/tests/test_gpt_oss_predictor.py b/privacy_guard/attacks/extraction/predictors/tests/test_gpt_oss_predictor.py new file mode 100644 index 0000000..9f8d43c --- /dev/null +++ b/privacy_guard/attacks/extraction/predictors/tests/test_gpt_oss_predictor.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +import unittest +from unittest.mock import MagicMock, patch + +import torch +from privacy_guard.attacks.extraction.predictors.gpt_oss_predictor import ( + GPTOSSPredictor, +) + + +class TestGPTOSSPredictor(unittest.TestCase): + def setUp(self) -> None: + self.model_name = "test-model" + self.device = "cpu" + self.vocab_size = 50257 + + # Create simple mocks for model and tokenizer + self.mock_model = MagicMock( + spec=["generate", "config"] + ) # Only allow these attributes + self.mock_model.config.vocab_size = self.vocab_size + self.mock_model.generate.return_value = torch.tensor([[1, 2, 3, 4, 5]]) + + self.mock_tokenizer = MagicMock() + self.mock_tokenizer.pad_token = None + self.mock_tokenizer.eos_token = "<|endoftext|>" + self.mock_tokenizer.pad_token_id = 0 + self.mock_tokenizer.batch_decode.return_value = ["Generated text"] + + with patch.object( + GPTOSSPredictor, "accelerate_available_workaround", return_value=True + ), patch( + "privacy_guard.attacks.extraction.predictors.huggingface_predictor.load_model_and_tokenizer", + return_value=( + self.mock_model, + self.mock_tokenizer, + ), + ): + self.predictor = GPTOSSPredictor(self.model_name, self.device) + + def test_init(self) -> None: + """Test predictor initialization.""" + self.assertEqual(self.predictor.model_name, self.model_name) + self.assertEqual(self.predictor.device, self.device) + + def test_generate(self) -> None: + """Test generate functionality.""" + + # Mock tokenizer responses + mock_inputs = MagicMock() + mock_inputs.to.return_value = { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[1, 1, 1]]), + } + self.mock_tokenizer.return_value = mock_inputs + self.mock_tokenizer.batch_decode.return_value = ["Generated text"] + + # Mock the tqdm within the generate method - patch the specific import + with patch( + "privacy_guard.attacks.extraction.predictors.huggingface_predictor.tqdm" + ) as mock_tqdm: + mock_tqdm.side_effect = lambda x, **kwargs: x + result = self.predictor.generate(["Test prompt"]) + + self.assertEqual(result, ["Generated text"]) + self.mock_model.generate.assert_called_once() + + def test_generate_with_kwargs(self) -> None: + """Test generate functionality specifying add_generation_prompt + and reasoning_effort""" + + # Mock tokenizer responses + mock_inputs = MagicMock() + mock_inputs.to.return_value = { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[1, 1, 1]]), + } + self.mock_tokenizer.return_value = mock_inputs + self.mock_tokenizer.batch_decode.return_value = ["Generated text"] + + # Mock the tqdm within the generate method - patch the specific import + with patch( + "privacy_guard.attacks.extraction.predictors.huggingface_predictor.tqdm" + ) as mock_tqdm: + mock_tqdm.side_effect = lambda x, **kwargs: x + result = self.predictor.generate( + ["Test prompt"], + add_generation_prompt=True, + reasoning_effort="medium", + ) + + self.assertEqual(result, ["Generated text"]) + self.mock_model.generate.assert_called_once() + + @patch( + "privacy_guard.attacks.extraction.predictors.gpt_oss_predictor.is_accelerate_available" + ) + def test_accelerate_available_workaround_when_initially_true( + self, mock_is_accelerate_available: MagicMock + ) -> None: + """Test accelerate_available_workaround when is_accelerate_available is True initially.""" + + # Setup: mock is_accelerate_available to return True + mock_is_accelerate_available.return_value = True + + # Execute: call the workaround method + # accelerate_available_workaround is called in __init__ + result = self.predictor.accelerate_available_workaround() + + # Assert: method returns True and only checks is_accelerate_available + self.assertTrue(result) + mock_is_accelerate_available.assert_called_once() + + @patch( + "privacy_guard.attacks.extraction.predictors.gpt_oss_predictor._is_package_available" + ) + @patch( + "privacy_guard.attacks.extraction.predictors.gpt_oss_predictor.is_accelerate_available" + ) + def test_accelerate_available_workaround_when_package_available( + self, + mock_is_accelerate_available: MagicMock, + mock_is_package_available: MagicMock, + ) -> None: + """Test when is_accelerate_available is initially false but _is_package_available returns true.""" + + # Setup: mock is_accelerate_available to return False initially, then True after workaround + mock_is_accelerate_available.side_effect = [False, True] + + # Setup: mock _is_package_available to return True and a version string + mock_is_package_available.return_value = (True, "0.21.0") + + # Execute: call the workaround method + result = self.predictor.accelerate_available_workaround() + + # Assert: method returns True after setting the accelerate availability + self.assertTrue(result) + self.assertEqual(mock_is_accelerate_available.call_count, 2) + mock_is_package_available.assert_called_once() + # mock_import_utils._is_package_available.assert_called_once_with( + # "accelerate", return_version=True + # ) + + @patch( + "privacy_guard.attacks.extraction.predictors.gpt_oss_predictor._is_package_available" + ) + @patch( + "privacy_guard.attacks.extraction.predictors.gpt_oss_predictor.is_accelerate_available" + ) + def test_accelerate_available_workaround_when_both_false( + self, + mock_is_accelerate_available: MagicMock, + mock_is_package_available: MagicMock, + ) -> None: + """Test when both is_accelerate_available and _is_package_available are false.""" + + # Setup: mock is_accelerate_available to return False + mock_is_accelerate_available.return_value = False + + # Setup: mock _is_package_available to return False + mock_is_package_available.return_value = (False, "N/A") + + # Execute: call the workaround method + result = self.predictor.accelerate_available_workaround() + + # Assert: method returns False + self.assertFalse(result) + mock_is_accelerate_available.assert_called_once() + mock_is_package_available.assert_called_once() + # mock_import_utils._is_package_available.assert_called_once_with( + # "accelerate", return_version=True + # ) + + def test_init_fails_when_accelerate_not_available( + self, + ) -> None: + """Test that instantiating GPTOSSPredictor when accelerate is not available + raises exception.""" + with self.assertRaises(ImportError): + with patch.object( + GPTOSSPredictor, "accelerate_available_workaround", return_value=False + ): + _ = GPTOSSPredictor(self.model_name, self.device) + + +if __name__ == "__main__": + unittest.main() diff --git a/privacy_guard/attacks/extraction/predictors/tests/test_huggingface_predictor.py b/privacy_guard/attacks/extraction/predictors/tests/test_huggingface_predictor.py index 14da9b0..4770b3d 100644 --- a/privacy_guard/attacks/extraction/predictors/tests/test_huggingface_predictor.py +++ b/privacy_guard/attacks/extraction/predictors/tests/test_huggingface_predictor.py @@ -90,6 +90,43 @@ def test_generate(self, mock_load_model_and_tokenizer: MagicMock) -> None: self.assertEqual(result, ["Generated text"]) self.mock_model.generate.assert_called_once() + @patch( + "privacy_guard.attacks.extraction.predictors.huggingface_predictor.load_model_and_tokenizer" + ) + def test_generate_no_prompt_in_result( + self, mock_load_model_and_tokenizer: MagicMock + ) -> None: + """Test generate functionality.""" + mock_load_model_and_tokenizer.return_value = ( + self.mock_model, + self.mock_tokenizer, + ) + + # Mock tokenizer responses + mock_inputs = MagicMock() + mock_inputs.to.return_value = { + "input_ids": torch.tensor([[1, 2, 3]]), + "attention_mask": torch.tensor([[1, 1, 1]]), + } + self.mock_tokenizer.return_value = mock_inputs + self.mock_tokenizer.batch_decode.return_value = [ + "Generated text without prompt" + ] + + predictor = HuggingFacePredictor( + self.model_name, self.device, include_prompt_in_generation_result=False + ) + + # Mock the tqdm within the generate method - patch the specific import + with patch( + "privacy_guard.attacks.extraction.predictors.huggingface_predictor.tqdm" + ) as mock_tqdm: + mock_tqdm.side_effect = lambda x, **kwargs: x + result = predictor.generate(["Test prompt"]) + + self.assertEqual(result, ["Generated text without prompt"]) + self.mock_model.generate.assert_called_once() + @patch( "privacy_guard.attacks.extraction.predictors.huggingface_predictor.load_model_and_tokenizer" ) diff --git a/privacy_guard/attacks/extraction/tests/test_generation_attack.py b/privacy_guard/attacks/extraction/tests/test_generation_attack.py index c644d6f..c95c614 100644 --- a/privacy_guard/attacks/extraction/tests/test_generation_attack.py +++ b/privacy_guard/attacks/extraction/tests/test_generation_attack.py @@ -58,7 +58,7 @@ def test_generation_attack_no_output_file(self) -> None: # Verify predictor was called correctly self.mock_predictor.generate.assert_called_once_with( - prompts=["prompt 1", "prompt 2"], temperature=1, top_k=40 + prompts=["prompt 1", "prompt 2"], batch_size=1, temperature=1, top_k=40 ) # Verify result structure @@ -84,7 +84,7 @@ def test_generation_attack_with_output_file(self) -> None: # Verify predictor was called correctly self.mock_predictor.generate.assert_called_once_with( - prompts=["prompt 1", "prompt 2"], temperature=1, top_k=40 + prompts=["prompt 1", "prompt 2"], batch_size=1, temperature=1, top_k=40 ) # Verify output file was created @@ -120,7 +120,7 @@ def test_generation_attack_custom_columns(self) -> None: # Verify predictor was called with correct prompts self.mock_predictor.generate.assert_called_once_with( - prompts=["test prompt 1", "test prompt 2"] + prompts=["test prompt 1", "test prompt 2"], batch_size=1 ) # Verify custom column names