Calculate adjoint derivatives w.r.t to multiple Constants

Here is a little helper script - derivatives.py, and test for it bellow. It seems to do what it should be (both in serial and parralel). The test is done 2D domain, so it might need some tweaking for 3D.

Edit: It looks like it works in 3D as is.

This might be useful to some people, before new dolfin-adjoint is implemented. The usage is shown in the test bellow. It does not handle constrains.

from mpi4py import MPI
from dolfinx import geometry
from dolfinx import fem
import numpy as np
import ufl
from petsc4py import PETSc

import dolfinx.nls.petsc
import dolfinx.fem.petsc

def wrap_constant_controls(controls):
    vars = {}
    for control in controls:
        for i in range(len(control)):
            vars[control[i]] = ufl.variable(control[i])
    return vars

class UflObjective:
    def __init__(self, J, u, controls):
        self.J = J
        self.J_form = fem.form(J)
        self.u = u
        self.controls = controls

        self.rhs = -ufl.derivative(J, u)
        self.rhs_form = fem.form(self.rhs)
        self.b = dolfinx.fem.petsc.create_vector(self.rhs_form)

        self.vars = wrap_constant_controls(self.controls)
        self.J_var_form = ufl.replace(self.J, self.vars)

        self.dJdk_form = [fem.form(ufl.diff(self.J_var_form, v)) for v in self.vars.values()]

    def assemble_adjoint_rhs_vector(self):
        with self.b.localForm() as b_loc:
            b_loc.set(0.0)
        dolfinx.fem.petsc.assemble_vector(self.b, self.rhs_form)
        self.b.assemble()
        return self.b
    
    def eval_dJdk(self):
        dJdk = np.zeros(len(self.dJdk_form))
        for i in range(len(self.dJdk_form)):
            dJdk[i] = fem.assemble_scalar(self.dJdk_form[i])
            dJdk[i] = MPI.COMM_WORLD.allreduce(dJdk[i], op=MPI.SUM)
        return dJdk
    
    def evaluate(self):
        J_value = fem.assemble_scalar(self.J_form)
        J_value = MPI.COMM_WORLD.allreduce(J_value, op=MPI.SUM)
        return J_value

class Point_wise_lsq_objective:
    def __init__(self, p_coords, f, controls, true_values = None):
        self.p_coords = p_coords
        self.f = f
        self.controls = controls
        self.b = f.copy()
        self.true_values = true_values

        self.vars = wrap_constant_controls(self.controls)

        self.evaluate_dofs_sensitivities()

    def evaluate_dofs_sensitivities(self):
        # sensitivity of values at p_coords w.r.t values at dofs
        domain = self.f.function_space.mesh
        points = np.array(self.p_coords)
        fdim = domain.topology.dim
        bb_tree = geometry.bb_tree(domain, fdim)
        cell_candidates = geometry.compute_collisions_points(bb_tree, points)
        colliding_cells = geometry.compute_colliding_cells(domain, cell_candidates, points)

        # initialize result on each proc
        self.points = []
        self.cells = []
        self.dofs = []
        self.sensitivities = []
        self.true_values_map = []
        for i, point in enumerate(points):
            if len(colliding_cells.links(i)) > 0:
                cell = colliding_cells.links(i)[0]
                dofs = self.f.function_space.dofmap.cell_dofs(cell)
                s = np.zeros(len(dofs))

                # there might exist nicer way but finite diff works too on (CG, 1)
                for j in range(len(dofs)):
                    val_0 = self.f.eval(point, cell)
                    self.f.x.array[dofs[j]] += 1.0
                    val_1 = self.f.eval(point, cell)
                    self.f.x.array[dofs[j]] -= 1.0
                    s[j] = val_1[0] - val_0[0]
            
                self.points.append(point)
                self.cells.append(cell)
                self.dofs.append(dofs)
                self.sensitivities.append(s)
                self.true_values_map.append(i)
    
    def assemble_adjoint_rhs_vector(self):
        with self.b.vector.localForm() as b_loc:
            b_loc.set(0)
        for i in range(len(self.points)):
            # contribution to dJdu of single point
            value = -2*(self.f.eval(self.points[i], self.cells[i])[0]-self.true_values[self.true_values_map[i]])
            self.b.x.array[self.dofs[i]] += self.sensitivities[i]*value
        return self.b.vector

    def eval_points(self):
        values = []
        for i in range(len(self.points)):
            values.append(self.f.eval(self.points[i], self.cells[i])[0])
        return values
    
    def eval_dJdk(self):
        dJdk = np.zeros(len(self.vars.keys()))
        return dJdk
    
    def evaluate(self):
        J_value = 0.0
        for i in range(len(self.points)):
            J_value += (self.f.eval(self.points[i], self.cells[i])[0]-self.true_values[self.true_values_map[i]])**2
        J_value = MPI.COMM_WORLD.allreduce(J_value, op=MPI.SUM)
        return J_value

class AdjointDerivative:
    def __init__(self, J, controls, form, forward_solver, u, bcs=None):
        self.form = form
        self.controls = controls
        self.J = J
        self.forward_solver = forward_solver
        self.u = u
        self.lmbda = u.copy()
        self.bcs = bcs or []

        self.vars = wrap_constant_controls(self.controls)
        self.var_form = ufl.replace(self.form, self.vars)

        self.dFdk_form = [fem.form(ufl.diff(self.var_form, var)) for var in self.vars.values()]
        self.dFdk = [fem.petsc.create_vector(form) for form in self.dFdk_form]

        # adjoint problem solver definition
        self.lhs = ufl.adjoint(ufl.derivative(self.form, u))
        self.lhs_form = fem.form(self.lhs)
        self.A = dolfinx.fem.petsc.create_matrix(self.lhs_form)

        self.adjoint_solver = PETSc.KSP().create(u.function_space.mesh.comm)
        self.adjoint_solver.setOperators(self.A)
        self.adjoint_solver.setType(PETSc.KSP.Type.PREONLY)
        self.adjoint_solver.setTolerances(atol=1e-16)
        self.adjoint_solver.getPC().setType(PETSc.PC.Type.LU)

    def solve_adjoint_problem(self):

        # lhs assembly
        self.A.zeroEntries()
        dolfinx.fem.petsc.assemble_matrix_mat(self.A, self.lhs_form, bcs=self.bcs)
        self.A.assemble()

        # rhs assembly
        b = self.J.assemble_adjoint_rhs_vector()
        dolfinx.fem.petsc.apply_lifting(b, [self.lhs_form], bcs=[self.bcs])
        b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        dolfinx.fem.petsc.set_bc(b, self.bcs)

        self.adjoint_solver.solve(b, self.lmbda.vector)
        self.lmbda.x.scatter_forward()
        return self.lmbda

    def compute_gradient(self, *args, **kwargs):
    
        # run forward solve
        self.forward_solver(*args, **kwargs)

        # solve adjoint vector
        self.solve_adjoint_problem()

        #initialize empty gradient vector
        dJdk = self.J.eval_dJdk()
        for i in range(len(self.dFdk)):

            # reassemble adjoint stuff
            with self.dFdk[i].localForm() as dfdk_loc:
                dfdk_loc.set(0)
            dolfinx.fem.petsc.assemble_vector(self.dFdk[i], self.dFdk_form[i])
            self.dFdk[i].ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
            
            # add dFdk*lmbda
            dJdk[i] += self.dFdk[i].dot(self.lmbda.vector)
            

        J_value = self.J.evaluate()

        return dJdk, J_value
    
def taylor_test(loss, grad, k0, p=1e-3, n=5):
        g0 = grad(k0)
        l0 = loss(k0)
        reminder = []
        perturbance = []
        for i in range(0, n):
            l1 = loss(k0+p)
            reminder.append(l1 - l0 - np.sum(g0*p))
            perturbance.append(p)
            p /= 2
        conv_rate = convergence_rates(reminder, perturbance)
        return reminder, perturbance, conv_rate

def convergence_rates(r, p):
    cr = []
    for i in range(1, len(p)):
        cr.append(np.log(r[i] / r[i - 1])
                 / np.log(p[i] / p[i - 1]))
    return cr

unittest

from mpi4py import MPI
import dolfinx
import ufl
from petsc4py import PETSc
import numpy as np
import unittest

from derivatives import (
    UflObjective, Point_wise_lsq_objective, AdjointDerivative, taylor_test)

import dolfinx.nls.petsc
import dolfinx.fem.petsc

class TestDerivative(unittest.TestCase):
    def setUp(self) -> None:
        # mesh and function spaces
        mesh = dolfinx.mesh.create_rectangle(MPI.COMM_WORLD, [[0, 0], [1, 1]], [64, 64], dolfinx.mesh.CellType.triangle)
        V = dolfinx.fem.FunctionSpace(mesh, ("P", 1)) # solution space
        self.u = dolfinx.fem.Function(V)  # solution trial
        v = ufl.TestFunction(V)  # solution test

        self.k = dolfinx.fem.Constant(mesh, PETSc.ScalarType((1.0, 1.0)))

        self.dx = ufl.Measure("dx")
        self.F = ufl.inner(ufl.grad(self.u), ufl.grad(v))*self.dx # heat equation
        self.F += sum(self.k)*v*self.dx # add source of heat

        # boundary conditions
        def left(x): return np.isclose(x[0], 0)

        fdim = mesh.topology.dim - 1

        u1 = dolfinx.fem.Function(V)
        boundary_facets = dolfinx.mesh.locate_entities_boundary(mesh, fdim, left)
        with u1.vector.localForm() as loc:
            loc.set(0)
        bc1 = dolfinx.fem.dirichletbc(u1, dolfinx.fem.locate_dofs_topological(V, fdim, boundary_facets))

        self.bcs = [bc1]

        # solve the problem
        problem = dolfinx.fem.petsc.NonlinearProblem(self.F, self.u, self.bcs)
        self.solver = dolfinx.nls.petsc.NewtonSolver(MPI.COMM_WORLD, problem)
        self.solver.convergence_criterion = "residual"
        self.solver.rtol = 1e-16
        self.forward = lambda : self.solver.solve(self.u)
    
    def test_ufl_objective(self):
        J = ufl.inner(self.u, self.u)*self.dx

        controls = [self.k]
        Jr = UflObjective(J, self.u, controls)
        adjoint = AdjointDerivative(Jr, controls, self.F, self.forward, self.u, self.bcs)

        def grad(value):
            self.k.value[:] = value
            g, l = adjoint.compute_gradient()
            return g

        def loss(value):
            self.k.value[:] = value
            self.solver.solve(self.u)
            l = Jr.evaluate()
            return l
        
        k0 = np.array([1, 1])
        r = taylor_test(loss, grad, k0)
        self.assertTrue(np.allclose(r[2], 2.0, atol=0.1), f"Taylor test failed - Convergence rate: {r[2]}")


    def test_lsq_objective(self):
        # sum of square errors at specific point in the domain
        points = [[0.0, 0.0, 0.0], [0.5, 0.5, 0.0]]
        true_values = [0.0, 0.0]
        controls = [self.k]
        Jr = Point_wise_lsq_objective(points, self.u, controls, true_values)
        adjoint = AdjointDerivative(Jr, controls, self.F, self.forward, self.u, self.bcs)

        def grad(value):
            self.k.value[:] = value
            g, l = adjoint.compute_gradient()
            return g

        def loss(value):
            self.k.value[:] = value
            self.solver.solve(self.u)
            l = Jr.evaluate()
            return l
        
        k0 = np.array([1, 1])
        r = taylor_test(loss, grad, k0)
        self.assertTrue(np.allclose(r[2], 2.0, atol=0.1), f"Taylor test failed - Convergence rate: {r[2]}")

    def test_optimisation(self):
        points = [[0.0, 0.0, 0.0], [0.5, 0.5, 0.0]]
        true_values = [0.0, 0.0]
        controls = [self.k]
        Jr = Point_wise_lsq_objective(points, self.u, controls, true_values)
        adjoint = AdjointDerivative(Jr, controls, self.F, self.forward, self.u, self.bcs)

        # find zero
        k_value = np.array([1.0, 1.0])
        alpha = 0.5
        # Vanila gradient descent
        for i in range(150):
            g, l = adjoint.compute_gradient()
            self.k.value += -alpha*g
            if MPI.COMM_WORLD.rank == 0:
                print(l, self.k.value)

        self.assertTrue(np.allclose(self.k.value, 0.0, atol=1e-6), f"k_value should be [0.0, 0.0] not {self.k.value}")

    def tearDown(self) -> None:
        pass

if __name__ == '__main__':
    unittest.main()
2 Likes