Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion privacy_guard/attacks/extraction/generation_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
input_column: str = "prompt",
target_column: str = "target",
output_column: str = "prediction",
batch_size: int = 4,
batch_size: int = 1,
**generation_kwargs: Any,
) -> None:
if output_file is None and output_format is not None:
Expand Down Expand Up @@ -133,6 +133,7 @@ def run_attack(self) -> TextInclusionAnalysisInput:
logger.info(f"Generating text for {len(prompts)} prompts")
generations = self.predictor.generate(
prompts=prompts,
batch_size=self.batch_size,
**self.generation_kwargs,
)

Expand Down
141 changes: 141 additions & 0 deletions privacy_guard/attacks/extraction/predictors/gpt_oss_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-strict

"""
GPT OSS predictor implementation for openai extraction attacks.
"""

from typing import Any, Dict, List

import transformers.utils.import_utils

from privacy_guard.attacks.extraction.predictors.huggingface_predictor import (
HuggingFacePredictor,
)
from transformers.utils.import_utils import (
_is_package_available,
is_accelerate_available,
)


class GPTOSSPredictor(HuggingFacePredictor):
"""
Inherits from HuggingFacePredictor and updates the generation logic to match
GPT OSS expectation.

Use this predictor for models like "gpt-oss-20b" and "gpt-oss-120b"

Note: HuggingFacePredictor "get_logits" and "get_logprobs" behavior is
not yet tested w/ GPTOSSPredictor
"""

def __init__(
self,
*args: Any,
**kwargs: Any,
) -> None:
accelerate_available = self.accelerate_available_workaround()
if not accelerate_available:
raise ImportError(
'Required library "accelerate" for GPT OSS not available'
)

super().__init__(
*args,
**kwargs,
)

def accelerate_available_workaround(self) -> bool:
"""
In old transformers versions, availability for the required 'accelerate' package
is checked once at import time and the result is saved for all future checks.

For Meta internal packaging this check returns as false at import time even when
the package is available at runtime.

This is a workaround which updates the saved values in transformers
when this class is initialized.

See the following link to the old transformers code pointer.
https://github.com/huggingface/transformers/blob/
e95441bdb586a7c3c9b4f61a41e99178c1becf54/src/transformers/utils/import_utils.py#L126
"""
if is_accelerate_available():
return True

_accelerate_available, _accelerate_version = ( # pyre-ignore
_is_package_available("accelerate", return_version=True)
)

if _accelerate_available:
transformers.utils.import_utils._accelerate_available = (
_accelerate_available
)
transformers.utils.import_utils._accelerate_version = _accelerate_version

return is_accelerate_available()

return False

def preprocess_batch_messages(self, batch: List[str]) -> List[Dict[str, str]]:
"""
Prepare a batch of messages for prediction.

Differs than parent HuggingfacePredictor in that it returns a list of Dict
instead of str, and includes "role" user field.
"""
clean_batch = []
for item in batch:
if not isinstance(item, str):
raise Warning(f"Found non-string item in batch: {type(item)}")
clean_batch.append(str(item) if item is not None else "")
else:
clean_batch.append({"role": "user", "content": item})
return clean_batch

# Override
def _generate_process_batch(
self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any
) -> List[str]:
"""Process a single batch of prompts.
apply_chat_template is used to apply the harmony response format, required for
gpt models to work properly.
"""
clean_batch: List[Dict[str, str]] = self.preprocess_batch_messages(batch)

# Different than parent HuggingfacePredictor class
add_generation_prompt = (
True
if "add_generation_prompt" not in generation_kwargs
else generation_kwargs.pop("add_generation_prompt")
)
reasoning_effort = (
"medium"
if "reasoning_effort" not in generation_kwargs
else generation_kwargs.pop("reasoning_effort")
)
inputs = self.tokenizer.apply_chat_template( # pyre-ignore
clean_batch,
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_dict=True,
return_tensors="pt",
reasoning_effort=reasoning_effort,
).to(self.device)

return self._generate_decode_logic(
inputs=inputs, max_new_tokens=max_new_tokens, **generation_kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,23 @@ def __init__(
device: str | None = None,
model_kwargs: Dict[str, Any] | None = None,
tokenizer_kwargs: Dict[str, Any] | None = None,
include_prompt_in_generation_result: bool = True,
**kwargs: Any,
) -> None:
"""
A predictor class that leverages HuggingFace models for text generation and inference.
This class wraps a HuggingFace model and tokenizer, allowing for flexible configuration
and usage on different devices (CPU/Cuda GPU). It supports passing custom arguments to both
the model and tokenizer, and can optionally include the input prompt in the generation result.
Args:
model_name: The name or path of the HuggingFace model to load (locally or remotely from Huggingface).
device: The device to run the model on (e.g., 'cpu', 'cuda') If None, selects device based on cuda availability
model_kwargs: Additional keyword arguments to pass to the model during initialization. Defaults to None.
tokenizer_kwargs: Additional keyword arguments to pass to the tokenizer during initialization. Defaults to None.
include_prompt_in_generation_result: If True, the generation results will incldue the full prompt.
If false, the result will decode only the newly generated tokens and not the input prompt.
**kwargs: Additional keyword arguments for the base predictor or other custom settings.
"""
self.model_name: str = model_name
self.device: str = (
device
Expand All @@ -51,6 +66,7 @@ def __init__(
self.tokenizer_kwargs: Dict[str, Any] = tokenizer_kwargs or {}
self.model: PreTrainedModel
self.tokenizer: PreTrainedTokenizer
self.include_prompt_in_generation_result = include_prompt_in_generation_result
# Model already loaded on device - now pass the kwargs
self.model, self.tokenizer = load_model_and_tokenizer(
model_name,
Expand All @@ -73,15 +89,16 @@ def preprocess_batch(self, batch: List[str]) -> List[str]:
clean_batch.append(item)
return clean_batch

def _generate_process_batch(
self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any
def _generate_decode_logic(
self,
inputs: Dict[str, Any],
max_new_tokens: int = 512,
**generation_kwargs: Any,
) -> List[str]:
"""Process a single batch of prompts."""
clean_batch = self.preprocess_batch(batch)

inputs = self.tokenizer(
clean_batch, return_tensors="pt", padding=True, truncation=True
).to(self.device)
"""Calls the correct generate call based on the model type.
Supports logic for returning only the generated text, or the full sample including
the prompt."""
include_prompt_in_generation_result = self.include_prompt_in_generation_result

with torch.no_grad():
# Handle both regular models and DDP-wrapped models
Expand All @@ -94,10 +111,35 @@ def _generate_process_batch(
**inputs, max_new_tokens=max_new_tokens, **generation_kwargs
)

batch_results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
if include_prompt_in_generation_result:
batch_results = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True
)
else:
trimmed_outputs = []
for output, input_val in zip(outputs, inputs["input_ids"]):
trimmed_outputs.append(output[len(input_val) :])

batch_results = self.tokenizer.batch_decode(
trimmed_outputs, skip_special_tokens=True
)

return batch_results

def _generate_process_batch(
self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any
) -> List[str]:
"""Process a single batch of prompts."""
clean_batch = self.preprocess_batch(batch)

inputs = self.tokenizer(
clean_batch, return_tensors="pt", padding=True, truncation=True
).to(self.device)

return self._generate_decode_logic(
inputs=inputs, max_new_tokens=max_new_tokens, **generation_kwargs
)

def generate(self, prompts: List[str], **generation_kwargs: Any) -> List[str]:
"""Generate text continuations for given prompts."""
if not prompts:
Expand Down
Loading
Loading