@@ -85,7 +85,7 @@ def run_2d(self):
8585 bench = tasks .FunctionDescent ('mycs1' )
8686 self .run_bench (bench , '2D - mycs1' , passes = 500 , sec = 60 , metrics = 'train loss' , vid_scale = 1 )
8787
88- bench = tasks .SimultaneousFunctionDescent ('rosen' ).cuda ( )
88+ bench = tasks .SimultaneousFunctionDescent ('rosen' ).to ( CUDA_IF_AVAILABLE )
8989 self .run_bench (bench , '2D simultaneous - rosenbrock' , passes = 1000 , sec = 60 , metrics = 'train loss' , vid_scale = 3 )
9090
9191 def run_projected (self ):
@@ -171,11 +171,13 @@ def run_visual(self):
171171 self .run_bench (bench , 'Visual - LinesDrawer SSIM' , passes = 2000 , sec = 60 , metrics = 'train loss' , vid_scale = 4 , fps = 30 )
172172
173173 # -------------------------- deformable registration ------------------------- #
174- bench = tasks .DeformableRegistration (data .FROG96 , grid_size = (5 ,5 )).cuda ( )
174+ bench = tasks .DeformableRegistration (data .FROG96 , grid_size = (5 ,5 )).to ( CUDA_IF_AVAILABLE )
175175 self .run_bench (bench , 'Visual - DeformableRegistration' , passes = 2_000 , sec = 60 , metrics = 'train loss' , vid_scale = 2 )
176176
177177 def run_linalg (self ):
178178 # ---------------------------------- inverse --------------------------------- #
179- bench = tasks .Inverse (data .SANIC96 ).cuda ( )
179+ bench = tasks .Inverse (data .SANIC96 ).to ( CUDA_IF_AVAILABLE )
180180 self .run_bench (bench , 'Linalg - Inverse' , passes = 2_000 , sec = 60 , metrics = 'train loss' , vid_scale = 2 )
181181
182+ bench = tasks .StochasticInverse (data .SANIC96 ).to (CUDA_IF_AVAILABLE )
183+ self .run_bench (bench , 'Linalg - StochasticInverse' , passes = 2_000 , sec = 60 , metrics = 'train loss' , vid_scale = 2 )
0 commit comments