diff --git a/src/yadism/coefficient_functions/heavy/n3lo/__init__.py b/src/yadism/coefficient_functions/heavy/n3lo/__init__.py index 7f095724..6f78d704 100644 --- a/src/yadism/coefficient_functions/heavy/n3lo/__init__.py +++ b/src/yadism/coefficient_functions/heavy/n3lo/__init__.py @@ -1,7 +1,7 @@ import pathlib import numpy as np -from scipy.interpolate import RectBivariateSpline +from scipy.interpolate import NearestNDInterpolator, RectBivariateSpline grid_path = pathlib.Path(__file__).parent / "grids" @@ -11,6 +11,21 @@ interpolators = {} +def fill_nans_nearest_neighbor(xi_grid, eta_grid, coeffs): + z_filled = coeffs.copy() + mask = np.isfinite(coeffs) + values = coeffs[mask] + + xm, ym = np.meshgrid(xi_grid, eta_grid, indexing="ij") + points = np.column_stack((xm[mask], ym[mask])) + + interp_nn = NearestNDInterpolator(points, values) + z_filled[~mask] = interp_nn(xm[~mask], ym[~mask]) + assert np.all(np.isfinite(z_filled)) + + return z_filled + + def interpolator(coeff, nf, variation): grid_name = f"{coeff}_nf{int(nf)}_var{int(variation)}.npy" @@ -20,6 +35,8 @@ def interpolator(coeff, nf, variation): # load grid coeff = np.load(grid_path / grid_name) + if np.isnan(coeff).any(): + coeff = fill_nans_nearest_neighbor(xi_grid, eta_grid, coeff) grid_interpolator = RectBivariateSpline(xi_grid, eta_grid, coeff) # store result diff --git a/src/yadism/runner.py b/src/yadism/runner.py index cc41f934..794ce5db 100644 --- a/src/yadism/runner.py +++ b/src/yadism/runner.py @@ -216,7 +216,7 @@ def replace_nans_with_0(self, out): # Loop through each observable in the dictionary for observable, points in out2.items(): # Skip the keys that are not an observable - if observable not in observable_name.kinds: + if not observable_name.ObservableName.is_valid(observable): continue # Loop over the kinematic points