Skip to content

Commit 37675ed

Browse files
committed
update
1 parent d45c468 commit 37675ed

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/chronos/chronos.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,9 @@ def predict(
506506
raise ValueError(msg)
507507
warnings.warn(msg)
508508

509+
device = context_tensor.device
510+
dtype = context_tensor.dtype
511+
509512
predictions = []
510513
remaining = prediction_length
511514

@@ -524,7 +527,7 @@ def predict(
524527
)
525528
prediction = self.tokenizer.output_transform(
526529
samples.to(scale.device), scale
527-
).to(context_tensor)
530+
)
528531

529532
predictions.append(prediction)
530533
remaining -= prediction.shape[-1]
@@ -536,7 +539,7 @@ def predict(
536539
[context_tensor, prediction.median(dim=1).values], dim=-1
537540
)
538541

539-
return torch.cat(predictions, dim=-1)
542+
return torch.cat(predictions, dim=-1).to(device, dtype)
540543

541544
@classmethod
542545
def from_pretrained(cls, *args, **kwargs):

0 commit comments

Comments
 (0)