Skip to content

interpolate: sanitise submesh interpolate #4482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,27 @@ class SameMeshInterpolator(Interpolator):
"""

@no_annotations
def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None, **kwargs):
super().__init__(expr, V, subset, freeze_expr, access, bcs)
def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None, allow_missing_dofs=False, **kwargs):
if subset is None:
target = V.function_space().mesh().topology if isinstance(V, firedrake.Function) else V.mesh().topology
temp = extract_unique_domain(expr)
source = target if temp is None else temp.topology
if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source:
composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None)
if result_integral_type != "cell":
raise AssertionError("Only cell-cell interpolation supported")
indices_active = composed_map.indices_active_with_halo
make_subset = not indices_active.all()
make_subset = target.comm.allreduce(make_subset, op=MPI.LOR)
if make_subset:
if not allow_missing_dofs:
raise ValueError("iteration (sub)set unclear: run with `allow_missing_dofs=True`")
subset = op2.Subset(target.cell_set, numpy.where(indices_active))
else:
# Do not need subset as target <= source.
pass
super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr,
access=access, bcs=bcs, allow_missing_dofs=allow_missing_dofs)
try:
self.callable, arguments = make_interpolator(expr, V, subset, access, bcs=bcs)
except FIAT.hdiv_trace.TraceError:
Expand Down
38 changes: 37 additions & 1 deletion pyop2/types/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Map:
"""

dtype = dtypes.IntType
VALUE_UNDEFINED = -1

@utils.validate_type(('iterset', Set, ex.SetTypeError), ('toset', Set, ex.SetTypeError),
('arity', numbers.Integral, ex.ArityTypeError), ('name', str, ex.NameTypeError))
Expand Down Expand Up @@ -278,7 +279,42 @@ def values(self):

@utils.cached_property
def values_with_halo(self):
raise RuntimeError("ComposedMap does not store values directly")
r = np.empty(self.shape, dtype=Map.dtype)
# Initialise map values.
r[:, 0] = np.arange(r.shape[0])
# Initialise mask values.
mask = np.full(r.shape[0], True, dtype=bool)
temp = np.empty_like(mask)
for m in reversed(self.maps_):
a = m.values_with_halo
# Update mask according to whether map target is defined or not.
temp[:] = mask[:]
mask[temp] &= a[r[:, 0][temp], 0] != Map.VALUE_UNDEFINED
# Update map values (only where targets are defined).
r[mask, :] = a[r[:, 0][mask], :]
r[~mask, :] = Map.VALUE_UNDEFINED
return r

@utils.cached_property
def indices_active_with_halo(self):
"""Return boolean array for active indices.

Returns
-------
numpy.ndarray
Boolean array of size (self._iterset.total_size,), whose values
are `False` if the corresponding entries in the iterset have
no targets, or if the target values are `Map.VALUE_UNDEFINED`.

"""
r = self.values_with_halo[:, 0] != Map.VALUE_UNDEFINED
if (
(self.values_with_halo[r, :] == Map.VALUE_UNDEFINED).any() or not (self.values_with_halo[~r, :] == Map.VALUE_UNDEFINED).all()
):
raise AssertionError(
"target values of a given entry must be all defined or all undefined"
)
return r

def __str__(self):
return "OP2 ComposedMap of Maps: [%s]" % ",".join([str(m) for m in self.maps_])
Expand Down
66 changes: 54 additions & 12 deletions tests/firedrake/submesh/test_submesh_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,17 @@ def _test_submesh_interpolate_cell_cell(mesh, subdomain_cond, fe_fesub):
fsub = Function(Vsub).interpolate(f)
assert np.allclose(fsub.dat.data_ro_with_halos, gsub.dat.data_ro_with_halos)
f = Function(V_).interpolate(f)
g = Function(V)
# interpolation on subdomain only makes sense
# if there is no ambiguity on the subdomain boundary.
# For testing, the following suffices.
g.interpolate(f)
temp = Constant(999.*np.ones(V.value_shape))
g.interpolate(temp, subset=mesh.topology.cell_subset(label_value)) # pollute the data
g.interpolate(gsub, subset=mesh.topology.cell_subset(label_value))
v0 = Coargument(V.dual(), 0)
v1 = TrialFunction(Vsub)
interp = Interpolate(v1, v0, allow_missing_dofs=True)
A = assemble(interp)
g = assemble(action(A, gsub))
assert assemble(inner(g - f, g - f) * dx(label_value)).real < 1e-14


@pytest.mark.parametrize('nelem', [2, 4, 8, None])
@pytest.mark.parametrize('fe_fesub', [[("DQ", 0), ("DQ", 0)],
[("Q", 4), ("DQ", 5)]])
[("Q", 4), ("Q", 5)]])
@pytest.mark.parametrize('condx', [LT])
@pytest.mark.parametrize('condy', [LT])
@pytest.mark.parametrize('condz', [LT])
Expand All @@ -78,7 +75,7 @@ def test_submesh_interpolate_cell_cell_hex_1_processes(fe_fesub, nelem, condx, c
@pytest.mark.parallel(nprocs=3)
@pytest.mark.parametrize('nelem', [2, 4, 8, None])
@pytest.mark.parametrize('fe_fesub', [[("DQ", 0), ("DQ", 0)],
[("Q", 4), ("DQ", 5)]])
[("Q", 4), ("Q", 5)]])
@pytest.mark.parametrize('condx', [LT, GT])
@pytest.mark.parametrize('condy', [LT, GT])
@pytest.mark.parametrize('condz', [LT, GT])
Expand All @@ -97,7 +94,7 @@ def test_submesh_interpolate_cell_cell_hex_3_processes(fe_fesub, nelem, condx, c

@pytest.mark.parallel(nprocs=3)
@pytest.mark.parametrize('fe_fesub', [[("DP", 0), ("DP", 0)],
[("P", 4), ("DP", 5)],
[("P", 4), ("P", 5)],
[("BDME", 2), ("BDME", 3)],
[("BDMF", 2), ("BDMF", 3)]])
@pytest.mark.parametrize('condx', [LT, GT])
Expand All @@ -114,7 +111,7 @@ def test_submesh_interpolate_cell_cell_tri_3_processes(fe_fesub, condx, condy, d

@pytest.mark.parallel(nprocs=3)
@pytest.mark.parametrize('fe_fesub', [[("DQ", 0), ("DQ", 0)],
[("Q", 4), ("DQ", 5)]])
[("Q", 4), ("Q", 5)]])
@pytest.mark.parametrize('condx', [LT, GT])
@pytest.mark.parametrize('condy', [LT, GT])
@pytest.mark.parametrize('distribution_parameters', [None, {"overlap_type": (DistributedMeshOverlapType.NONE, 0)}])
Expand All @@ -124,3 +121,48 @@ def test_submesh_interpolate_cell_cell_quad_3_processes(fe_fesub, condx, condy,
cond = conditional(condx(x, 0.5), 1,
conditional(condy(y, 0.5), 1, 0)) # noqa: E128
_test_submesh_interpolate_cell_cell(mesh, cond, fe_fesub)


@pytest.mark.parallel(nprocs=2)
def test_submesh_interpolate_subcell_subcell_2_processes():
# mesh
# rank 0:
# 4---12----6---15---(8)-(18)-(10)
# | | | |
# 11 0 13 1 (17) (2) (19)
# | | | |
# 3---14----5---16---(7)-(20)--(9)
# rank 1:
# (7)-(13)---3----9----5
# | | |
# (12) (1) 8 0 10
# | | | plex points
# (6)-(14)---2---11----4 () = ghost
mesh = RectangleMesh(
3, 1, 3., 1., quadrilateral=True, distribution_parameters={"partitioner_type": "simple"},
)
dim = mesh.topological_dimension()
x, _ = SpatialCoordinate(mesh)
DG0 = FunctionSpace(mesh, "DG", 0)
f_l = Function(DG0).interpolate(conditional(x < 2.0, 1, 0))
f_r = Function(DG0).interpolate(conditional(x > 1.0, 1, 0))
mesh = RelabeledMesh(mesh, [f_l, f_r], [111, 222])
mesh_l = Submesh(mesh, dim, 111)
mesh_r = Submesh(mesh, dim, 222)
V_l = FunctionSpace(mesh_l, "CG", 1)
V_r = FunctionSpace(mesh_r, "CG", 1)
f_l = Function(V_l)
f_r = Function(V_r)
f_l.dat.data_with_halos[:] = 1.0
f_r.dat.data_with_halos[:] = 2.0
f_l.interpolate(f_r, allow_missing_dofs=True)
g_l = Function(V_l).interpolate(conditional(x > 0.999, 2.0, 1.0))
assert np.allclose(f_l.dat.data_with_halos, g_l.dat.data_with_halos)
f_l.dat.data_with_halos[:] = 3.0
v0 = Coargument(V_r.dual(), 0)
v1 = TrialFunction(V_l)
interp = Interpolate(v1, v0, allow_missing_dofs=True)
A = assemble(interp)
f_r = assemble(action(A, f_l))
g_r = Function(V_r).interpolate(conditional(x < 2.001, 3.0, 0.0))
assert np.allclose(f_r.dat.data_with_halos, g_r.dat.data_with_halos)
Loading