Skip to content

Commit e4de359

Browse files
mgrange1998facebook-github-bot
authored andcommitted
Update PACKAGE transformers library version to support gpt-oss (#81)
Summary: To unblock the tutorial PrivacyGuard Tutorial: Memorization of Online Content in LLMs, updating the transformers version to enable gpt-oss to be used. N8564502 Differential Revision: D86675884
1 parent a8d2782 commit e4de359

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

privacy_guard/attacks/extraction/predictors/huggingface_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def _get_logits_batch(
284284

285285
with torch.no_grad():
286286
# Get model outputs for the entire batch
287-
outputs = (
287+
outputs = ( # pyre-fixme[29]: `typing.Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function.
288288
self.model.module # pyre-ignore
289289
if hasattr(self.model, "module")
290290
else self.model

privacy_guard/attacks/extraction/tests/test_generation_attack.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from unittest.mock import MagicMock
2121

22+
import transformers
23+
from packaging.version import Version
2224
from privacy_guard.attacks.extraction.generation_attack import GenerationAttack
2325

2426

@@ -145,6 +147,17 @@ def test_generation_attack_missing_column(self) -> None:
145147
self.assertIn("Missing required columns", str(context.exception))
146148
bad_input_file.close()
147149

150+
def test_transformers_version_in_generation_attack(self) -> None:
151+
"""Verify that transformers version is greater than or equal to 4.55.0"""
152+
current_version = transformers.__version__
153+
required_version = "4.55.0"
154+
155+
self.assertGreaterEqual(
156+
Version(current_version),
157+
Version(required_version),
158+
f"Transformers version {current_version} must be greater than or equal to 4.55.0",
159+
)
160+
148161
def tearDown(self) -> None:
149162
"""Clean up temporary files."""
150163
self.input_file.close()

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ dependencies = [
3636
"torch",
3737
'tqdm',
3838
'textdistance',
39-
'transformers',
39+
'transformers>=4.55.0',
40+
'accelerate',
4041
'later',
4142
]
4243

0 commit comments

Comments
 (0)