@@ -101,15 +101,16 @@ def setUp(self):
101101 [33 , 64 , 32 , "ortho" , "equiangular" , 1e-9 , False ],
102102 [33 , 64 , 32 , "ortho" , "legendre-gauss" , 1e-9 , False ],
103103 [33 , 64 , 32 , "ortho" , "lobatto" , 1e-9 , False ],
104- [33 , 64 , 32 , "four-pi" , "equiangular" , 1e-9 , False ],
104+ [33 , 64 , 32 , "four-pi" , "equiangular" , 1e-9 , False ],
105105 [33 , 64 , 32 , "four-pi" , "legendre-gauss" , 1e-9 , False ],
106106 [33 , 64 , 32 , "four-pi" , "lobatto" , 1e-9 , False ],
107107 [33 , 64 , 32 , "schmidt" , "equiangular" , 1e-9 , False ],
108108 [33 , 64 , 32 , "schmidt" , "legendre-gauss" , 1e-9 , False ],
109109 [33 , 64 , 32 , "schmidt" , "lobatto" , 1e-9 , False ],
110- ]
110+ ],
111+ skip_on_empty = True ,
111112 )
112- def test_sht (self , nlat , nlon , batch_size , norm , grid , tol , verbose ):
113+ def test_forward_inverse (self , nlat , nlon , batch_size , norm , grid , tol , verbose ):
113114 if verbose :
114115 print (f"Testing real-valued SHT on { nlat } x{ nlon } { grid } grid with { norm } normalization on { self .device .type } device" )
115116
@@ -168,9 +169,10 @@ def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
168169 [15 , 30 , 2 , "schmidt" , "equiangular" , 1e-5 , False ],
169170 [15 , 30 , 2 , "schmidt" , "legendre-gauss" , 1e-5 , False ],
170171 [15 , 30 , 2 , "schmidt" , "lobatto" , 1e-5 , False ],
171- ]
172+ ],
173+ skip_on_empty = True ,
172174 )
173- def test_sht_grads (self , nlat , nlon , batch_size , norm , grid , tol , verbose ):
175+ def test_grads (self , nlat , nlon , batch_size , norm , grid , tol , verbose ):
174176 if verbose :
175177 print (f"Testing gradients of real-valued SHT on { nlat } x{ nlon } { grid } grid with { norm } normalization" )
176178
@@ -202,6 +204,40 @@ def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
202204 test_result = gradcheck (err_handle , grad_input , eps = 1e-6 , atol = tol )
203205 self .assertTrue (test_result )
204206
207+ @parameterized .expand (
208+ [
209+ # even-even
210+ [12 , 24 , 2 , "ortho" , "equiangular" , 1e-5 , False ],
211+ [12 , 24 , 2 , "ortho" , "legendre-gauss" , 1e-5 , False ],
212+ [12 , 24 , 2 , "ortho" , "lobatto" , 1e-5 , False ],
213+ ],
214+ skip_on_empty = True ,
215+ )
216+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA is not available" )
217+ def test_device_instantiation (self , nlat , nlon , batch_size , norm , grid , tol , verbose ):
218+ if verbose :
219+ print (f"Testing device instantiation of real-valued SHT on { nlat } x{ nlon } { grid } grid with { norm } normalization" )
220+
221+ if grid == "equiangular" :
222+ mmax = nlat // 2
223+ elif grid == "lobatto" :
224+ mmax = nlat - 1
225+ else :
226+ mmax = nlat
227+ lmax = mmax
228+
229+ # init on cpu
230+ sht_host = th .RealSHT (nlat , nlon , mmax = mmax , lmax = lmax , grid = grid , norm = norm )
231+ isht_host = th .InverseRealSHT (nlat , nlon , mmax = mmax , lmax = lmax , grid = grid , norm = norm )
232+
233+ # init on device
234+ with torch .device (self .device ):
235+ sht_device = th .RealSHT (nlat , nlon , mmax = mmax , lmax = lmax , grid = grid , norm = norm )
236+ isht_device = th .InverseRealSHT (nlat , nlon , mmax = mmax , lmax = lmax , grid = grid , norm = norm )
237+
238+ self .assertTrue (torch .allclose (sht_host .weights .cpu (), sht_device .weights .cpu ()))
239+ self .assertTrue (torch .allclose (isht_host .pct .cpu (), isht_device .pct .cpu ()))
240+
205241
206242if __name__ == "__main__" :
207243 unittest .main ()
0 commit comments