Skip to content

Commit 43e478a

Browse files
committed
Preallocate GPU interpolant
The interpolants we in `Interpolations.jl` are described by two arrays: the knots and the coeffs. When `Adapt` is called on these interpolants, CuArrays are allocated on the GPU. For large data, this is inefficient. In this commit, I add a system to avoid these allocations. This is accomplished by add a dictionary to `InterpolationsRegridder`. This dictionary has keys that identify the size of the knots and coefficients and values the adapted splines. When `regrid` is called, we check if we have already allocated some suitable space in this dictionary, if not, we create a new spline, if we do, we write in place. This removes GPU allocations in the hot path (ie, the regridder is used in a time evolution with always the same data and dimensions), while also keeping the flexibility of reusing the same regridder with any input data.
1 parent bcb6ce0 commit 43e478a

File tree

1 file changed

+63
-10
lines changed

1 file changed

+63
-10
lines changed

ext/InterpolationsRegridderExt.jl

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ struct InterpolationsRegridder{
1212
SPACE <: ClimaCore.Spaces.AbstractSpace,
1313
FIELD <: ClimaCore.Fields.Field,
1414
BC,
15+
GITP,
1516
} <: Regridders.AbstractRegridder
1617

1718
"""ClimaCore.Space where the output Field will be defined"""
@@ -22,6 +23,14 @@ struct InterpolationsRegridder{
2223

2324
"""Tuple of extrapolation conditions as accepted by Interpolations.jl"""
2425
extrapolation_bc::BC
26+
27+
# This is needed because Adapt moves from CPU to GPU and allocates new memory.
28+
"""Dictionary of preallocated areas of memory where to store the GPU interpolant (if
29+
needed). Every time new data/dimensions are used in regrid, a new entry in the
30+
dictionary is created. The keys of the dictionary a tuple of tuple
31+
`(size(dimensions), size(data))`, with `dimensions` and `data` defined in `regrid`.
32+
"""
33+
_gpuitps::GITP
2534
end
2635

2736
# Note, we swap Lat and Long! This is because according to the CF conventions longitude
@@ -69,9 +78,38 @@ function Regridders.InterpolationsRegridder(
6978
end
7079
end
7180

72-
return InterpolationsRegridder(target_space, coordinates, extrapolation_bc)
81+
# Let's figure out the type of _gpuitps by creating a simple spline
82+
FT = ClimaCore.Spaces.undertype(target_space)
83+
dimensions = ntuple(_ -> [zero(FT), one(FT)], length(extrapolation_bc))
84+
data = zeros(FT, ntuple(_ -> 2, length(dimensions)))
85+
itp = _create_linear_spline(FT, data, dimensions, extrapolation_bc)
86+
fake_gpuitp = Adapt.adapt(ClimaComms.array_type(target_space), itp)
87+
gpuitps = Dict((size.(dimensions), size(data)) => fake_gpuitp)
88+
89+
return InterpolationsRegridder(
90+
target_space,
91+
coordinates,
92+
extrapolation_bc,
93+
gpuitps,
94+
)
95+
end
96+
97+
"""
98+
_create_linear_spline(regridder::InterpolationsRegridder, data, dimensions)
99+
100+
Create a linear spline for the given data on the given dimension (on the CPU).
101+
"""
102+
function _create_linear_spline(FT, data, dimensions, extrapolation_bc)
103+
dimensions_FT = map(d -> FT.(d), dimensions)
104+
105+
# Make a linear spline
106+
return Intp.extrapolate(
107+
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
108+
extrapolation_bc,
109+
)
73110
end
74111

112+
75113
"""
76114
regrid(regridder::InterpolationsRegridder, data, dimensions)::Field
77115
@@ -81,16 +119,31 @@ This function is allocating.
81119
"""
82120
function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
83121
FT = ClimaCore.Spaces.undertype(regridder.target_space)
84-
dimensions_FT = map(d -> FT.(d), dimensions)
85-
86-
# Make a linear spline
87-
itp = Intp.extrapolate(
88-
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
89-
regridder.extrapolation_bc,
90-
)
122+
itp =
123+
_create_linear_spline(FT, data, dimensions, regridder.extrapolation_bc)
124+
125+
key = (size.(dimensions), size(data))
126+
127+
if haskey(regridder._gpuitps, key)
128+
for (k, k_new) in zip(
129+
regridder._gpuitps[key].itp.knots,
130+
Adapt.adapt(
131+
ClimaComms.array_type(regridder.target_space),
132+
itp.itp.knots,
133+
),
134+
)
135+
k .= k_new
136+
end
137+
regridder._gpuitps[key].itp.coefs .= Adapt.adapt(
138+
ClimaComms.array_type(regridder.target_space),
139+
itp.itp.coefs,
140+
)
141+
else
142+
regridder._gpuitps[key] =
143+
Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
144+
end
91145

92-
# Move it to GPU (if needed)
93-
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
146+
gpuitp = regridder._gpuitps[key]
94147

95148
return map(regridder.coordinates) do coord
96149
gpuitp(totuple(coord)...)

0 commit comments

Comments
 (0)