diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index 05d3438318..d816ef19ec 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -826,8 +826,27 @@ class SameMeshInterpolator(Interpolator): """ @no_annotations - def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None, **kwargs): - super().__init__(expr, V, subset, freeze_expr, access, bcs) + def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None, allow_missing_dofs=False, **kwargs): + if subset is None: + target = V.function_space().mesh().topology if isinstance(V, firedrake.Function) else V.mesh().topology + temp = extract_unique_domain(expr) + source = target if temp is None else temp.topology + if all(isinstance(m, firedrake.mesh.MeshTopology) for m in [target, source]) and target is not source: + composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None) + if result_integral_type != "cell": + raise AssertionError("Only cell-cell interpolation supported") + indices_active = composed_map.indices_active_with_halo + make_subset = not indices_active.all() + make_subset = target.comm.allreduce(make_subset, op=MPI.LOR) + if make_subset: + if not allow_missing_dofs: + raise ValueError("iteration (sub)set unclear: run with `allow_missing_dofs=True`") + subset = op2.Subset(target.cell_set, numpy.where(indices_active)) + else: + # Do not need subset as target <= source. + pass + super().__init__(expr, V, subset=subset, freeze_expr=freeze_expr, + access=access, bcs=bcs, allow_missing_dofs=allow_missing_dofs) try: self.callable, arguments = make_interpolator(expr, V, subset, access, bcs=bcs) except FIAT.hdiv_trace.TraceError: diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 81e3865465..13a43143ba 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -30,6 +30,7 @@ class Map: """ dtype = dtypes.IntType + VALUE_UNDEFINED = -1 @utils.validate_type(('iterset', Set, ex.SetTypeError), ('toset', Set, ex.SetTypeError), ('arity', numbers.Integral, ex.ArityTypeError), ('name', str, ex.NameTypeError)) @@ -278,7 +279,42 @@ def values(self): @utils.cached_property def values_with_halo(self): - raise RuntimeError("ComposedMap does not store values directly") + r = np.empty(self.shape, dtype=Map.dtype) + # Initialise map values. + r[:, 0] = np.arange(r.shape[0]) + # Initialise mask values. + mask = np.full(r.shape[0], True, dtype=bool) + temp = np.empty_like(mask) + for m in reversed(self.maps_): + a = m.values_with_halo + # Update mask according to whether map target is defined or not. + temp[:] = mask[:] + mask[temp] &= a[r[:, 0][temp], 0] != Map.VALUE_UNDEFINED + # Update map values (only where targets are defined). + r[mask, :] = a[r[:, 0][mask], :] + r[~mask, :] = Map.VALUE_UNDEFINED + return r + + @utils.cached_property + def indices_active_with_halo(self): + """Return boolean array for active indices. + + Returns + ------- + numpy.ndarray + Boolean array of size (self._iterset.total_size,), whose values + are `False` if the corresponding entries in the iterset have + no targets, or if the target values are `Map.VALUE_UNDEFINED`. + + """ + r = self.values_with_halo[:, 0] != Map.VALUE_UNDEFINED + if ( + (self.values_with_halo[r, :] == Map.VALUE_UNDEFINED).any() or not (self.values_with_halo[~r, :] == Map.VALUE_UNDEFINED).all() + ): + raise AssertionError( + "target values of a given entry must be all defined or all undefined" + ) + return r def __str__(self): return "OP2 ComposedMap of Maps: [%s]" % ",".join([str(m) for m in self.maps_]) diff --git a/tests/firedrake/submesh/test_submesh_interpolate.py b/tests/firedrake/submesh/test_submesh_interpolate.py index cc6dac2231..c2cf9cad0e 100644 --- a/tests/firedrake/submesh/test_submesh_interpolate.py +++ b/tests/firedrake/submesh/test_submesh_interpolate.py @@ -45,20 +45,17 @@ def _test_submesh_interpolate_cell_cell(mesh, subdomain_cond, fe_fesub): fsub = Function(Vsub).interpolate(f) assert np.allclose(fsub.dat.data_ro_with_halos, gsub.dat.data_ro_with_halos) f = Function(V_).interpolate(f) - g = Function(V) - # interpolation on subdomain only makes sense - # if there is no ambiguity on the subdomain boundary. - # For testing, the following suffices. - g.interpolate(f) - temp = Constant(999.*np.ones(V.value_shape)) - g.interpolate(temp, subset=mesh.topology.cell_subset(label_value)) # pollute the data - g.interpolate(gsub, subset=mesh.topology.cell_subset(label_value)) + v0 = Coargument(V.dual(), 0) + v1 = TrialFunction(Vsub) + interp = Interpolate(v1, v0, allow_missing_dofs=True) + A = assemble(interp) + g = assemble(action(A, gsub)) assert assemble(inner(g - f, g - f) * dx(label_value)).real < 1e-14 @pytest.mark.parametrize('nelem', [2, 4, 8, None]) @pytest.mark.parametrize('fe_fesub', [[("DQ", 0), ("DQ", 0)], - [("Q", 4), ("DQ", 5)]]) + [("Q", 4), ("Q", 5)]]) @pytest.mark.parametrize('condx', [LT]) @pytest.mark.parametrize('condy', [LT]) @pytest.mark.parametrize('condz', [LT]) @@ -78,7 +75,7 @@ def test_submesh_interpolate_cell_cell_hex_1_processes(fe_fesub, nelem, condx, c @pytest.mark.parallel(nprocs=3) @pytest.mark.parametrize('nelem', [2, 4, 8, None]) @pytest.mark.parametrize('fe_fesub', [[("DQ", 0), ("DQ", 0)], - [("Q", 4), ("DQ", 5)]]) + [("Q", 4), ("Q", 5)]]) @pytest.mark.parametrize('condx', [LT, GT]) @pytest.mark.parametrize('condy', [LT, GT]) @pytest.mark.parametrize('condz', [LT, GT]) @@ -97,7 +94,7 @@ def test_submesh_interpolate_cell_cell_hex_3_processes(fe_fesub, nelem, condx, c @pytest.mark.parallel(nprocs=3) @pytest.mark.parametrize('fe_fesub', [[("DP", 0), ("DP", 0)], - [("P", 4), ("DP", 5)], + [("P", 4), ("P", 5)], [("BDME", 2), ("BDME", 3)], [("BDMF", 2), ("BDMF", 3)]]) @pytest.mark.parametrize('condx', [LT, GT]) @@ -114,7 +111,7 @@ def test_submesh_interpolate_cell_cell_tri_3_processes(fe_fesub, condx, condy, d @pytest.mark.parallel(nprocs=3) @pytest.mark.parametrize('fe_fesub', [[("DQ", 0), ("DQ", 0)], - [("Q", 4), ("DQ", 5)]]) + [("Q", 4), ("Q", 5)]]) @pytest.mark.parametrize('condx', [LT, GT]) @pytest.mark.parametrize('condy', [LT, GT]) @pytest.mark.parametrize('distribution_parameters', [None, {"overlap_type": (DistributedMeshOverlapType.NONE, 0)}]) @@ -124,3 +121,48 @@ def test_submesh_interpolate_cell_cell_quad_3_processes(fe_fesub, condx, condy, cond = conditional(condx(x, 0.5), 1, conditional(condy(y, 0.5), 1, 0)) # noqa: E128 _test_submesh_interpolate_cell_cell(mesh, cond, fe_fesub) + + +@pytest.mark.parallel(nprocs=2) +def test_submesh_interpolate_subcell_subcell_2_processes(): + # mesh + # rank 0: + # 4---12----6---15---(8)-(18)-(10) + # | | | | + # 11 0 13 1 (17) (2) (19) + # | | | | + # 3---14----5---16---(7)-(20)--(9) + # rank 1: + # (7)-(13)---3----9----5 + # | | | + # (12) (1) 8 0 10 + # | | | plex points + # (6)-(14)---2---11----4 () = ghost + mesh = RectangleMesh( + 3, 1, 3., 1., quadrilateral=True, distribution_parameters={"partitioner_type": "simple"}, + ) + dim = mesh.topological_dimension() + x, _ = SpatialCoordinate(mesh) + DG0 = FunctionSpace(mesh, "DG", 0) + f_l = Function(DG0).interpolate(conditional(x < 2.0, 1, 0)) + f_r = Function(DG0).interpolate(conditional(x > 1.0, 1, 0)) + mesh = RelabeledMesh(mesh, [f_l, f_r], [111, 222]) + mesh_l = Submesh(mesh, dim, 111) + mesh_r = Submesh(mesh, dim, 222) + V_l = FunctionSpace(mesh_l, "CG", 1) + V_r = FunctionSpace(mesh_r, "CG", 1) + f_l = Function(V_l) + f_r = Function(V_r) + f_l.dat.data_with_halos[:] = 1.0 + f_r.dat.data_with_halos[:] = 2.0 + f_l.interpolate(f_r, allow_missing_dofs=True) + g_l = Function(V_l).interpolate(conditional(x > 0.999, 2.0, 1.0)) + assert np.allclose(f_l.dat.data_with_halos, g_l.dat.data_with_halos) + f_l.dat.data_with_halos[:] = 3.0 + v0 = Coargument(V_r.dual(), 0) + v1 = TrialFunction(V_l) + interp = Interpolate(v1, v0, allow_missing_dofs=True) + A = assemble(interp) + f_r = assemble(action(A, f_l)) + g_r = Function(V_r).interpolate(conditional(x < 2.001, 3.0, 0.0)) + assert np.allclose(f_r.dat.data_with_halos, g_r.dat.data_with_halos)