Skip to content

Commit fe4df99

Browse files
mgrange1998facebook-github-bot
authored andcommitted
GenerationAttack and HuggingFace predictor updates
Summary: This change updates GenerationAttack and HuggingFace predictor with the following functionality. - Fixes "batch_size=self.batch_size" s.t it is propagated downstream properly. Updated the default value to 1 to reflect intended behavior - Adds "_generate_decode_logic" helper to encapsulate calling "model.generate" - Extends generation functionality to not include prompt when decoding Differential Revision: D87341640
1 parent 97bc5d7 commit fe4df99

File tree

3 files changed

+76
-10
lines changed

3 files changed

+76
-10
lines changed

privacy_guard/attacks/extraction/generation_attack.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
input_column: str = "prompt",
9292
target_column: str = "target",
9393
output_column: str = "prediction",
94-
batch_size: int = 4,
94+
batch_size: int = 1,
9595
**generation_kwargs: Any,
9696
) -> None:
9797
if output_file is None and output_format is not None:
@@ -133,6 +133,7 @@ def run_attack(self) -> TextInclusionAnalysisInput:
133133
logger.info(f"Generating text for {len(prompts)} prompts")
134134
generations = self.predictor.generate(
135135
prompts=prompts,
136+
batch_size=self.batch_size,
136137
**self.generation_kwargs,
137138
)
138139

privacy_guard/attacks/extraction/predictors/huggingface_predictor.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
device: str | None = None,
3737
model_kwargs: Dict[str, Any] | None = None,
3838
tokenizer_kwargs: Dict[str, Any] | None = None,
39+
include_prompt_in_generation_result: bool = True,
3940
**kwargs: Any,
4041
) -> None:
4142
self.model_name: str = model_name
@@ -51,6 +52,7 @@ def __init__(
5152
self.tokenizer_kwargs: Dict[str, Any] = tokenizer_kwargs or {}
5253
self.model: PreTrainedModel
5354
self.tokenizer: PreTrainedTokenizer
55+
self.include_prompt_in_generation_result = include_prompt_in_generation_result
5456
# Model already loaded on device - now pass the kwargs
5557
self.model, self.tokenizer = load_model_and_tokenizer(
5658
model_name,
@@ -73,15 +75,16 @@ def preprocess_batch(self, batch: List[str]) -> List[str]:
7375
clean_batch.append(item)
7476
return clean_batch
7577

76-
def _generate_process_batch(
77-
self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any
78+
def _generate_decode_logic(
79+
self,
80+
inputs: Dict[str, Any],
81+
max_new_tokens: int = 512,
82+
**generation_kwargs: Any,
7883
) -> List[str]:
79-
"""Process a single batch of prompts."""
80-
clean_batch = self.preprocess_batch(batch)
81-
82-
inputs = self.tokenizer(
83-
clean_batch, return_tensors="pt", padding=True, truncation=True
84-
).to(self.device)
84+
"""Calls the correct generate call based on the model type.
85+
Supports logic for returning only the generated text, or the full sample including
86+
the prompt."""
87+
include_prompt_in_generation_result = self.include_prompt_in_generation_result
8588

8689
with torch.no_grad():
8790
# Handle both regular models and DDP-wrapped models
@@ -94,10 +97,35 @@ def _generate_process_batch(
9497
**inputs, max_new_tokens=max_new_tokens, **generation_kwargs
9598
)
9699

97-
batch_results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
100+
if include_prompt_in_generation_result:
101+
batch_results = self.tokenizer.batch_decode(
102+
outputs, skip_special_tokens=True
103+
)
104+
else:
105+
trimmed_outputs = []
106+
for output, input_val in zip(outputs, inputs["input_ids"]):
107+
trimmed_outputs.append(output[len(input_val) :])
108+
109+
batch_results = self.tokenizer.batch_decode(
110+
trimmed_outputs, skip_special_tokens=True
111+
)
98112

99113
return batch_results
100114

115+
def _generate_process_batch(
116+
self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any
117+
) -> List[str]:
118+
"""Process a single batch of prompts."""
119+
clean_batch = self.preprocess_batch(batch)
120+
121+
inputs = self.tokenizer(
122+
clean_batch, return_tensors="pt", padding=True, truncation=True
123+
).to(self.device)
124+
125+
return self._generate_decode_logic(
126+
inputs=inputs, max_new_tokens=max_new_tokens, **generation_kwargs
127+
)
128+
101129
def generate(self, prompts: List[str], **generation_kwargs: Any) -> List[str]:
102130
"""Generate text continuations for given prompts."""
103131
if not prompts:

privacy_guard/attacks/extraction/predictors/tests/test_huggingface_predictor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,43 @@ def test_generate(self, mock_load_model_and_tokenizer: MagicMock) -> None:
9090
self.assertEqual(result, ["Generated text"])
9191
self.mock_model.generate.assert_called_once()
9292

93+
@patch(
94+
"privacy_guard.attacks.extraction.predictors.huggingface_predictor.load_model_and_tokenizer"
95+
)
96+
def test_generate_no_prompt_in_result(
97+
self, mock_load_model_and_tokenizer: MagicMock
98+
) -> None:
99+
"""Test generate functionality."""
100+
mock_load_model_and_tokenizer.return_value = (
101+
self.mock_model,
102+
self.mock_tokenizer,
103+
)
104+
105+
# Mock tokenizer responses
106+
mock_inputs = MagicMock()
107+
mock_inputs.to.return_value = {
108+
"input_ids": torch.tensor([[1, 2, 3]]),
109+
"attention_mask": torch.tensor([[1, 1, 1]]),
110+
}
111+
self.mock_tokenizer.return_value = mock_inputs
112+
self.mock_tokenizer.batch_decode.return_value = [
113+
"Generated text without prompt"
114+
]
115+
116+
predictor = HuggingFacePredictor(
117+
self.model_name, self.device, include_prompt_in_generation_result=False
118+
)
119+
120+
# Mock the tqdm within the generate method - patch the specific import
121+
with patch(
122+
"privacy_guard.attacks.extraction.predictors.huggingface_predictor.tqdm"
123+
) as mock_tqdm:
124+
mock_tqdm.side_effect = lambda x, **kwargs: x
125+
result = predictor.generate(["Test prompt"])
126+
127+
self.assertEqual(result, ["Generated text without prompt"])
128+
self.mock_model.generate.assert_called_once()
129+
93130
@patch(
94131
"privacy_guard.attacks.extraction.predictors.huggingface_predictor.load_model_and_tokenizer"
95132
)

0 commit comments

Comments
 (0)