File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -506,6 +506,9 @@ def predict(
506506 raise ValueError (msg )
507507 warnings .warn (msg )
508508
509+ device = context_tensor .device
510+ dtype = context_tensor .dtype
511+
509512 predictions = []
510513 remaining = prediction_length
511514
@@ -524,7 +527,7 @@ def predict(
524527 )
525528 prediction = self .tokenizer .output_transform (
526529 samples .to (scale .device ), scale
527- ). to ( context_tensor )
530+ )
528531
529532 predictions .append (prediction )
530533 remaining -= prediction .shape [- 1 ]
@@ -536,7 +539,7 @@ def predict(
536539 [context_tensor , prediction .median (dim = 1 ).values ], dim = - 1
537540 )
538541
539- return torch .cat (predictions , dim = - 1 )
542+ return torch .cat (predictions , dim = - 1 ). to ( device , dtype )
540543
541544 @classmethod
542545 def from_pretrained (cls , * args , ** kwargs ):
You can’t perform that action at this time.
0 commit comments