Skip to content

Commit 975172a

Browse files
committed
update tests
1 parent a92432f commit 975172a

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

test/test_chronos.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,10 @@ def test_tokenizer_random_data(use_eos_token: bool):
157157
assert samples.shape == (2, 10, 4)
158158

159159

160-
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None:
161-
assert isinstance(samples, torch.Tensor)
162-
assert samples.shape == shape
160+
def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype) -> None:
161+
assert isinstance(a, torch.Tensor)
162+
assert a.shape == shape
163+
assert a.dtype == dtype
163164

164165

165166
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
@@ -174,20 +175,20 @@ def test_pipeline_predict(torch_dtype: str):
174175
# input: tensor of shape (batch_size, context_length)
175176

176177
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
177-
validate_tensor(samples, (4, 12, 3))
178+
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
178179

179180
with pytest.raises(ValueError):
180181
samples = pipeline.predict(context, num_samples=7, prediction_length=65)
181182

182183
samples = pipeline.predict(
183184
context, num_samples=7, prediction_length=65, limit_prediction_length=False
184185
)
185-
validate_tensor(samples, (4, 7, 65))
186+
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
186187

187188
# input: batch_size-long list of tensors of shape (context_length,)
188189

189190
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
190-
validate_tensor(samples, (4, 12, 3))
191+
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
191192

192193
with pytest.raises(ValueError):
193194
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
@@ -198,12 +199,12 @@ def test_pipeline_predict(torch_dtype: str):
198199
prediction_length=65,
199200
limit_prediction_length=False,
200201
)
201-
validate_tensor(samples, (4, 7, 65))
202+
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
202203

203204
# input: tensor of shape (context_length,)
204205

205206
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
206-
validate_tensor(samples, (1, 12, 3))
207+
validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32)
207208

208209
with pytest.raises(ValueError):
209210
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
@@ -214,7 +215,7 @@ def test_pipeline_predict(torch_dtype: str):
214215
prediction_length=65,
215216
limit_prediction_length=False,
216217
)
217-
validate_tensor(samples, (1, 7, 65))
218+
validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32)
218219

219220

220221
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
@@ -231,19 +232,25 @@ def test_pipeline_embed(torch_dtype: str):
231232
# input: tensor of shape (batch_size, context_length)
232233

233234
embedding, scale = pipeline.embed(context)
234-
validate_tensor(embedding, (4, expected_embed_length, d_model))
235-
validate_tensor(scale, (4,))
235+
validate_tensor(
236+
embedding, shape=(4, expected_embed_length, d_model), dtype=torch_dtype
237+
)
238+
validate_tensor(scale, shape=(4,), dtype=torch.float32)
236239

237240
# input: batch_size-long list of tensors of shape (context_length,)
238241

239242
embedding, scale = pipeline.embed(list(context))
240-
validate_tensor(embedding, (4, expected_embed_length, d_model))
241-
validate_tensor(scale, (4,))
243+
validate_tensor(
244+
embedding, shape=(4, expected_embed_length, d_model), dtype=torch_dtype
245+
)
246+
validate_tensor(scale, shape=(4,), dtype=torch.float32)
242247

243248
# input: tensor of shape (context_length,)
244249
embedding, scale = pipeline.embed(context[0, ...])
245-
validate_tensor(embedding, (1, expected_embed_length, d_model))
246-
validate_tensor(scale, (1,))
250+
validate_tensor(
251+
embedding, shape=(1, expected_embed_length, d_model), dtype=torch_dtype
252+
)
253+
validate_tensor(scale, shape=(1,), dtype=torch.float32)
247254

248255

249256
@pytest.mark.parametrize("n_tokens", [10, 1000, 10000])

0 commit comments

Comments
 (0)