Skip to content

Commit 9f12bf9

Browse files
mgrange1998facebook-github-bot
authored andcommitted
Add GPTOSSPredictor to apply correct template format (facebookresearch#82)
Summary: Pull Request resolved: facebookresearch#82 Differential Revision: D86975066
1 parent 6699947 commit 9f12bf9

File tree

2 files changed

+325
-0
lines changed

2 files changed

+325
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
17+
"""
18+
HuggingFace predictor implementation for GenAI extraction attacks.
19+
"""
20+
21+
from typing import Any, Dict, List
22+
23+
import torch
24+
25+
import transformers.utils.import_utils
26+
27+
from privacy_guard.attacks.extraction.predictors.huggingface_predictor import (
28+
HuggingFacePredictor,
29+
)
30+
from transformers.utils.import_utils import (
31+
_is_package_available,
32+
is_accelerate_available,
33+
)
34+
35+
36+
class GPTOSSPredictor(HuggingFacePredictor):
37+
"""
38+
Inherits from HuggingFacePredictor and updates the generation logic to match
39+
GPT OSS expectation.
40+
41+
Use this predictor for models like "gpt-oss-20b" and "gpt-oss-120b"
42+
43+
Note: HuggingFacePredictor "get_logits" and "get_logprobs" behavior is
44+
not yet tested w/ GPTOSSPredictor
45+
"""
46+
47+
def __init__(
48+
self,
49+
model_name: str,
50+
device: str | None = None,
51+
model_kwargs: Dict[str, Any] | None = None,
52+
tokenizer_kwargs: Dict[str, Any] | None = None,
53+
**kwargs: Any,
54+
) -> None:
55+
accelerate_available = self.accelerate_available_workaround()
56+
if not accelerate_available:
57+
raise ImportError(
58+
'Required library "accelerate" for GPT OSS not available'
59+
)
60+
61+
super().__init__(
62+
model_name=model_name,
63+
device=device,
64+
model_kwargs=model_kwargs,
65+
tokenizer_kwargs=tokenizer_kwargs,
66+
**kwargs,
67+
)
68+
69+
def accelerate_available_workaround(self) -> bool:
70+
"""
71+
In old transformers versions, availability for the required 'accelerate' package
72+
is checked once at import time and the result is saved for all future checks.
73+
74+
For Meta internal packaging this check returns as false at import time even when
75+
the package is available at runtime.
76+
77+
This is a workaround which updates the saved values in transformers
78+
when this class is initialized.
79+
80+
See the following link to the old transformers code pointer.
81+
https://github.com/huggingface/transformers/blob/
82+
e95441bdb586a7c3c9b4f61a41e99178c1becf54/src/transformers/utils/import_utils.py#L126
83+
"""
84+
if is_accelerate_available():
85+
return True
86+
87+
_accelerate_available, _accelerate_version = ( # pyre-ignore
88+
_is_package_available("accelerate", return_version=True)
89+
)
90+
91+
if _accelerate_available:
92+
transformers.utils.import_utils._accelerate_available = (
93+
_accelerate_available
94+
)
95+
transformers.utils.import_utils._accelerate_version = _accelerate_version
96+
97+
return is_accelerate_available()
98+
99+
return False
100+
101+
def preprocess_batch_messages(self, batch: List[str]) -> List[Dict[str, str]]:
102+
"""
103+
Prepare a batch of messages for prediction.
104+
105+
Differs than parent HuggingfacePredictor in that it returns a list of Dict
106+
instead of str, and includes "role" user field.
107+
"""
108+
clean_batch = []
109+
for item in batch:
110+
if not isinstance(item, str):
111+
raise Warning(f"Found non-string item in batch: {type(item)}")
112+
clean_batch.append(str(item) if item is not None else "")
113+
else:
114+
clean_batch.append({"role": "user", "content": item})
115+
return clean_batch
116+
117+
# Override
118+
def _generate_process_batch(
119+
self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any
120+
) -> List[str]:
121+
"""Process a single batch of prompts.
122+
apply_chat_template is used to apply the harmony response format, required for
123+
gpt models to work properly.
124+
"""
125+
clean_batch: List[Dict[str, str]] = self.preprocess_batch_messages(batch)
126+
127+
# Different than parent HuggingfacePredictor class
128+
inputs = self.tokenizer.apply_chat_template( # pyre-ignore
129+
clean_batch,
130+
add_generation_prompt=True,
131+
tokenize=True,
132+
return_dict=True,
133+
return_tensors="pt",
134+
).to(self.device)
135+
# Everything after is the same as parent
136+
137+
with torch.no_grad():
138+
# Handle both regular models and DDP-wrapped models
139+
# TODO: identify which of these paths is utilized for GPT OSS
140+
if hasattr(self.model, "module"):
141+
outputs = self.model.module.generate( # pyre-ignore
142+
**inputs, max_new_tokens=max_new_tokens, **generation_kwargs
143+
)
144+
else:
145+
outputs = self.model.generate( # pyre-ignore
146+
**inputs, max_new_tokens=max_new_tokens, **generation_kwargs
147+
)
148+
149+
batch_results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
150+
151+
return batch_results
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
import unittest
17+
from unittest.mock import MagicMock, patch
18+
19+
import torch
20+
from privacy_guard.attacks.extraction.predictors.gpt_oss_predictor import (
21+
GPTOSSPredictor,
22+
)
23+
24+
25+
class TestGPTOSSPredictor(unittest.TestCase):
26+
def setUp(self) -> None:
27+
self.model_name = "test-model"
28+
self.device = "cpu"
29+
self.vocab_size = 50257
30+
31+
# Create simple mocks for model and tokenizer
32+
self.mock_model = MagicMock(
33+
spec=["generate", "config"]
34+
) # Only allow these attributes
35+
self.mock_model.config.vocab_size = self.vocab_size
36+
self.mock_model.generate.return_value = torch.tensor([[1, 2, 3, 4, 5]])
37+
38+
self.mock_tokenizer = MagicMock()
39+
self.mock_tokenizer.pad_token = None
40+
self.mock_tokenizer.eos_token = "<|endoftext|>"
41+
self.mock_tokenizer.pad_token_id = 0
42+
self.mock_tokenizer.batch_decode.return_value = ["Generated text"]
43+
44+
with patch.object(
45+
GPTOSSPredictor, "accelerate_available_workaround", return_value=True
46+
), patch(
47+
"privacy_guard.attacks.extraction.predictors.huggingface_predictor.load_model_and_tokenizer",
48+
return_value=(
49+
self.mock_model,
50+
self.mock_tokenizer,
51+
),
52+
):
53+
self.predictor = GPTOSSPredictor(self.model_name, self.device)
54+
55+
def test_init(self) -> None:
56+
"""Test predictor initialization."""
57+
self.assertEqual(self.predictor.model_name, self.model_name)
58+
self.assertEqual(self.predictor.device, self.device)
59+
60+
def test_generate(self) -> None:
61+
"""Test generate functionality."""
62+
63+
# Mock tokenizer responses
64+
mock_inputs = MagicMock()
65+
mock_inputs.to.return_value = {
66+
"input_ids": torch.tensor([[1, 2, 3]]),
67+
"attention_mask": torch.tensor([[1, 1, 1]]),
68+
}
69+
self.mock_tokenizer.return_value = mock_inputs
70+
self.mock_tokenizer.batch_decode.return_value = ["Generated text"]
71+
72+
# Mock the tqdm within the generate method - patch the specific import
73+
with patch(
74+
"privacy_guard.attacks.extraction.predictors.huggingface_predictor.tqdm"
75+
) as mock_tqdm:
76+
mock_tqdm.side_effect = lambda x, **kwargs: x
77+
result = self.predictor.generate(["Test prompt"])
78+
79+
self.assertEqual(result, ["Generated text"])
80+
self.mock_model.generate.assert_called_once()
81+
82+
@patch(
83+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor.is_accelerate_available"
84+
)
85+
def test_accelerate_available_workaround_when_initially_true(
86+
self, mock_is_accelerate_available: MagicMock
87+
) -> None:
88+
"""Test accelerate_available_workaround when is_accelerate_available is True initially."""
89+
90+
# Setup: mock is_accelerate_available to return True
91+
mock_is_accelerate_available.return_value = True
92+
93+
# Execute: call the workaround method
94+
# accelerate_available_workaround is called in __init__
95+
result = self.predictor.accelerate_available_workaround()
96+
97+
# Assert: method returns True and only checks is_accelerate_available
98+
self.assertTrue(result)
99+
mock_is_accelerate_available.assert_called_once()
100+
101+
@patch(
102+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor._is_package_available"
103+
)
104+
@patch(
105+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor.is_accelerate_available"
106+
)
107+
def test_accelerate_available_workaround_when_package_available(
108+
self,
109+
mock_is_accelerate_available: MagicMock,
110+
mock_is_package_available: MagicMock,
111+
) -> None:
112+
"""Test when is_accelerate_available is initially false but _is_package_available returns true."""
113+
114+
# Setup: mock is_accelerate_available to return False initially, then True after workaround
115+
mock_is_accelerate_available.side_effect = [False, True]
116+
117+
# Setup: mock _is_package_available to return True and a version string
118+
mock_is_package_available.return_value = (True, "0.21.0")
119+
120+
# Execute: call the workaround method
121+
result = self.predictor.accelerate_available_workaround()
122+
123+
# Assert: method returns True after setting the accelerate availability
124+
self.assertTrue(result)
125+
self.assertEqual(mock_is_accelerate_available.call_count, 2)
126+
mock_is_package_available.assert_called_once()
127+
# mock_import_utils._is_package_available.assert_called_once_with(
128+
# "accelerate", return_version=True
129+
# )
130+
131+
@patch(
132+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor._is_package_available"
133+
)
134+
@patch(
135+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor.is_accelerate_available"
136+
)
137+
def test_accelerate_available_workaround_when_both_false(
138+
self,
139+
mock_is_accelerate_available: MagicMock,
140+
mock_is_package_available: MagicMock,
141+
) -> None:
142+
"""Test when both is_accelerate_available and _is_package_available are false."""
143+
144+
# Setup: mock is_accelerate_available to return False
145+
mock_is_accelerate_available.return_value = False
146+
147+
# Setup: mock _is_package_available to return False
148+
mock_is_package_available.return_value = (False, "N/A")
149+
150+
# Execute: call the workaround method
151+
result = self.predictor.accelerate_available_workaround()
152+
153+
# Assert: method returns False
154+
self.assertFalse(result)
155+
mock_is_accelerate_available.assert_called_once()
156+
mock_is_package_available.assert_called_once()
157+
# mock_import_utils._is_package_available.assert_called_once_with(
158+
# "accelerate", return_version=True
159+
# )
160+
161+
def test_init_fails_when_accelerate_not_available(
162+
self,
163+
) -> None:
164+
"""Test that instantiating GPTOSSPredictor when accelerate is not available
165+
raises exception."""
166+
with self.assertRaises(ImportError):
167+
with patch.object(
168+
GPTOSSPredictor, "accelerate_available_workaround", return_value=False
169+
):
170+
_ = GPTOSSPredictor(self.model_name, self.device)
171+
172+
173+
if __name__ == "__main__":
174+
unittest.main()

0 commit comments

Comments
 (0)