Implementing custom newton solver with Lagrange multiplier

Hey, I am trying to solve a problem in a mixed space where one of the elements is a tensor-valued Raviart–Thomas element and requires the integral of its trace over the domain to be zero. Following the recent discussions on implementing a Lagrange multiplier (Lagrange Multiplier for the NS equation with additional constrain · Issue #3341 · FEniCS/dolfinx · GitHub and Lagrange Multiplier for the Poisson equation with pure Neumann boundary conditions · Issue #3121 · FEniCS/dolfinx · GitHub) I have tried to create a custom Newton solver. I am facing two issues:

  • The residual norm computed in parallel (>1) execution differs by many orders of magnitude from serial.
  • In serial, the solver converges well (~1e-7 residual norm) within the first few iterations, remains constant, and then diverges.

Could there be a fundamental problem in my implementation? I’ve tried quite a few things now and haven’t been able to narrow down the problem. Any help would be appreciated! Also, Dirichlet conditions are manifested in the variational form, so I don’t need to account for them in the solver.

from mpi4py import MPI
import dolfinx as dfx
import ufl
import dolfinx.fem.petsc
from petsc4py import PETSc
import numpy as np
import matplotlib.pyplot as plt
import logging

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()

PETSc.Log.begin()

class DirectNewtonRLMSolver():
    
    def __init__(self, solution: dfx.fem.Function, a: dfx.fem.Form, L: dfx.fem.form, lm: dfx.fem.form, solver, comm=None, verbose=True):
        self.comm = comm
        self.verbose = verbose
        self.sol = solution
        self.a = a
        self.L = L
        self.F = self.a - self.L 
        self.residual = dfx.fem.form(self.F)
        self.J = ufl.derivative(self.F, self.sol)
        self.jacobian = dfx.fem.form(self.J)
        self.lm_form = dfx.fem.form(lm)
        
        self._base_setup()
        self.solver = self.config_solver(solver)
        
    def _base_setup(self):
        self.bt = self.make_extension_vector(self.lm_form)
        local_size = self.bt.getLocalSize()
        self.A_aug = PETSc.Mat().createAIJ(((local_size + (1 if self.comm.rank == self.comm.size - 1 else 0), PETSc.DECIDE), (local_size + (1 if self.comm.rank == self.comm.size - 1 else 0), PETSc.DECIDE)), comm=self.comm)
        self.b_aug = PETSc.Vec().createMPI((local_size + (1 if self.comm.rank == self.comm.size - 1 else 0), PETSc.DECIDE), comm=self.comm)
        self._obj_vec = PETSc.Vec().createMPI((local_size + (1 if self.comm.rank == self.comm.size - 1 else 0), PETSc.DECIDE), comm=self.comm)
        self.x = self.b_aug.duplicate()
        self.update_solution(self.x)
        self.A = dfx.fem.petsc.create_matrix(self.jacobian)
        self.b = dfx.fem.petsc.create_vector(self.residual)
        self._last_rank_debug(f"matrix A global size {self.A.getSize()}")
    
    def _last_rank_debug(self, message):
        if self.verbose:
            if self.comm.rank == self.comm.size - 1:
                logger.info(f"Rank {self.comm.rank}: {message}")
    
    def make_extension_vector(self, lm_form):
        bt = dfx.fem.petsc.create_vector(lm_form)
        dfx.fem.petsc.assemble_vector(bt, self.lm_form)
        bt.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
        bt.assemble()
        return bt
    
    def residual_norm(self, x: PETSc.Vec, obj: PETSc.Vec) -> np.float64:
        self._last_rank_debug(f"assembling objective vector")
        if self.comm.rank == self.comm.size - 1:
            lam_l = x.getValue(x.getOwnershipRange()[1] - 1)
        else:
            lam_l = 0
        lam = self.comm.allreduce(lam_l, op=MPI.SUM)
        self._last_rank_debug(f"lam = {lam}")
        self.b.axpy(lam, self.bt)
        self.b.assemble()
        obj.setValues(range(*(self.b.getOwnershipRange())), self.b)
        
        cx = self.bt.dot(self.sol.vector)
        self._last_rank_debug(f"cx = {cx}")
        if self.comm.rank == self.comm.size - 1:
            obj.setValue(self.b.getSize(), -cx)
        
        obj.assemble()
        self._last_rank_debug(f"assembly of objective vector complete!")
        
        return self._obj_vec.norm()
    
    def assemble_b_aug(self, x: PETSc.Vec, b_aug: PETSc.Vec):
        self._last_rank_debug(f"assembling nested b")
        with self.b.localForm() as b_loc:
            b_loc.set(0)
        dfx.fem.petsc.assemble_vector(self.b, self.residual)
        self.b.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
        self.b.scale(-1)
        self.b.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)
        
        if self.comm.rank == self.comm.size - 1:
            lam_l = x.getValue(x.getOwnershipRange()[1] - 1)
        else:
            lam_l = 0
        lam = self.comm.allreduce(lam_l, op=MPI.SUM)
        self._last_rank_debug(f"lam = {lam}")
        self.b.axpy(lam, self.bt)
        self.b.assemble()
        b_aug.setValues(range(*(self.b.getOwnershipRange())), self.b)

        cx = self.bt.dot(self.sol.vector)
        self._last_rank_debug(f"cx = {cx}")
        if self.comm.rank == self.comm.size - 1:
            b_aug.setValue(self.b.getSize(), -cx)
        
        b_aug.assemble()
        self._last_rank_debug(f"assembly of nested b complete!")
    
    def assemble_A_aug(self, A_aug: PETSc.Mat):
        A_aug.zeroEntries()
        dfx.fem.petsc.assemble_matrix(self.A, self.jacobian)
        self.A.assemble()
        self._last_rank_debug("assembling nested A")
        
        global_size = self.bt.getSize()
        local_range = range(*(self.bt.getOwnershipRange()))
        global_range = range(global_size)
        
        global_is = PETSc.IS().createGeneral(global_range, comm=PETSc.COMM_WORLD)
        local_is = PETSc.IS().createGeneral(local_range, comm=PETSc.COMM_WORLD)
        
        A_aug.setValues(local_is, global_size, self.bt.getArray(), addv=PETSc.InsertMode.INSERT_VALUES)
        y = self.gather_to_the_last_process(self.bt)
        if self.comm.rank == self.comm.size - 1:
            A_aug.setValues(global_size, global_is, y.getArray(), addv=PETSc.InsertMode.INSERT_VALUES)
            A_aug.setValues(global_size, global_size, 0, addv=PETSc.InsertMode.INSERT_VALUES)
            
        values = self.A.getValuesCSR()
        if self.comm.rank == self.comm.size - 1:
            values = (np.append(values[0], values[0][-1]), ) + values[1:]
        A_aug.setValuesCSR(*values, addv=PETSc.InsertMode.INSERT_VALUES)
        A_aug.assemble()
        self._last_rank_debug("assembly of nested A complete!")
    
    def gather_to_the_last_process(self, b: PETSc.Vec):
        # b -> y
        global_rows = b.getSize()
        if self.comm.rank == self.comm.size - 1:
            y = PETSc.Vec().createSeq(global_rows, comm=PETSc.COMM_SELF)
            send_index_set = PETSc.IS().createStride(global_rows, first=0, step=1, comm=PETSc.COMM_SELF)
            recv_index_set = PETSc.IS().createStride(global_rows, first=0, step=1, comm=PETSc.COMM_SELF)
        else:
            y = PETSc.Vec().createSeq(0, comm=PETSc.COMM_SELF)
            send_index_set = PETSc.IS().createStride(0, first=0, step=1, comm=PETSc.COMM_SELF)
            recv_index_set = PETSc.IS().createStride(0, first=0, step=1, comm=PETSc.COMM_SELF)
        scatter = PETSc.Scatter().create(b, send_index_set, y, recv_index_set)
        scatter.scatter(b, y, mode=PETSc.ScatterMode.FORWARD, addv=PETSc.InsertMode.INSERT_VALUES)
        return y
    
    def update_solution(self, x: PETSc.Vec) -> None:
        is_xt = PETSc.IS().createGeneral(range(*(self.sol.vector.getOwnershipRange())), comm=PETSc.COMM_WORLD)
        self.xt = x.getSubVector(is_xt)
        
        vals = self.sol.vector.getValues(is_xt)
        self.sol.vector.setValues(is_xt, vals + self.xt.getArray())
        self.sol.vector.assemble()
        self.sol.vector.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
    
    def solve(self, max_iterations=25, atol=1e-8):
        for i in range(max_iterations):
            self.assemble_b_aug(self.x, self.b_aug)
            self.assemble_A_aug(self.A_aug)
            self.solver.solve(self.b_aug, self.x)
            self.update_solution(self.x)
            
            reason = self.solver.getConvergedReason()
            self._last_rank_debug(f"iteration {i}: converged with status {reason}")
            
            resid_norm = self.residual_norm(self.x, self._obj_vec)
            self._last_rank_debug(f"iteration {i}: residual norm = {resid_norm}")
            self._test_get_serial_norm()
            
            global_resid_norm = self.comm.allreduce(resid_norm, op=MPI.MAX)
            self._last_rank_debug(f"iteration {i}: global residual norm = {global_resid_norm}")
            if global_resid_norm < atol:
                break
            
        return self.sol, global_resid_norm
    
    def config_solver(self, solver):
        solver.setOperators(self.A_aug)
        
        solver.setType("preonly")
        solver.setTolerances(
            atol=1e-10, 
            rtol=1e-8
        )
        pc = solver.getPC()
        pc.setType("lu")
        pc.setFactorSolverType("mumps")
        F = pc.getFactorMatrix()
        F.setMumpsIcntl(14, 100)
        F.setMumpsIcntl(24, 1)
        F.setMumpsIcntl(25, 0)
        solver.setFromOptions()
        return solver

    def _test_get_serial_norm(self):
        self.comm.Barrier()
        global_xt_array = self.comm.gather(self.xt.getArray(), root=0)
        
        if self.comm.rank == 0:
            logger.debug(f"Rank {self.comm.rank}: global xt: {np.concatenate(global_xt_array).shape}")
            logger.debug(f"Rank {self.comm.rank}: global xt norm = {np.linalg.norm(np.concatenate(global_xt_array), 1)}")