Calculate sensitivity when the control is a fem function

Here is a version that work for non-homogeneous and in parallel:

# Adjoint computation for the Poisson problem
# Author: Jørgen S. Dokken

import dolfinx
from ufl import Measure, inner, grad, TestFunction, adjoint, derivative, action
from dolfinx.mesh import CellType, create_rectangle
from dolfinx import fem, nls, la
from mpi4py import MPI
import numpy as np
import dolfinx.nls.petsc
import dolfinx.fem.petsc


class Problem():
    def __init__(self, V, bcs, u_obs):
        """Initialize all variational forms for the given problem, for a given set of boundary conditions and observed u"""
        # Residual of the variational form of Poisson problem
        self.f = fem.Function(V)  # The forcing f is the design variable
        self.u = fem.Function(V)
        R = inner(grad(self.u), grad(v))*dx - inner(self.f, v)*dx

        # Forward problem
        self.forward_problem = fem.petsc.NonlinearProblem(
            R, self.u, bcs)
        self.forward_solver = nls.petsc.NewtonSolver(
            V.mesh.comm, self.forward_problem)
        self.forward_solver.rtol = 1e-16

        # Define cost function
        J = (1/2)*inner(self.u-u_obs, self.u-u_obs)*dx
        self.J_form = fem.form(J)

        # Define derivative of cost function
        # Bilinear and linear form of the adjoint problem
        dRdu = adjoint(derivative(R, self.u))
        adjoint_lhs = fem.form(dRdu)
        adjoint_rhs = derivative(J, self.u)

        # Create adjoint problem solver
        self.lmbda = fem.Function(V)
        lambda_0 = fem.Function(V)
        lambda_0.x.array[:] = 0
        bcs_adjoint = [fem.dirichletbc(
            lambda_0, bc._cpp_object.dof_indices()[0]) for bc in bcs]  # Homogenize bcs
        self.adjoint_problem = dolfinx.fem.petsc.LinearProblem(adjoint_lhs, adjoint_rhs, u=self.lmbda,
                                                               bcs=bcs_adjoint, petsc_options={
                                                                   "ksp_type": "preonly", "pc_type": "lu"})

        # Compute sensitivity: dJ/dm
        # Partial derivative of J w.r.t. m
        dJdm = derivative(J, self.f)
        # partial derivative of R w.r.t. m
        dRdm = action(adjoint(derivative(R, self.f)), self.lmbda)

        self.dJdm = fem.form(dJdm-dRdm)

    def eval_forward(self, f_data: np.ndarray):
        """Compute functional for a given control"""
        # Compute solution
        self.f.x.array[:] = f_data
        self.f.x.scatter_forward()
        self.forward_solver.solve(self.u)
        self.u.x.scatter_forward()
        return self.u.function_space.mesh.comm.allreduce(fem.assemble_scalar(self.J_form), op=MPI.SUM)

    def eval_derivative(self):
        """
        Compute derivative of functional
        """
        self.adjoint_problem.solve()
        self.lmbda.x.scatter_forward()
        dJdm_local = fem.assemble_vector(self.dJdm)
        dJdm_local.scatter_reverse(la.InsertMode.add)
        dJdm_local.scatter_forward()
        return dJdm_local.array[: dJdm_local.index_map.size_local * dJdm_local.block_size]


def taylor_test(cost, grad, m_0, p=1e-2, n=5):
    """
    Compute a Taylor test for a cost function and gradient function from `m_0` in direction `p` 

    """
    l0 = cost(m_0)
    local_gradient = grad()
    global_gradient = np.hstack(MPI.COMM_WORLD.allgather(local_gradient))

    if isinstance(p, float):
        p = np.full_like(m_0, p)
    p_global = np.hstack(MPI.COMM_WORLD.allgather(p[:len(local_gradient)]))
    dJdm = np.dot(global_gradient, p_global)
    remainder = []
    perturbance = []
    for i in range(0, n):
        step = 0.5**i
        l1 = cost(m_0+step*p)
        remainder.append(l1-l0 - step*dJdm)
        perturbance.append(step)
    conv_rate = convergence_rates(remainder, perturbance)
    return remainder, perturbance, conv_rate


def convergence_rates(r, p):
    cr = []  # convergence rates
    for i in range(1, len(p)):
        cr.append(np.log(r[i] / r[i - 1])
                  / np.log(p[i] / p[i - 1]))
    return cr


# mesh and function space
mesh = create_rectangle(MPI.COMM_WORLD, [[0, 0], [1, 1]], [
                        64, 64], CellType.triangle)
V = fem.functionspace(mesh, ("CG", 1))
v = TestFunction(V)
dx = Measure("dx", domain=mesh)


def left(x):
    return np.isclose(x[0], 0)


def right(x):
    return np.isclose(x[0], 1)


def top(x):
    return np.isclose(x[1], 1)


def bottom(x):
    return np.isclose(x[1], 0)


fdim = mesh.topology.dim - 1

uD = fem.Function(V)

left_facets = dolfinx.mesh.locate_entities_boundary(mesh, fdim, left)
right_facets = dolfinx.mesh.locate_entities_boundary(mesh, fdim, right)
top_facets = dolfinx.mesh.locate_entities_boundary(mesh, fdim, top)
bottom_facets = dolfinx.mesh.locate_entities_boundary(mesh, fdim, bottom)

with uD.vector.localForm() as loc:
    loc.set(1.0)

bc1 = fem.dirichletbc(
    uD, dolfinx.fem.locate_dofs_topological(V, fdim, left_facets))
bc2 = fem.dirichletbc(
    uD, dolfinx.fem.locate_dofs_topological(V, fdim, right_facets))
bc3 = fem.dirichletbc(
    uD, dolfinx.fem.locate_dofs_topological(V, fdim, top_facets))
bc4 = fem.dirichletbc(
    uD, dolfinx.fem.locate_dofs_topological(V, fdim, bottom_facets))

bcs = [bc1, bc2, bc3, bc4]


# Define desired u profile


def profile(x):
    return x[0]*(1-x[0])*x[1]*(1-x[1])


u_obs = fem.Function(V)
u_obs.interpolate(profile)


problem = Problem(V, bcs, u_obs)


f_data = 10*np.random.random(len(problem.f.x.array)).astype(np.float64)
f_0 = np.zeros_like(f_data)
J_0 = problem.eval_forward(f_0)

dJdf = problem.eval_derivative()


error, perturbance, rate = taylor_test(problem.eval_forward,
                                       grad=problem.eval_derivative, m_0=f_0, p=f_data)

if mesh.comm.rank == 0:
    print(f"(1) Error: {error}")
    print(f"(1) Perturbance: {perturbance}")
    print(f"(1) Convergence rate: {rate}")


error, perturbance, rate = taylor_test(problem.eval_forward,
                                       grad=problem.eval_derivative, m_0=f_0)
if mesh.comm.rank == 0:
    print(f"(2) Error: {error}")
    print(f"(2) Perturbance: {perturbance}")
    print(f"(2) Convergence rate: {rate}")

serial:

root@dokken-XPS-15-9560:~/shared/debug# python3 adjoint.py 
(1) Error: [0.021619627022312904, 0.005404906755620248, 0.0013512266889327135, 0.00033780667226231825, 8.445166809268809e-05]
(1) Perturbance: [1.0, 0.5, 0.25, 0.125, 0.0625]
(1) Convergence rate: [1.9999999999887834, 1.9999999999704765, 1.99999999987555, 1.9999999995369027]
(2) Error: [8.501957100197442e-08, 2.1254875918607242e-08, 5.313705366477095e-09, 1.328413713299931e-09, 3.320927149065762e-10]
(2) Perturbance: [1.0, 0.5, 0.25, 0.125, 0.0625]
(2) Convergence rate: [2.0000011424799404, 2.0000036960337733, 2.0000137146484653, 2.000046541057984]

parallel:

mpirun -n 4 python3 adjoint.py 
(1) Error: [0.022252050386877914, 0.00556301259674051, 0.0013907531492012223, 0.00034768828731673515, 8.692207184526728e-05]
(1) Perturbance: [1.0, 0.5, 0.25, 0.125, 0.0625]
(1) Convergence rate: [1.9999999999945455, 1.999999999983304, 1.9999999999318272, 1.9999999997330533]
(2) Error: [8.501958598212478e-08, 2.1254894566423823e-08, 5.313721296212382e-09, 1.3284283672612977e-09, 3.321041497124507e-10]
(2) Perturbance: [1.0, 0.5, 0.25, 0.125, 0.0625]
(2) Convergence rate: [2.000000130939888, 2.000000636782914, 2.000002125106489, 2.0000127807449206]
3 Likes