@@ -1099,12 +1099,9 @@ def __reduce__(self):
1099
1099
# Safe load should assert
1100
1100
with self .assertRaisesRegex (pickle .UnpicklingError , "Unsupported global: GLOBAL builtins.print" ):
1101
1101
torch .load (f , weights_only = True )
1102
- try :
1103
- torch .serialization .add_safe_globals ([print ])
1102
+ with torch .serialization .safe_globals ([print ]):
1104
1103
f .seek (0 )
1105
1104
torch .load (f , weights_only = True )
1106
- finally :
1107
- torch .serialization .clear_safe_globals ()
1108
1105
1109
1106
def test_weights_only_safe_globals_newobj (self ):
1110
1107
# This will use NEWOBJ
@@ -1116,12 +1113,9 @@ def test_weights_only_safe_globals_newobj(self):
1116
1113
"GLOBAL __main__.Point was not an allowed global by default" ):
1117
1114
torch .load (f , weights_only = True )
1118
1115
f .seek (0 )
1119
- try :
1120
- torch .serialization .add_safe_globals ([Point ])
1116
+ with torch .serialization .safe_globals ([Point ]):
1121
1117
loaded_p = torch .load (f , weights_only = True )
1122
1118
self .assertEqual (loaded_p , p )
1123
- finally :
1124
- torch .serialization .clear_safe_globals ()
1125
1119
1126
1120
def test_weights_only_safe_globals_build (self ):
1127
1121
counter = 0
@@ -1138,21 +1132,20 @@ def fake_set_state(obj, *args):
1138
1132
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default" ):
1139
1133
torch .load (f , weights_only = True )
1140
1134
try :
1141
- torch .serialization .add_safe_globals ([ClassThatUsesBuildInstruction ])
1142
- # Test dict update path
1143
- f .seek (0 )
1144
- loaded_c = torch .load (f , weights_only = True )
1145
- self .assertEqual (loaded_c .num , 2 )
1146
- self .assertEqual (loaded_c .foo , 'bar' )
1147
- # Test setstate path
1148
- ClassThatUsesBuildInstruction .__setstate__ = fake_set_state
1149
- f .seek (0 )
1150
- loaded_c = torch .load (f , weights_only = True )
1151
- self .assertEqual (loaded_c .num , 2 )
1152
- self .assertEqual (counter , 1 )
1153
- self .assertFalse (hasattr (loaded_c , 'foo' ))
1135
+ with torch .serialization .safe_globals ([ClassThatUsesBuildInstruction ]):
1136
+ # Test dict update path
1137
+ f .seek (0 )
1138
+ loaded_c = torch .load (f , weights_only = True )
1139
+ self .assertEqual (loaded_c .num , 2 )
1140
+ self .assertEqual (loaded_c .foo , 'bar' )
1141
+ # Test setstate path
1142
+ ClassThatUsesBuildInstruction .__setstate__ = fake_set_state
1143
+ f .seek (0 )
1144
+ loaded_c = torch .load (f , weights_only = True )
1145
+ self .assertEqual (loaded_c .num , 2 )
1146
+ self .assertEqual (counter , 1 )
1147
+ self .assertFalse (hasattr (loaded_c , 'foo' ))
1154
1148
finally :
1155
- torch .serialization .clear_safe_globals ()
1156
1149
ClassThatUsesBuildInstruction .__setstate__ = None
1157
1150
1158
1151
@parametrize ("slots" , ['some' , 'all' ])
@@ -4629,10 +4622,12 @@ def test_safe_globals_for_weights_only(self):
4629
4622
sd = torch .load (f , weights_only = True )
4630
4623
4631
4624
# Loading tensor subclass should work if the class is marked safe
4625
+ safe_globals_before = torch .serialization .get_safe_globals ()
4632
4626
f .seek (0 )
4633
4627
try :
4634
4628
torch .serialization .add_safe_globals ([TwoTensor ])
4635
- self .assertTrue (torch .serialization .get_safe_globals () == [TwoTensor ])
4629
+ expected_safe_globals = set (safe_globals_before + [TwoTensor ])
4630
+ self .assertEqual (set (torch .serialization .get_safe_globals ()), expected_safe_globals )
4636
4631
sd = torch .load (f , weights_only = True )
4637
4632
self .assertEqual (sd ['t' ], t )
4638
4633
self .assertEqual (sd ['p' ], p )
@@ -4645,6 +4640,7 @@ def test_safe_globals_for_weights_only(self):
4645
4640
torch .load (f , weights_only = True )
4646
4641
finally :
4647
4642
torch .serialization .clear_safe_globals ()
4643
+ torch .serialization .add_safe_globals (safe_globals_before )
4648
4644
4649
4645
def test_safe_globals_context_manager_weights_only (self ):
4650
4646
'''
@@ -4654,20 +4650,23 @@ def test_safe_globals_context_manager_weights_only(self):
4654
4650
p = torch .nn .Parameter (t )
4655
4651
sd = OrderedDict ([('t' , t ), ('p' , p )])
4656
4652
4653
+ safe_globals_before = torch .serialization .get_safe_globals ()
4657
4654
try :
4658
4655
torch .serialization .add_safe_globals ([TestEmptySubclass ])
4659
4656
with tempfile .NamedTemporaryFile () as f :
4660
4657
torch .save (sd , f )
4661
4658
with safe_globals ([TwoTensor ]):
4662
4659
f .seek (0 )
4663
4660
torch .load (f , weights_only = True )
4664
- self .assertTrue (torch .serialization .get_safe_globals () == [TestEmptySubclass ])
4661
+ expected_safe_globals = set (safe_globals_before + [TestEmptySubclass ])
4662
+ self .assertEqual (set (torch .serialization .get_safe_globals ()), expected_safe_globals )
4665
4663
f .seek (0 )
4666
4664
with self .assertRaisesRegex (pickle .UnpicklingError ,
4667
4665
"Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor" ):
4668
4666
torch .load (f , weights_only = True )
4669
4667
finally :
4670
4668
torch .serialization .clear_safe_globals ()
4669
+ torch .serialization .add_safe_globals (safe_globals_before )
4671
4670
4672
4671
def test_sets_are_loadable_with_weights_only (self ):
4673
4672
s = {1 , 2 , 3 }
0 commit comments