Skip to content

Commit bb76fae

Browse files
committed
refactored laplacian.py's manual implementation of the stencil
this seems to be substantially faster as it can be cached
1 parent 1950435 commit bb76fae

File tree

2 files changed

+273
-6
lines changed

2 files changed

+273
-6
lines changed

src/pybella/flow_solver/physics/low_mach/laplacian.py

Lines changed: 272 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ def get_lap2D_stencil(mpv, node, coriolis, diag_inv, ud):
2020

2121
dummy_p = np.zeros((node.isc[1], node.isc[0]))
2222

23-
# if hasattr(ud, "ATMOSPHERIC_EXTENSION"):
24-
# return lambda p: lap2D_extended(
25-
# p, dummy_p, dx, dy, coeffs, diag_inv.T, coriolis, shp
26-
# )
27-
2823
if hasattr(ud, "ATMOSPHERIC_EXTENSION") and ud.ATMOSPHERIC_EXTENSION:
2924
boundary_handler = periodic_x_wall_y
3025
else:
@@ -206,6 +201,278 @@ def kernel_9pt(a, dx, dy, hpx, hpy, hpc, diag_inv, cxx, cyy, cxy, cyx):
206201

207202
return ((Dxx + Dyy + Dyx + Dxy) + hpc[0, 0] * a[0, 0]) * diag_inv[0, 0]
208203

204+
205+
206+
207+
def get_lap2D(mpv, node, coriolis, diag_inv, ud):
208+
dx = node.dx
209+
dy = node.dy
210+
211+
hplusx = mpv.wplus[0]
212+
hplusy = mpv.wplus[1]
213+
hcenter = mpv.wcenter
214+
215+
coeffs = (hplusx.T, hplusy.T, hcenter.T)
216+
217+
### Need to clean this up, but the Numba stencil is used in the Helmholtz solve for radiative BC!
218+
if hasattr(ud, "ATMOSPHERIC_EXTENSION") and ud.ATMOSPHERIC_EXTENSION:
219+
y_atmosphere = True
220+
else:
221+
y_atmosphere = False
222+
223+
###################
224+
x_wall = (
225+
ud.bdry_type[0] == opts.BdryType.WALL
226+
)
227+
y_wall = (
228+
ud.bdry_type[1] == opts.BdryType.WALL
229+
)
230+
231+
cor_slc = (slice(1, -1), slice(1, -1))
232+
coeff_slc = (slice(1, -1), slice(1, -1))
233+
234+
coeffs = (
235+
hplusx[coeff_slc].T.reshape(
236+
-1,
237+
),
238+
hplusy[coeff_slc].T.reshape(
239+
-1,
240+
),
241+
hcenter[node.i1].T.reshape(
242+
-1,
243+
),
244+
)
245+
246+
coriolis = (
247+
coriolis[0][cor_slc].reshape(
248+
-1,
249+
),
250+
coriolis[1][cor_slc].reshape(
251+
-1,
252+
),
253+
coriolis[2][cor_slc].reshape(
254+
-1,
255+
),
256+
coriolis[3][cor_slc].reshape(
257+
-1,
258+
),
259+
)
260+
261+
return lambda p: lap2D_gather_new(
262+
p,
263+
node.iicx,
264+
node.iicy,
265+
coeffs,
266+
dx,
267+
dy,
268+
x_wall,
269+
y_wall,
270+
y_atmosphere,
271+
diag_inv[node.i1].T.reshape(
272+
-1,
273+
),
274+
coriolis,
275+
)
276+
277+
278+
@nb.njit(cache=True)
279+
def lap2D_gather_new(
280+
p, iicxn, iicyn, coeffs, dx, dy, x_wall, y_wall, y_atmosphere, diag_inv, coriolis
281+
):
282+
ngnc = (iicxn) * (iicyn)
283+
lap = np.zeros((ngnc))
284+
cnt_x = 0
285+
cnt_y = 0
286+
287+
oodx = 1.0 / dx
288+
oody = 1.0 / dy
289+
cxx, cyy, cxy, cyx = coriolis
290+
291+
hplusx, hplusy, hcenter = coeffs
292+
293+
for idx in range(iicxn * iicyn):
294+
nr_row = idx // iicxn
295+
col_idx = idx - (nr_row * iicxn)
296+
297+
ne_row_idx = nr_row * (iicxn + 1)
298+
ne_col_idx = col_idx
299+
ne_idx = ne_row_idx + ne_col_idx
300+
301+
ne_topleft = ne_idx
302+
ne_topright = ne_idx + 1
303+
ne_botleft = ne_idx + (iicxn + 1)
304+
ne_botright = ne_idx + (iicxn + 1) + 1
305+
306+
# get indices of the 9pt stencil
307+
topleft_idx = idx - iicxn - 1
308+
midleft_idx = idx - 1
309+
botleft_idx = idx + iicxn - 1
310+
311+
topmid_idx = idx - iicxn
312+
midmid_idx = idx
313+
botmid_idx = idx + iicxn
314+
315+
topright_idx = idx - iicxn + 1
316+
midright_idx = idx + 1
317+
botright_idx = idx + iicxn + 1
318+
319+
if cnt_x == 0:
320+
topleft_idx += iicxn - 1
321+
midleft_idx += iicxn - 1
322+
botleft_idx += iicxn - 1
323+
324+
if cnt_x == (iicxn - 1):
325+
topright_idx -= iicxn - 1
326+
midright_idx -= iicxn - 1
327+
botright_idx -= iicxn - 1
328+
329+
val = 0
330+
331+
if cnt_y == 0:
332+
if y_atmosphere:
333+
topleft_idx += 2 * (iicxn - val)
334+
topmid_idx += 2 * (iicxn - val)
335+
topright_idx += 2 * (iicxn - val)
336+
else:
337+
topleft_idx += (iicxn) * (iicyn - 1)
338+
topmid_idx += (iicxn) * (iicyn - 1)
339+
topright_idx += (iicxn) * (iicyn - 1)
340+
341+
if cnt_y == (iicyn - 1):
342+
if y_atmosphere:
343+
botleft_idx -= 2 * (iicxn - val)
344+
botmid_idx -= 2 * (iicxn - val)
345+
botright_idx -= 2 * (iicxn - val)
346+
else:
347+
botleft_idx -= (iicxn) * (iicyn - 1)
348+
botmid_idx -= (iicxn) * (iicyn - 1)
349+
botright_idx -= (iicxn) * (iicyn - 1)
350+
351+
topleft = p[topleft_idx]
352+
midleft = p[midleft_idx]
353+
botleft = p[botleft_idx]
354+
355+
topmid = p[topmid_idx]
356+
midmid = p[midmid_idx]
357+
botmid = p[botmid_idx]
358+
359+
topright = p[topright_idx]
360+
midright = p[midright_idx]
361+
botright = p[botright_idx]
362+
363+
hplusx_topleft = hplusx[ne_topleft]
364+
hplusx_botleft = hplusx[ne_botleft]
365+
hplusy_topleft = hplusy[ne_topleft]
366+
hplusy_botleft = hplusy[ne_botleft]
367+
368+
hplusx_topright = hplusx[ne_topright]
369+
hplusx_botright = hplusx[ne_botright]
370+
hplusy_topright = hplusy[ne_topright]
371+
hplusy_botright = hplusy[ne_botright]
372+
373+
cxx_tl = cxx[ne_topleft]
374+
cxx_tr = cxx[ne_topright]
375+
cxx_bl = cxx[ne_botleft]
376+
cxx_br = cxx[ne_botright]
377+
378+
cxy_tl = cxy[ne_topleft]
379+
cxy_tr = cxy[ne_topright]
380+
cxy_bl = cxy[ne_botleft]
381+
cxy_br = cxy[ne_botright]
382+
383+
cyx_tl = cyx[ne_topleft]
384+
cyx_tr = cyx[ne_topright]
385+
cyx_bl = cyx[ne_botleft]
386+
cyx_br = cyx[ne_botright]
387+
388+
cyy_tl = cyy[ne_topleft]
389+
cyy_tr = cyy[ne_topright]
390+
cyy_bl = cyy[ne_botleft]
391+
cyy_br = cyy[ne_botright]
392+
393+
if x_wall and (cnt_x == 0):
394+
hplusx_topleft = 0.0
395+
hplusy_topleft = 0.0
396+
hplusx_botleft = 0.0
397+
hplusy_botleft = 0.0
398+
399+
if x_wall and (cnt_x == (iicxn - 1)):
400+
hplusx_topright = 0.0
401+
hplusy_topright = 0.0
402+
hplusx_botright = 0.0
403+
hplusy_botright = 0.0
404+
405+
if y_wall and (cnt_y == 0):
406+
if y_atmosphere:
407+
pass
408+
else:
409+
hplusx_topleft = 0.0
410+
hplusy_topleft = 0.0
411+
hplusx_topright = 0.0
412+
hplusy_topright = 0.0
413+
414+
if y_wall and (cnt_y == (iicyn - 1)):
415+
if y_atmosphere:
416+
pass
417+
else:
418+
hplusx_botleft = 0.0
419+
hplusy_botleft = 0.0
420+
hplusx_botright = 0.0
421+
hplusy_botright = 0.0
422+
423+
Dx_tl = 0.5 * (topmid - topleft + midmid - midleft) * hplusx_topleft
424+
Dx_tr = 0.5 * (topright - topmid + midright - midmid) * hplusx_topright
425+
Dx_bl = 0.5 * (botmid - botleft + midmid - midleft) * hplusx_botleft
426+
Dx_br = 0.5 * (botright - botmid + midright - midmid) * hplusx_botright
427+
428+
Dy_tl = 0.5 * (midmid - topmid + midleft - topleft) * hplusy_topleft
429+
Dy_tr = 0.5 * (midright - topright + midmid - topmid) * hplusy_topright
430+
Dy_bl = 0.5 * (botmid - midmid + botleft - midleft) * hplusy_botleft
431+
Dy_br = 0.5 * (botright - midright + botmid - midmid) * hplusy_botright
432+
433+
fac = 1.0
434+
Dxx = (
435+
0.5
436+
* (cxx_tr * Dx_tr - cxx_tl * Dx_tl + cxx_br * Dx_br - cxx_bl * Dx_bl)
437+
* oodx
438+
* oodx
439+
* fac
440+
)
441+
Dyy = (
442+
0.5
443+
* (cyy_br * Dy_br - cyy_tr * Dy_tr + cyy_bl * Dy_bl - cyy_tl * Dy_tl)
444+
* oody
445+
* oody
446+
* fac
447+
)
448+
Dyx = (
449+
0.5
450+
* (cxy_br * Dy_br - cxy_bl * Dy_bl + cxy_tr * Dy_tr - cxy_tl * Dy_tl)
451+
* oody
452+
* oodx
453+
* fac
454+
)
455+
Dxy = (
456+
0.5
457+
* (cyx_br * Dx_br - cyx_tr * Dx_tr + cyx_bl * Dx_bl - cyx_tl * Dx_tl)
458+
* oodx
459+
* oody
460+
* fac
461+
)
462+
463+
lap[idx] = Dxx + Dyy + Dyx + Dxy + hcenter[idx] * p[idx]
464+
465+
lap[idx] *= diag_inv[idx]
466+
467+
cnt_x += 1
468+
if cnt_x % iicxn == 0:
469+
cnt_y += 1
470+
cnt_x = 0
471+
472+
return lap
473+
474+
475+
209476
def stencil_27pt(elem, node, mpv, ud, diag_inv, dt):
210477
oodxyz = node.dxyz
211478
oodxyz = 1.0 / (oodxyz**2)

src/pybella/flow_solver/physics/low_mach/second_projection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def euler_backward_non_advective_impl_part(
188188
mem.mpv.rhs *= diag_inv
189189

190190
p2 = mem.mpv.p2_nodes[mem.node.i2].T
191-
lap = lm_lp.get_lap2D_stencil(mem.mpv, mem.node, coriolis_params, diag_inv, ud)
191+
lap = lm_lp.get_lap2D(mem.mpv, mem.node, coriolis_params, diag_inv, ud)
192192
sh = p2.shape[0] * p2.shape[1]
193193

194194
elif mem.elem.ndim == 3:

0 commit comments

Comments
 (0)