Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/yadism/coefficient_functions/heavy/n3lo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib

import numpy as np
from scipy.interpolate import RectBivariateSpline
from scipy.interpolate import NearestNDInterpolator, RectBivariateSpline

grid_path = pathlib.Path(__file__).parent / "grids"

Expand All @@ -11,6 +11,21 @@
interpolators = {}


def fill_nans_nearest_neighbor(xi_grid, eta_grid, coeffs):
z_filled = coeffs.copy()
mask = np.isfinite(coeffs)
values = coeffs[mask]

xm, ym = np.meshgrid(xi_grid, eta_grid, indexing="ij")
points = np.column_stack((xm[mask], ym[mask]))

interp_nn = NearestNDInterpolator(points, values)
z_filled[~mask] = interp_nn(xm[~mask], ym[~mask])
assert np.all(np.isfinite(z_filled))

return z_filled


def interpolator(coeff, nf, variation):
grid_name = f"{coeff}_nf{int(nf)}_var{int(variation)}.npy"

Expand All @@ -20,6 +35,8 @@ def interpolator(coeff, nf, variation):

# load grid
coeff = np.load(grid_path / grid_name)
if np.isnan(coeff).any():
coeff = fill_nans_nearest_neighbor(xi_grid, eta_grid, coeff)
grid_interpolator = RectBivariateSpline(xi_grid, eta_grid, coeff)

# store result
Expand Down
2 changes: 1 addition & 1 deletion src/yadism/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def replace_nans_with_0(self, out):
# Loop through each observable in the dictionary
for observable, points in out2.items():
# Skip the keys that are not an observable
if observable not in observable_name.kinds:
if not observable_name.ObservableName.is_valid(observable):
continue

# Loop over the kinematic points
Expand Down