Skip to content

Commit d95e8b6

Browse files
committed
Preallocate GPU interpolant
1 parent bcb6ce0 commit d95e8b6

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

ext/InterpolationsRegridderExt.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ import Interpolations as Intp
55
import ClimaCore
66
import ClimaCore.Fields: Adapt
77
import ClimaCore.Fields: ClimaComms
8+
import ClimaCore.Fields: zeros
89

910
import ClimaUtilities.Regridders
1011

1112
struct InterpolationsRegridder{
1213
SPACE <: ClimaCore.Spaces.AbstractSpace,
1314
FIELD <: ClimaCore.Fields.Field,
1415
BC,
16+
GITP,
1517
} <: Regridders.AbstractRegridder
1618

1719
"""ClimaCore.Space where the output Field will be defined"""
@@ -22,6 +24,10 @@ struct InterpolationsRegridder{
2224

2325
"""Tuple of extrapolation conditions as accepted by Interpolations.jl"""
2426
extrapolation_bc::BC
27+
28+
# This is needed because Adapt moves from CPU to GPU and allocates new memory
29+
"""Preallocated area of memory where to store the GPU interpolant (if needed)"""
30+
_gpuitp::GITP
2531
end
2632

2733
# Note, we swap Lat and Long! This is because according to the CF conventions longitude
@@ -69,7 +75,12 @@ function Regridders.InterpolationsRegridder(
6975
end
7076
end
7177

72-
return InterpolationsRegridder(target_space, coordinates, extrapolation_bc)
78+
return InterpolationsRegridder(
79+
target_space,
80+
coordinates,
81+
extrapolation_bc,
82+
zeros(target_space),
83+
)
7384
end
7485

7586
"""
@@ -90,7 +101,8 @@ function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
90101
)
91102

92103
# Move it to GPU (if needed)
93-
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
104+
itp._gpuitp .=
105+
Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
94106

95107
return map(regridder.coordinates) do coord
96108
gpuitp(totuple(coord)...)

0 commit comments

Comments
 (0)