diff --git a/firedrake/assemble.py b/firedrake/assemble.py index c55466086d..ba7b9d39a2 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1745,7 +1745,7 @@ def _as_global_kernel_arg_output(_, self): if rank == 0: return op2.GlobalKernelArg((1,)) elif rank == 1 or rank == 2 and self._diagonal: - V, = Vs + V = Vs[0] if V.ufl_element().family() == "Real": return op2.GlobalKernelArg((1,)) else: @@ -2052,7 +2052,7 @@ def _as_parloop_arg_output(_, self): if rank == 0: return op2.GlobalParloopArg(self._tensor) elif rank == 1 or rank == 2 and self._diagonal: - V, = Vs + V = Vs[0] if V.ufl_element().family() == "Real": return op2.GlobalParloopArg(self._tensor) else: diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index 942faf2bd8..f5af0d6cf3 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -1387,8 +1387,7 @@ def arg_function_spaces(self): """Returns a tuple of function spaces that the tensor is defined on. """ - tensor, = self.operands - return tuple(arg.function_space() for arg in tensor.arguments()) + return tuple(arg.function_space() for arg in self.arguments()) def arguments(self): """Returns a tuple of arguments associated with the tensor.""" diff --git a/tests/firedrake/slate/test_linear_algebra.py b/tests/firedrake/slate/test_linear_algebra.py index 23a2670217..d5c7dde997 100644 --- a/tests/firedrake/slate/test_linear_algebra.py +++ b/tests/firedrake/slate/test_linear_algebra.py @@ -152,10 +152,8 @@ def test_inverse_action(mat_type, rhs_type): assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13) -@pytest.mark.parametrize("mat_type, rhs_type", [ - ("slate", "slate"), ("slate", "form"), ("slate", "cofunction"), - ("aij", "cofunction"), ("aij", "form"), - ("matfree", "cofunction"), ("matfree", "form")]) +@pytest.mark.parametrize("rhs_type", ["slate", "form", "cofunction"]) +@pytest.mark.parametrize("mat_type", ["slate", "aij", "matfree"]) def test_solve_interface(mat_type, rhs_type): mesh = UnitSquareMesh(1, 1) V = FunctionSpace(mesh, "HDivT", 0) @@ -180,12 +178,8 @@ def test_solve_interface(mat_type, rhs_type): else: raise ValueError("Invalid rhs type") - sp = None - if mat_type == "matfree": - sp = {"pc_type": "none"} - x = Function(V) problem = LinearVariationalProblem(A, b, x, bcs=bcs) - solver = LinearVariationalSolver(problem, solver_parameters=sp) + solver = LinearVariationalSolver(problem) solver.solve() assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)