Here is a version that work for non-homogeneous and in parallel:
# Adjoint computation for the Poisson problem
# Author: Jørgen S. Dokken
import dolfinx
from ufl import Measure, inner, grad, TestFunction, adjoint, derivative, action
from dolfinx.mesh import CellType, create_rectangle
from dolfinx import fem, nls, la
from mpi4py import MPI
import numpy as np
import dolfinx.nls.petsc
import dolfinx.fem.petsc
class Problem():
def __init__(self, V, bcs, u_obs):
"""Initialize all variational forms for the given problem, for a given set of boundary conditions and observed u"""
# Residual of the variational form of Poisson problem
self.f = fem.Function(V) # The forcing f is the design variable
self.u = fem.Function(V)
R = inner(grad(self.u), grad(v))*dx - inner(self.f, v)*dx
# Forward problem
self.forward_problem = fem.petsc.NonlinearProblem(
R, self.u, bcs)
self.forward_solver = nls.petsc.NewtonSolver(
V.mesh.comm, self.forward_problem)
self.forward_solver.rtol = 1e-16
# Define cost function
J = (1/2)*inner(self.u-u_obs, self.u-u_obs)*dx
self.J_form = fem.form(J)
# Define derivative of cost function
# Bilinear and linear form of the adjoint problem
dRdu = adjoint(derivative(R, self.u))
adjoint_lhs = fem.form(dRdu)
adjoint_rhs = derivative(J, self.u)
# Create adjoint problem solver
self.lmbda = fem.Function(V)
lambda_0 = fem.Function(V)
lambda_0.x.array[:] = 0
bcs_adjoint = [fem.dirichletbc(
lambda_0, bc._cpp_object.dof_indices()[0]) for bc in bcs] # Homogenize bcs
self.adjoint_problem = dolfinx.fem.petsc.LinearProblem(adjoint_lhs, adjoint_rhs, u=self.lmbda,
bcs=bcs_adjoint, petsc_options={
"ksp_type": "preonly", "pc_type": "lu"})
# Compute sensitivity: dJ/dm
# Partial derivative of J w.r.t. m
dJdm = derivative(J, self.f)
# partial derivative of R w.r.t. m
dRdm = action(adjoint(derivative(R, self.f)), self.lmbda)
self.dJdm = fem.form(dJdm-dRdm)
def eval_forward(self, f_data: np.ndarray):
"""Compute functional for a given control"""
# Compute solution
self.f.x.array[:] = f_data
self.f.x.scatter_forward()
self.forward_solver.solve(self.u)
self.u.x.scatter_forward()
return self.u.function_space.mesh.comm.allreduce(fem.assemble_scalar(self.J_form), op=MPI.SUM)
def eval_derivative(self):
"""
Compute derivative of functional
"""
self.adjoint_problem.solve()
self.lmbda.x.scatter_forward()
dJdm_local = fem.assemble_vector(self.dJdm)
dJdm_local.scatter_reverse(la.InsertMode.add)
dJdm_local.scatter_forward()
return dJdm_local.array[: dJdm_local.index_map.size_local * dJdm_local.block_size]
def taylor_test(cost, grad, m_0, p=1e-2, n=5):
"""
Compute a Taylor test for a cost function and gradient function from `m_0` in direction `p`
"""
l0 = cost(m_0)
local_gradient = grad()
global_gradient = np.hstack(MPI.COMM_WORLD.allgather(local_gradient))
if isinstance(p, float):
p = np.full_like(m_0, p)
p_global = np.hstack(MPI.COMM_WORLD.allgather(p[:len(local_gradient)]))
dJdm = np.dot(global_gradient, p_global)
remainder = []
perturbance = []
for i in range(0, n):
step = 0.5**i
l1 = cost(m_0+step*p)
remainder.append(l1-l0 - step*dJdm)
perturbance.append(step)
conv_rate = convergence_rates(remainder, perturbance)
return remainder, perturbance, conv_rate
def convergence_rates(r, p):
cr = [] # convergence rates
for i in range(1, len(p)):
cr.append(np.log(r[i] / r[i - 1])
/ np.log(p[i] / p[i - 1]))
return cr
# mesh and function space
mesh = create_rectangle(MPI.COMM_WORLD, [[0, 0], [1, 1]], [
64, 64], CellType.triangle)
V = fem.functionspace(mesh, ("CG", 1))
v = TestFunction(V)
dx = Measure("dx", domain=mesh)
def left(x):
return np.isclose(x[0], 0)
def right(x):
return np.isclose(x[0], 1)
def top(x):
return np.isclose(x[1], 1)
def bottom(x):
return np.isclose(x[1], 0)
fdim = mesh.topology.dim - 1
uD = fem.Function(V)
left_facets = dolfinx.mesh.locate_entities_boundary(mesh, fdim, left)
right_facets = dolfinx.mesh.locate_entities_boundary(mesh, fdim, right)
top_facets = dolfinx.mesh.locate_entities_boundary(mesh, fdim, top)
bottom_facets = dolfinx.mesh.locate_entities_boundary(mesh, fdim, bottom)
with uD.vector.localForm() as loc:
loc.set(1.0)
bc1 = fem.dirichletbc(
uD, dolfinx.fem.locate_dofs_topological(V, fdim, left_facets))
bc2 = fem.dirichletbc(
uD, dolfinx.fem.locate_dofs_topological(V, fdim, right_facets))
bc3 = fem.dirichletbc(
uD, dolfinx.fem.locate_dofs_topological(V, fdim, top_facets))
bc4 = fem.dirichletbc(
uD, dolfinx.fem.locate_dofs_topological(V, fdim, bottom_facets))
bcs = [bc1, bc2, bc3, bc4]
# Define desired u profile
def profile(x):
return x[0]*(1-x[0])*x[1]*(1-x[1])
u_obs = fem.Function(V)
u_obs.interpolate(profile)
problem = Problem(V, bcs, u_obs)
f_data = 10*np.random.random(len(problem.f.x.array)).astype(np.float64)
f_0 = np.zeros_like(f_data)
J_0 = problem.eval_forward(f_0)
dJdf = problem.eval_derivative()
error, perturbance, rate = taylor_test(problem.eval_forward,
grad=problem.eval_derivative, m_0=f_0, p=f_data)
if mesh.comm.rank == 0:
print(f"(1) Error: {error}")
print(f"(1) Perturbance: {perturbance}")
print(f"(1) Convergence rate: {rate}")
error, perturbance, rate = taylor_test(problem.eval_forward,
grad=problem.eval_derivative, m_0=f_0)
if mesh.comm.rank == 0:
print(f"(2) Error: {error}")
print(f"(2) Perturbance: {perturbance}")
print(f"(2) Convergence rate: {rate}")
serial:
root@dokken-XPS-15-9560:~/shared/debug# python3 adjoint.py
(1) Error: [0.021619627022312904, 0.005404906755620248, 0.0013512266889327135, 0.00033780667226231825, 8.445166809268809e-05]
(1) Perturbance: [1.0, 0.5, 0.25, 0.125, 0.0625]
(1) Convergence rate: [1.9999999999887834, 1.9999999999704765, 1.99999999987555, 1.9999999995369027]
(2) Error: [8.501957100197442e-08, 2.1254875918607242e-08, 5.313705366477095e-09, 1.328413713299931e-09, 3.320927149065762e-10]
(2) Perturbance: [1.0, 0.5, 0.25, 0.125, 0.0625]
(2) Convergence rate: [2.0000011424799404, 2.0000036960337733, 2.0000137146484653, 2.000046541057984]
parallel:
mpirun -n 4 python3 adjoint.py
(1) Error: [0.022252050386877914, 0.00556301259674051, 0.0013907531492012223, 0.00034768828731673515, 8.692207184526728e-05]
(1) Perturbance: [1.0, 0.5, 0.25, 0.125, 0.0625]
(1) Convergence rate: [1.9999999999945455, 1.999999999983304, 1.9999999999318272, 1.9999999997330533]
(2) Error: [8.501958598212478e-08, 2.1254894566423823e-08, 5.313721296212382e-09, 1.3284283672612977e-09, 3.321041497124507e-10]
(2) Perturbance: [1.0, 0.5, 0.25, 0.125, 0.0625]
(2) Convergence rate: [2.000000130939888, 2.000000636782914, 2.000002125106489, 2.0000127807449206]