Skip to content

Commit b5c410c

Browse files
authored
Merge pull request #93 from NVIDIA/tkurth/device-fixes
Tkurth/device fixes
2 parents 4aaff02 + 3d604f8 commit b5c410c

File tree

9 files changed

+144
-28
lines changed

9 files changed

+144
-28
lines changed

tests/test_convolution.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
4040

4141
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
42-
42+
from torch_harmonics.convolution import _precompute_convolution_tensor_s2
4343

4444
_devices = [(torch.device("cpu"),)]
4545
if torch.cuda.is_available():
@@ -127,7 +127,7 @@ def _precompute_convolution_tensor_dense(
127127
quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
128128

129129
# array for accumulating non-zero indices
130-
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)
130+
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64, device=lons_in.device)
131131

132132
for t in range(nlat_out):
133133
for p in range(nlon_out):
@@ -199,9 +199,10 @@ def setUp(self):
199199
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False],
200200
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False],
201201
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
202-
]
202+
],
203+
skip_on_empty=True,
203204
)
204-
def test_disco_convolution(
205+
def test_forward_backward(
205206
self,
206207
batch_size,
207208
in_channels,
@@ -315,6 +316,70 @@ def test_disco_convolution(
315316
self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
316317
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))
317318

319+
@parameterized.expand(
320+
[
321+
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
322+
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
323+
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
324+
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
325+
],
326+
skip_on_empty=True,
327+
)
328+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
329+
def test_device_instantiation(self, batch_size, in_channels, out_channels, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, transpose, tol, verbose):
330+
331+
nlat_in, nlon_in = in_shape
332+
nlat_out, nlon_out = out_shape
333+
334+
if isinstance(kernel_shape, int):
335+
theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
336+
else:
337+
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
338+
339+
# get handle
340+
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
341+
342+
# init on cpu
343+
conv_host = Conv(
344+
in_channels,
345+
out_channels,
346+
in_shape,
347+
out_shape,
348+
kernel_shape,
349+
basis_type=basis_type,
350+
basis_norm_mode=basis_norm_mode,
351+
groups=1,
352+
grid_in=grid_in,
353+
grid_out=grid_out,
354+
bias=False,
355+
theta_cutoff=theta_cutoff,
356+
)
357+
358+
#torch.set_default_device(self.device)
359+
with torch.device(self.device):
360+
conv_device = Conv(
361+
in_channels,
362+
out_channels,
363+
in_shape,
364+
out_shape,
365+
kernel_shape,
366+
basis_type=basis_type,
367+
basis_norm_mode=basis_norm_mode,
368+
groups=1,
369+
grid_in=grid_in,
370+
grid_out=grid_out,
371+
bias=False,
372+
theta_cutoff=theta_cutoff,
373+
)
374+
375+
# since we specified the device specifier everywhere, it should always
376+
# use the cpu and it should be the same everywhere
377+
self.assertTrue(torch.allclose(conv_host.psi_col_idx.cpu(), conv_device.psi_col_idx.cpu()))
378+
self.assertTrue(torch.allclose(conv_host.psi_row_idx.cpu(), conv_device.psi_row_idx.cpu()))
379+
self.assertTrue(torch.allclose(conv_host.psi_roff_idx.cpu(), conv_device.psi_roff_idx.cpu()))
380+
self.assertTrue(torch.allclose(conv_host.psi_vals.cpu(), conv_device.psi_vals.cpu()))
381+
self.assertTrue(torch.allclose(conv_host.psi_idx.cpu(), conv_device.psi_idx.cpu()))
382+
318383

319384
if __name__ == "__main__":
320385
unittest.main()

tests/test_sht.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,16 @@ def setUp(self):
101101
[33, 64, 32, "ortho", "equiangular", 1e-9, False],
102102
[33, 64, 32, "ortho", "legendre-gauss", 1e-9, False],
103103
[33, 64, 32, "ortho", "lobatto", 1e-9, False],
104-
[33, 64, 32, "four-pi", "equiangular", 1e-9, False],
104+
[33, 64, 32, "four-pi", "equiangular", 1e-9, False],
105105
[33, 64, 32, "four-pi", "legendre-gauss", 1e-9, False],
106106
[33, 64, 32, "four-pi", "lobatto", 1e-9, False],
107107
[33, 64, 32, "schmidt", "equiangular", 1e-9, False],
108108
[33, 64, 32, "schmidt", "legendre-gauss", 1e-9, False],
109109
[33, 64, 32, "schmidt", "lobatto", 1e-9, False],
110-
]
110+
],
111+
skip_on_empty=True,
111112
)
112-
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
113+
def test_forward_inverse(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
113114
if verbose:
114115
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
115116

@@ -168,9 +169,10 @@ def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
168169
[15, 30, 2, "schmidt", "equiangular", 1e-5, False],
169170
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
170171
[15, 30, 2, "schmidt", "lobatto", 1e-5, False],
171-
]
172+
],
173+
skip_on_empty=True,
172174
)
173-
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
175+
def test_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
174176
if verbose:
175177
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
176178

@@ -202,6 +204,40 @@ def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
202204
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
203205
self.assertTrue(test_result)
204206

207+
@parameterized.expand(
208+
[
209+
# even-even
210+
[12, 24, 2, "ortho", "equiangular", 1e-5, False],
211+
[12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
212+
[12, 24, 2, "ortho", "lobatto", 1e-5, False],
213+
],
214+
skip_on_empty=True,
215+
)
216+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
217+
def test_device_instantiation(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
218+
if verbose:
219+
print(f"Testing device instantiation of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
220+
221+
if grid == "equiangular":
222+
mmax = nlat // 2
223+
elif grid == "lobatto":
224+
mmax = nlat - 1
225+
else:
226+
mmax = nlat
227+
lmax = mmax
228+
229+
# init on cpu
230+
sht_host = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
231+
isht_host = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
232+
233+
# init on device
234+
with torch.device(self.device):
235+
sht_device = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
236+
isht_device = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
237+
238+
self.assertTrue(torch.allclose(sht_host.weights.cpu(), sht_device.weights.cpu()))
239+
self.assertTrue(torch.allclose(isht_host.pct.cpu(), isht_device.pct.cpu()))
240+
205241

206242
if __name__ == "__main__":
207243
unittest.main()

torch_harmonics/convolution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def _normalize_convolution_tensor_s2(
9292
q = quad_weights[ilat_in].reshape(-1)
9393

9494
# buffer to store intermediate values
95-
vnorm = torch.zeros(kernel_size, nlat_out)
96-
support = torch.zeros(kernel_size, nlat_out)
95+
vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
96+
support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
9797

9898
# loop through dimensions to compute the norms
9999
for ik in range(kernel_size):
@@ -207,7 +207,7 @@ def _precompute_convolution_tensor_s2(
207207
sgamma = torch.sin(gamma)
208208

209209
# compute row offsets
210-
out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64)
210+
out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device)
211211
out_roff[0] = 0
212212
for t in range(nlat_out):
213213
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis

torch_harmonics/csrc/disco/disco_helpers.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,27 +104,42 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
104104
CHECK_INPUT_TENSOR(col_idx);
105105
CHECK_INPUT_TENSOR(val);
106106

107+
// get the input device and make sure all tensors are on the same device
108+
auto device = ker_idx.device();
109+
TORCH_INTERNAL_ASSERT(device.type() == row_idx.device().type() && (device.type() == col_idx.device().type()) && (device.type() == val.device().type()));
110+
111+
// move to cpu
112+
ker_idx = ker_idx.to(torch::kCPU);
113+
row_idx = row_idx.to(torch::kCPU);
114+
col_idx = col_idx.to(torch::kCPU);
115+
val = val.to(torch::kCPU);
116+
107117
int64_t nnz = val.size(0);
108118
int64_t *ker_h = ker_idx.data_ptr<int64_t>();
109119
int64_t *row_h = row_idx.data_ptr<int64_t>();
110120
int64_t *col_h = col_idx.data_ptr<int64_t>();
111121
int64_t *roff_h = new int64_t[Ho * K + 1];
112122
int64_t nrows;
113-
// float *val_h = val.data_ptr<float>();
114123

115124
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
116125
preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
117126
val.data_ptr<scalar_t>(), nrows);
118127
}));
119128

120129
// create output tensor
121-
auto options = torch::TensorOptions().dtype(row_idx.dtype());
122-
auto roff_idx = torch::empty({nrows + 1}, options);
130+
auto roff_idx = torch::empty({nrows + 1}, row_idx.options());
123131
int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();
124132

125133
for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
126134
delete[] roff_h;
127135

136+
// move to original device
137+
ker_idx = ker_idx.to(device);
138+
row_idx = row_idx.to(device);
139+
col_idx = col_idx.to(device);
140+
val = val.to(device);
141+
roff_idx = roff_idx.to(device);
142+
128143
return roff_idx;
129144
}
130145

torch_harmonics/filter_basis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: flo
254254
mkernel = ikernel // self.kernel_shape[1]
255255

256256
# get relevant indices
257-
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
257+
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device))
258258

259259
# get corresponding r, phi, x and y coordinates
260260
r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff
@@ -316,10 +316,10 @@ def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: flo
316316
"""
317317

318318
# enumerator for basis function
319-
ikernel = torch.arange(self.kernel_size).reshape(-1, 1, 1)
319+
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
320320

321321
# get relevant indices
322-
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
322+
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device))
323323

324324
# indexing logic for zernike polynomials
325325
# the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed

torch_harmonics/legendre.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
5757

5858
# compute the tensor P^m_n:
5959
nmax = max(mmax,lmax)
60-
vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64, requires_grad=False)
60+
vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64, device=x.device, requires_grad=False)
6161

62-
norm_factor = 1. if norm == "ortho" else math.sqrt(4 * math.pi)
63-
norm_factor = 1. / norm_factor if inverse else norm_factor
62+
norm_factor = 1.0 if norm == "ortho" else math.sqrt(4 * math.pi)
63+
norm_factor = 1.0 / norm_factor if inverse else norm_factor
6464

6565
# initial values to start the recursion
6666
vdm[0,0,:] = norm_factor / math.sqrt(4 * math.pi)
@@ -123,7 +123,7 @@ def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
123123

124124
pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False)
125125

126-
dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64, requires_grad=False)
126+
dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64, device=t.device, requires_grad=False)
127127

128128
# fill the derivative terms wrt theta
129129
for l in range(0, lmax):

torch_harmonics/quadrature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
169169
tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False))
170170

171171
if n == 2:
172-
wcc = torch.tensor([1.0, 1.0], dtype=torch.float64)
172+
wcc = torch.as_tensor([1.0, 1.0], dtype=torch.float64)
173173
else:
174174

175175
n1 = n - 1

torch_harmonics/random_fields.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ def __init__(self, nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid="equia
7777
self.isht = InverseRealSHT(self.nlat, 2*self.nlat, grid=grid, norm='backward').to(dtype=dtype)
7878

7979
#Square root of the eigenvalues of C.
80-
sqrt_eig = torch.tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1)
80+
sqrt_eig = torch.as_tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1)
8181
sqrt_eig = torch.tril(sigma*(((sqrt_eig/radius**2) + tau**2)**(-alpha/2.0)))
8282
sqrt_eig[0,0] = 0.0
8383
sqrt_eig = sqrt_eig.unsqueeze(0)
8484
self.register_buffer('sqrt_eig', sqrt_eig)
8585

8686
#Save mean and var of the standard Gaussian.
8787
#Need these to re-initialize distribution on a new device.
88-
mean = torch.tensor([0.0]).to(dtype=dtype)
89-
var = torch.tensor([1.0]).to(dtype=dtype)
88+
mean = torch.as_tensor([0.0]).to(dtype=dtype)
89+
var = torch.as_tensor([1.0]).to(dtype=dtype)
9090
self.register_buffer('mean', mean)
9191
self.register_buffer('var', var)
9292

torch_harmonics/resample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def __init__(
7575
# we need to expand the solution to the poles before interpolating
7676
self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
7777
if self.expand_poles:
78-
self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
78+
self.lats_in = torch.cat([torch.as_tensor([0.], dtype=torch.float64, device=self.lats_in.device),
7979
self.lats_in,
80-
torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
80+
torch.as_tensor([math.pi], dtype=torch.float64, device=self.lats_in.device)]).contiguous()
8181

8282
# prepare the interpolation by computing indices to the left and right of each output latitude
8383
lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1

0 commit comments

Comments
 (0)