Skip to content

Commit 406c5ce

Browse files
mgrange1998facebook-github-bot
authored andcommitted
Add GPTOSSPredictor to apply correct template format (#82)
Summary: Pull Request resolved: #82 Reviewed By: s-huu Differential Revision: D86975066
1 parent d9b2547 commit 406c5ce

File tree

2 files changed

+342
-0
lines changed

2 files changed

+342
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
17+
"""
18+
GPT OSS predictor implementation for openai extraction attacks.
19+
"""
20+
21+
from typing import Any, Dict, List
22+
23+
import transformers.utils.import_utils
24+
25+
from privacy_guard.attacks.extraction.predictors.huggingface_predictor import (
26+
HuggingFacePredictor,
27+
)
28+
from transformers.utils.import_utils import (
29+
_is_package_available,
30+
is_accelerate_available,
31+
)
32+
33+
34+
class GPTOSSPredictor(HuggingFacePredictor):
35+
"""
36+
Inherits from HuggingFacePredictor and updates the generation logic to match
37+
GPT OSS expectation.
38+
39+
Use this predictor for models like "gpt-oss-20b" and "gpt-oss-120b"
40+
41+
Note: HuggingFacePredictor "get_logits" and "get_logprobs" behavior is
42+
not yet tested w/ GPTOSSPredictor
43+
"""
44+
45+
def __init__(
46+
self,
47+
*args: Any,
48+
**kwargs: Any,
49+
) -> None:
50+
accelerate_available = self.accelerate_available_workaround()
51+
if not accelerate_available:
52+
raise ImportError(
53+
'Required library "accelerate" for GPT OSS not available'
54+
)
55+
56+
super().__init__(
57+
*args,
58+
**kwargs,
59+
)
60+
61+
def accelerate_available_workaround(self) -> bool:
62+
"""
63+
In old transformers versions, availability for the required 'accelerate' package
64+
is checked once at import time and the result is saved for all future checks.
65+
66+
For Meta internal packaging this check returns as false at import time even when
67+
the package is available at runtime.
68+
69+
This is a workaround which updates the saved values in transformers
70+
when this class is initialized.
71+
72+
See the following link to the old transformers code pointer.
73+
https://github.com/huggingface/transformers/blob/
74+
e95441bdb586a7c3c9b4f61a41e99178c1becf54/src/transformers/utils/import_utils.py#L126
75+
"""
76+
if is_accelerate_available():
77+
return True
78+
79+
_accelerate_available, _accelerate_version = ( # pyre-ignore
80+
_is_package_available("accelerate", return_version=True)
81+
)
82+
83+
if _accelerate_available:
84+
transformers.utils.import_utils._accelerate_available = (
85+
_accelerate_available
86+
)
87+
transformers.utils.import_utils._accelerate_version = _accelerate_version
88+
89+
return is_accelerate_available()
90+
91+
return False
92+
93+
def preprocess_batch_messages(self, batch: List[str]) -> List[Dict[str, str]]:
94+
"""
95+
Prepare a batch of messages for prediction.
96+
97+
Differs than parent HuggingfacePredictor in that it returns a list of Dict
98+
instead of str, and includes "role" user field.
99+
"""
100+
clean_batch = []
101+
for item in batch:
102+
if not isinstance(item, str):
103+
raise Warning(f"Found non-string item in batch: {type(item)}")
104+
clean_batch.append(str(item) if item is not None else "")
105+
else:
106+
clean_batch.append({"role": "user", "content": item})
107+
return clean_batch
108+
109+
# Override
110+
def _generate_process_batch(
111+
self, batch: List[str], max_new_tokens: int = 512, **generation_kwargs: Any
112+
) -> List[str]:
113+
"""Process a single batch of prompts.
114+
apply_chat_template is used to apply the harmony response format, required for
115+
gpt models to work properly.
116+
"""
117+
clean_batch: List[Dict[str, str]] = self.preprocess_batch_messages(batch)
118+
119+
# Different than parent HuggingfacePredictor class
120+
add_generation_prompt = (
121+
True
122+
if "add_generation_prompt" not in generation_kwargs
123+
else generation_kwargs.pop("add_generation_prompt")
124+
)
125+
reasoning_effort = (
126+
"medium"
127+
if "reasoning_effort" not in generation_kwargs
128+
else generation_kwargs.pop("reasoning_effort")
129+
)
130+
inputs = self.tokenizer.apply_chat_template( # pyre-ignore
131+
clean_batch,
132+
add_generation_prompt=add_generation_prompt,
133+
tokenize=True,
134+
return_dict=True,
135+
return_tensors="pt",
136+
reasoning_effort=reasoning_effort,
137+
).to(self.device)
138+
139+
return self._generate_decode_logic(
140+
inputs=inputs, max_new_tokens=max_new_tokens, **generation_kwargs
141+
)
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
import unittest
17+
from unittest.mock import MagicMock, patch
18+
19+
import torch
20+
from privacy_guard.attacks.extraction.predictors.gpt_oss_predictor import (
21+
GPTOSSPredictor,
22+
)
23+
24+
25+
class TestGPTOSSPredictor(unittest.TestCase):
26+
def setUp(self) -> None:
27+
self.model_name = "test-model"
28+
self.device = "cpu"
29+
self.vocab_size = 50257
30+
31+
# Create simple mocks for model and tokenizer
32+
self.mock_model = MagicMock(
33+
spec=["generate", "config"]
34+
) # Only allow these attributes
35+
self.mock_model.config.vocab_size = self.vocab_size
36+
self.mock_model.generate.return_value = torch.tensor([[1, 2, 3, 4, 5]])
37+
38+
self.mock_tokenizer = MagicMock()
39+
self.mock_tokenizer.pad_token = None
40+
self.mock_tokenizer.eos_token = "<|endoftext|>"
41+
self.mock_tokenizer.pad_token_id = 0
42+
self.mock_tokenizer.batch_decode.return_value = ["Generated text"]
43+
44+
with patch.object(
45+
GPTOSSPredictor, "accelerate_available_workaround", return_value=True
46+
), patch(
47+
"privacy_guard.attacks.extraction.predictors.huggingface_predictor.load_model_and_tokenizer",
48+
return_value=(
49+
self.mock_model,
50+
self.mock_tokenizer,
51+
),
52+
):
53+
self.predictor = GPTOSSPredictor(self.model_name, self.device)
54+
55+
def test_init(self) -> None:
56+
"""Test predictor initialization."""
57+
self.assertEqual(self.predictor.model_name, self.model_name)
58+
self.assertEqual(self.predictor.device, self.device)
59+
60+
def test_generate(self) -> None:
61+
"""Test generate functionality."""
62+
63+
# Mock tokenizer responses
64+
mock_inputs = MagicMock()
65+
mock_inputs.to.return_value = {
66+
"input_ids": torch.tensor([[1, 2, 3]]),
67+
"attention_mask": torch.tensor([[1, 1, 1]]),
68+
}
69+
self.mock_tokenizer.return_value = mock_inputs
70+
self.mock_tokenizer.batch_decode.return_value = ["Generated text"]
71+
72+
# Mock the tqdm within the generate method - patch the specific import
73+
with patch(
74+
"privacy_guard.attacks.extraction.predictors.huggingface_predictor.tqdm"
75+
) as mock_tqdm:
76+
mock_tqdm.side_effect = lambda x, **kwargs: x
77+
result = self.predictor.generate(["Test prompt"])
78+
79+
self.assertEqual(result, ["Generated text"])
80+
self.mock_model.generate.assert_called_once()
81+
82+
def test_generate_with_kwargs(self) -> None:
83+
"""Test generate functionality specifying add_generation_prompt
84+
and reasoning_effort"""
85+
86+
# Mock tokenizer responses
87+
mock_inputs = MagicMock()
88+
mock_inputs.to.return_value = {
89+
"input_ids": torch.tensor([[1, 2, 3]]),
90+
"attention_mask": torch.tensor([[1, 1, 1]]),
91+
}
92+
self.mock_tokenizer.return_value = mock_inputs
93+
self.mock_tokenizer.batch_decode.return_value = ["Generated text"]
94+
95+
# Mock the tqdm within the generate method - patch the specific import
96+
with patch(
97+
"privacy_guard.attacks.extraction.predictors.huggingface_predictor.tqdm"
98+
) as mock_tqdm:
99+
mock_tqdm.side_effect = lambda x, **kwargs: x
100+
result = self.predictor.generate(
101+
["Test prompt"],
102+
add_generation_prompt=True,
103+
reasoning_effort="medium",
104+
)
105+
106+
self.assertEqual(result, ["Generated text"])
107+
self.mock_model.generate.assert_called_once()
108+
109+
@patch(
110+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor.is_accelerate_available"
111+
)
112+
def test_accelerate_available_workaround_when_initially_true(
113+
self, mock_is_accelerate_available: MagicMock
114+
) -> None:
115+
"""Test accelerate_available_workaround when is_accelerate_available is True initially."""
116+
117+
# Setup: mock is_accelerate_available to return True
118+
mock_is_accelerate_available.return_value = True
119+
120+
# Execute: call the workaround method
121+
# accelerate_available_workaround is called in __init__
122+
result = self.predictor.accelerate_available_workaround()
123+
124+
# Assert: method returns True and only checks is_accelerate_available
125+
self.assertTrue(result)
126+
mock_is_accelerate_available.assert_called_once()
127+
128+
@patch(
129+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor._is_package_available"
130+
)
131+
@patch(
132+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor.is_accelerate_available"
133+
)
134+
def test_accelerate_available_workaround_when_package_available(
135+
self,
136+
mock_is_accelerate_available: MagicMock,
137+
mock_is_package_available: MagicMock,
138+
) -> None:
139+
"""Test when is_accelerate_available is initially false but _is_package_available returns true."""
140+
141+
# Setup: mock is_accelerate_available to return False initially, then True after workaround
142+
mock_is_accelerate_available.side_effect = [False, True]
143+
144+
# Setup: mock _is_package_available to return True and a version string
145+
mock_is_package_available.return_value = (True, "0.21.0")
146+
147+
# Execute: call the workaround method
148+
result = self.predictor.accelerate_available_workaround()
149+
150+
# Assert: method returns True after setting the accelerate availability
151+
self.assertTrue(result)
152+
self.assertEqual(mock_is_accelerate_available.call_count, 2)
153+
mock_is_package_available.assert_called_once()
154+
# mock_import_utils._is_package_available.assert_called_once_with(
155+
# "accelerate", return_version=True
156+
# )
157+
158+
@patch(
159+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor._is_package_available"
160+
)
161+
@patch(
162+
"privacy_guard.attacks.extraction.predictors.gpt_oss_predictor.is_accelerate_available"
163+
)
164+
def test_accelerate_available_workaround_when_both_false(
165+
self,
166+
mock_is_accelerate_available: MagicMock,
167+
mock_is_package_available: MagicMock,
168+
) -> None:
169+
"""Test when both is_accelerate_available and _is_package_available are false."""
170+
171+
# Setup: mock is_accelerate_available to return False
172+
mock_is_accelerate_available.return_value = False
173+
174+
# Setup: mock _is_package_available to return False
175+
mock_is_package_available.return_value = (False, "N/A")
176+
177+
# Execute: call the workaround method
178+
result = self.predictor.accelerate_available_workaround()
179+
180+
# Assert: method returns False
181+
self.assertFalse(result)
182+
mock_is_accelerate_available.assert_called_once()
183+
mock_is_package_available.assert_called_once()
184+
# mock_import_utils._is_package_available.assert_called_once_with(
185+
# "accelerate", return_version=True
186+
# )
187+
188+
def test_init_fails_when_accelerate_not_available(
189+
self,
190+
) -> None:
191+
"""Test that instantiating GPTOSSPredictor when accelerate is not available
192+
raises exception."""
193+
with self.assertRaises(ImportError):
194+
with patch.object(
195+
GPTOSSPredictor, "accelerate_available_workaround", return_value=False
196+
):
197+
_ = GPTOSSPredictor(self.model_name, self.device)
198+
199+
200+
if __name__ == "__main__":
201+
unittest.main()

0 commit comments

Comments
 (0)