@@ -397,7 +397,7 @@ class ChronosPipeline:
397397 model : ChronosModel
398398
399399 def _prepare_and_validate_context (
400- self , context : Union [torch .Tensor , List [torch .Tensor ]]
400+ self , context : Union [torch .Tensor , List [torch .Tensor ]], dtype = torch . float32
401401 ):
402402 if isinstance (context , list ):
403403 context = left_pad_and_stack_1D (context )
@@ -406,7 +406,7 @@ def _prepare_and_validate_context(
406406 context = context .unsqueeze (0 )
407407 assert context .ndim == 2
408408
409- return context
409+ return context . to ( dtype = dtype )
410410
411411 @torch .no_grad ()
412412 def embed (
@@ -506,15 +506,12 @@ def predict(
506506 raise ValueError (msg )
507507 warnings .warn (msg )
508508
509- device = context_tensor .device
510- dtype = context_tensor .dtype
511-
512509 predictions = []
513510 remaining = prediction_length
514511
515512 while remaining > 0 :
516513 token_ids , attention_mask , scale = self .tokenizer .context_input_transform (
517- context_tensor . to ( torch . float32 )
514+ context_tensor
518515 )
519516 samples = self .model (
520517 token_ids .to (self .model .device ),
@@ -539,7 +536,7 @@ def predict(
539536 [context_tensor , prediction .median (dim = 1 ).values ], dim = - 1
540537 )
541538
542- return torch .cat (predictions , dim = - 1 ). to ( device , dtype )
539+ return torch .cat (predictions , dim = - 1 )
543540
544541 @classmethod
545542 def from_pretrained (cls , * args , ** kwargs ):
0 commit comments