Skip to content

Commit a92432f

Browse files
committed
simplify
1 parent 37675ed commit a92432f

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

src/chronos/chronos.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ class ChronosPipeline:
397397
model: ChronosModel
398398

399399
def _prepare_and_validate_context(
400-
self, context: Union[torch.Tensor, List[torch.Tensor]]
400+
self, context: Union[torch.Tensor, List[torch.Tensor]], dtype=torch.float32
401401
):
402402
if isinstance(context, list):
403403
context = left_pad_and_stack_1D(context)
@@ -406,7 +406,7 @@ def _prepare_and_validate_context(
406406
context = context.unsqueeze(0)
407407
assert context.ndim == 2
408408

409-
return context
409+
return context.to(dtype=dtype)
410410

411411
@torch.no_grad()
412412
def embed(
@@ -506,15 +506,12 @@ def predict(
506506
raise ValueError(msg)
507507
warnings.warn(msg)
508508

509-
device = context_tensor.device
510-
dtype = context_tensor.dtype
511-
512509
predictions = []
513510
remaining = prediction_length
514511

515512
while remaining > 0:
516513
token_ids, attention_mask, scale = self.tokenizer.context_input_transform(
517-
context_tensor.to(torch.float32)
514+
context_tensor
518515
)
519516
samples = self.model(
520517
token_ids.to(self.model.device),
@@ -539,7 +536,7 @@ def predict(
539536
[context_tensor, prediction.median(dim=1).values], dim=-1
540537
)
541538

542-
return torch.cat(predictions, dim=-1).to(device, dtype)
539+
return torch.cat(predictions, dim=-1)
543540

544541
@classmethod
545542
def from_pretrained(cls, *args, **kwargs):

0 commit comments

Comments
 (0)