Skip to content

Commit 3bec67b

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Fix tests in test/test_serialization that were failing if run individually (pytorch#141300)
pytorch#140739 and pytorch#140740 made it such that `get_safe_globals` no longer return an empty List by default This caused some tests that check the content of `get_safe_globals` to fail, in particular when run individually (they didn't fail in test suite as other tests ran before them called `clear_safe_globals`) but will fail when tests are run individually [T208186010](https://www.internalfb.com/intern/tasks/?t=208186010) test_safe_globals_for_weights_only test_safe_globals_context_manager_weights_only This PR fixes that and also makes most tests calling `clear_safe_globals` use the `safe_globals` context manager rather than try: finally Pull Request resolved: pytorch#141300 Approved by: https://github.com/awgu
1 parent dbe6fce commit 3bec67b

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

test/test_serialization.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,12 +1099,9 @@ def __reduce__(self):
10991099
# Safe load should assert
11001100
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
11011101
torch.load(f, weights_only=True)
1102-
try:
1103-
torch.serialization.add_safe_globals([print])
1102+
with torch.serialization.safe_globals([print]):
11041103
f.seek(0)
11051104
torch.load(f, weights_only=True)
1106-
finally:
1107-
torch.serialization.clear_safe_globals()
11081105

11091106
def test_weights_only_safe_globals_newobj(self):
11101107
# This will use NEWOBJ
@@ -1116,12 +1113,9 @@ def test_weights_only_safe_globals_newobj(self):
11161113
"GLOBAL __main__.Point was not an allowed global by default"):
11171114
torch.load(f, weights_only=True)
11181115
f.seek(0)
1119-
try:
1120-
torch.serialization.add_safe_globals([Point])
1116+
with torch.serialization.safe_globals([Point]):
11211117
loaded_p = torch.load(f, weights_only=True)
11221118
self.assertEqual(loaded_p, p)
1123-
finally:
1124-
torch.serialization.clear_safe_globals()
11251119

11261120
def test_weights_only_safe_globals_build(self):
11271121
counter = 0
@@ -1138,21 +1132,20 @@ def fake_set_state(obj, *args):
11381132
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
11391133
torch.load(f, weights_only=True)
11401134
try:
1141-
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
1142-
# Test dict update path
1143-
f.seek(0)
1144-
loaded_c = torch.load(f, weights_only=True)
1145-
self.assertEqual(loaded_c.num, 2)
1146-
self.assertEqual(loaded_c.foo, 'bar')
1147-
# Test setstate path
1148-
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
1149-
f.seek(0)
1150-
loaded_c = torch.load(f, weights_only=True)
1151-
self.assertEqual(loaded_c.num, 2)
1152-
self.assertEqual(counter, 1)
1153-
self.assertFalse(hasattr(loaded_c, 'foo'))
1135+
with torch.serialization.safe_globals([ClassThatUsesBuildInstruction]):
1136+
# Test dict update path
1137+
f.seek(0)
1138+
loaded_c = torch.load(f, weights_only=True)
1139+
self.assertEqual(loaded_c.num, 2)
1140+
self.assertEqual(loaded_c.foo, 'bar')
1141+
# Test setstate path
1142+
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
1143+
f.seek(0)
1144+
loaded_c = torch.load(f, weights_only=True)
1145+
self.assertEqual(loaded_c.num, 2)
1146+
self.assertEqual(counter, 1)
1147+
self.assertFalse(hasattr(loaded_c, 'foo'))
11541148
finally:
1155-
torch.serialization.clear_safe_globals()
11561149
ClassThatUsesBuildInstruction.__setstate__ = None
11571150

11581151
@parametrize("slots", ['some', 'all'])
@@ -4629,10 +4622,12 @@ def test_safe_globals_for_weights_only(self):
46294622
sd = torch.load(f, weights_only=True)
46304623

46314624
# Loading tensor subclass should work if the class is marked safe
4625+
safe_globals_before = torch.serialization.get_safe_globals()
46324626
f.seek(0)
46334627
try:
46344628
torch.serialization.add_safe_globals([TwoTensor])
4635-
self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor])
4629+
expected_safe_globals = set(safe_globals_before + [TwoTensor])
4630+
self.assertEqual(set(torch.serialization.get_safe_globals()), expected_safe_globals)
46364631
sd = torch.load(f, weights_only=True)
46374632
self.assertEqual(sd['t'], t)
46384633
self.assertEqual(sd['p'], p)
@@ -4645,6 +4640,7 @@ def test_safe_globals_for_weights_only(self):
46454640
torch.load(f, weights_only=True)
46464641
finally:
46474642
torch.serialization.clear_safe_globals()
4643+
torch.serialization.add_safe_globals(safe_globals_before)
46484644

46494645
def test_safe_globals_context_manager_weights_only(self):
46504646
'''
@@ -4654,20 +4650,23 @@ def test_safe_globals_context_manager_weights_only(self):
46544650
p = torch.nn.Parameter(t)
46554651
sd = OrderedDict([('t', t), ('p', p)])
46564652

4653+
safe_globals_before = torch.serialization.get_safe_globals()
46574654
try:
46584655
torch.serialization.add_safe_globals([TestEmptySubclass])
46594656
with tempfile.NamedTemporaryFile() as f:
46604657
torch.save(sd, f)
46614658
with safe_globals([TwoTensor]):
46624659
f.seek(0)
46634660
torch.load(f, weights_only=True)
4664-
self.assertTrue(torch.serialization.get_safe_globals() == [TestEmptySubclass])
4661+
expected_safe_globals = set(safe_globals_before + [TestEmptySubclass])
4662+
self.assertEqual(set(torch.serialization.get_safe_globals()), expected_safe_globals)
46654663
f.seek(0)
46664664
with self.assertRaisesRegex(pickle.UnpicklingError,
46674665
"Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"):
46684666
torch.load(f, weights_only=True)
46694667
finally:
46704668
torch.serialization.clear_safe_globals()
4669+
torch.serialization.add_safe_globals(safe_globals_before)
46714670

46724671
def test_sets_are_loadable_with_weights_only(self):
46734672
s = {1, 2, 3}

0 commit comments

Comments
 (0)