@@ -157,9 +157,10 @@ def test_tokenizer_random_data(use_eos_token: bool):
157157 assert samples .shape == (2 , 10 , 4 )
158158
159159
160- def validate_tensor (samples : torch .Tensor , shape : Tuple [int , ...]) -> None :
161- assert isinstance (samples , torch .Tensor )
162- assert samples .shape == shape
160+ def validate_tensor (a : torch .Tensor , shape : Tuple [int , ...], dtype ) -> None :
161+ assert isinstance (a , torch .Tensor )
162+ assert a .shape == shape
163+ assert a .dtype == dtype
163164
164165
165166@pytest .mark .parametrize ("torch_dtype" , [torch .float32 , torch .bfloat16 ])
@@ -174,20 +175,20 @@ def test_pipeline_predict(torch_dtype: str):
174175 # input: tensor of shape (batch_size, context_length)
175176
176177 samples = pipeline .predict (context , num_samples = 12 , prediction_length = 3 )
177- validate_tensor (samples , (4 , 12 , 3 ))
178+ validate_tensor (samples , shape = (4 , 12 , 3 ), dtype = torch . float32 )
178179
179180 with pytest .raises (ValueError ):
180181 samples = pipeline .predict (context , num_samples = 7 , prediction_length = 65 )
181182
182183 samples = pipeline .predict (
183184 context , num_samples = 7 , prediction_length = 65 , limit_prediction_length = False
184185 )
185- validate_tensor (samples , (4 , 7 , 65 ))
186+ validate_tensor (samples , shape = (4 , 7 , 65 ), dtype = torch . float32 )
186187
187188 # input: batch_size-long list of tensors of shape (context_length,)
188189
189190 samples = pipeline .predict (list (context ), num_samples = 12 , prediction_length = 3 )
190- validate_tensor (samples , (4 , 12 , 3 ))
191+ validate_tensor (samples , shape = (4 , 12 , 3 ), dtype = torch . float32 )
191192
192193 with pytest .raises (ValueError ):
193194 samples = pipeline .predict (list (context ), num_samples = 7 , prediction_length = 65 )
@@ -198,12 +199,12 @@ def test_pipeline_predict(torch_dtype: str):
198199 prediction_length = 65 ,
199200 limit_prediction_length = False ,
200201 )
201- validate_tensor (samples , (4 , 7 , 65 ))
202+ validate_tensor (samples , shape = (4 , 7 , 65 ), dtype = torch . float32 )
202203
203204 # input: tensor of shape (context_length,)
204205
205206 samples = pipeline .predict (context [0 , ...], num_samples = 12 , prediction_length = 3 )
206- validate_tensor (samples , (1 , 12 , 3 ))
207+ validate_tensor (samples , shape = (1 , 12 , 3 ), dtype = torch . float32 )
207208
208209 with pytest .raises (ValueError ):
209210 samples = pipeline .predict (context [0 , ...], num_samples = 7 , prediction_length = 65 )
@@ -214,7 +215,7 @@ def test_pipeline_predict(torch_dtype: str):
214215 prediction_length = 65 ,
215216 limit_prediction_length = False ,
216217 )
217- validate_tensor (samples , (1 , 7 , 65 ))
218+ validate_tensor (samples , shape = (1 , 7 , 65 ), dtype = torch . float32 )
218219
219220
220221@pytest .mark .parametrize ("torch_dtype" , [torch .float32 , torch .bfloat16 ])
@@ -231,19 +232,25 @@ def test_pipeline_embed(torch_dtype: str):
231232 # input: tensor of shape (batch_size, context_length)
232233
233234 embedding , scale = pipeline .embed (context )
234- validate_tensor (embedding , (4 , expected_embed_length , d_model ))
235- validate_tensor (scale , (4 ,))
235+ validate_tensor (
236+ embedding , shape = (4 , expected_embed_length , d_model ), dtype = torch_dtype
237+ )
238+ validate_tensor (scale , shape = (4 ,), dtype = torch .float32 )
236239
237240 # input: batch_size-long list of tensors of shape (context_length,)
238241
239242 embedding , scale = pipeline .embed (list (context ))
240- validate_tensor (embedding , (4 , expected_embed_length , d_model ))
241- validate_tensor (scale , (4 ,))
243+ validate_tensor (
244+ embedding , shape = (4 , expected_embed_length , d_model ), dtype = torch_dtype
245+ )
246+ validate_tensor (scale , shape = (4 ,), dtype = torch .float32 )
242247
243248 # input: tensor of shape (context_length,)
244249 embedding , scale = pipeline .embed (context [0 , ...])
245- validate_tensor (embedding , (1 , expected_embed_length , d_model ))
246- validate_tensor (scale , (1 ,))
250+ validate_tensor (
251+ embedding , shape = (1 , expected_embed_length , d_model ), dtype = torch_dtype
252+ )
253+ validate_tensor (scale , shape = (1 ,), dtype = torch .float32 )
247254
248255
249256@pytest .mark .parametrize ("n_tokens" , [10 , 1000 , 10000 ])
0 commit comments