Residual norm different in parallel and serial

I have been recently working my way around GitHub - a-latyshev/dolfinx-external-operator at v0.9.0 and met with a strange behavior in parallel.
I have attached a MWE which solves a nonlinear material model (St.Venant Kirchhoff) Hyperelastic material - Wikipedia using External operators and utilizes a custom Newton solver (similar to Custom Newton solvers — FEniCSx tutorial ).

My problem :

  • The residual and correction norms are different in serial and parallel. And also parallel takes more iterations to converge.
  • However, I did not find any differences in the output deformation file. So eventually after convergence, they produce the same results.
  • I am working on a much more complex model (which I haven’t shared here) where a similar behavior is seen. Therein in serial the code converges, whereas in multi-core the code struggles to converge or diverges.

I would like to know if I am doing anything wrong in this code. I heavily suspect the parallel communication but I quite cannot figure it out.
Versions → Dolfinx v0.9.0 (conda), External Operators v0.9.0

Norms when run with 1 core:

Load: 0.0
Iter:0 res_norm: 0.0 correction_norm: 0.0
Load: 0.01
Iter:0 res_norm: 0.0011645655700496462 correction_norm: 0.03405680111929685
Iter:1 res_norm: 1.3752055063070807e-08 correction_norm: 1.3426438184578667e-05
Iter:2 res_norm: 4.890900405460301e-15 correction_norm: 2.012602023127826e-10
Load: 0.02
Iter:0 res_norm: 0.0010800780829752143 correction_norm: 0.03406127555606496
Iter:1 res_norm: 1.2142982020414133e-08 correction_norm: 1.2668553588944848e-05
Iter:2 res_norm: 7.198501084021153e-15 correction_norm: 1.8168006344051814e-10

Norms when run with 2 cores:

Load: 0.0
Iter:0 res_norm: 0.0 correction_norm: 0.0
Load: 0.01
Iter:0 res_norm: 0.8930699302443738 correction_norm: 0.033854985881255226
Iter:1 res_norm: 0.00277598379917335 correction_norm: 0.007258805661231315
Iter:2 res_norm: 8.5285456113812e-08 correction_norm: 2.976174386216571e-05
Iter:3 res_norm: 5.5992304977974396e-15 correction_norm: 1.3421445319803384e-09
Load: 0.02
Iter:0 res_norm: 0.9160841037720723 correction_norm: 0.03385713388935493
Iter:1 res_norm: 0.0027668283620376597 correction_norm: 0.007231808363626835
Iter:2 res_norm: 8.417046869946899e-08 correction_norm: 2.9638927069016592e-05
Iter:3 res_norm: 3.2397806339452964e-15 correction_norm: 1.323708209311654e-09

Norms when run with 3 cores:

Load: 0.0
Iter:0 res_norm: 0.0 correction_norm: 0.0
Load: 0.01
Iter:0 res_norm: 1.6039310082055045 correction_norm: 0.034674963015875325
Iter:1 res_norm: 0.009230115330149384 correction_norm: 0.014410243023101827
Iter:2 res_norm: 1.0417828487614108e-06 correction_norm: 0.00011104545528349307
Iter:3 res_norm: 2.0919133683030973e-14 correction_norm: 1.595314573696844e-08
Load: 0.02
Iter:0 res_norm: 1.6465132421118773 correction_norm: 0.03470099329171745
Iter:1 res_norm: 0.009241721171174552 correction_norm: 0.014446538899800547
Iter:2 res_norm: 1.0379626322396358e-06 correction_norm: 0.00011123026437007292
Iter:3 res_norm: 1.4106714151888706e-14 correction_norm: 1.592229392352619e-08

MWE:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

###################### LIBRARIES ###############################
import jax
import jax.lax
import jax.numpy as jnp
import numpy as np
from petsc4py.PETSc import ScalarType
from mpi4py import MPI
import ufl , basix, os
from dolfinx import fem
import dolfinx.fem.petsc
from petsc4py import PETSc
from dolfinx.fem.petsc import apply_lifting, assemble_matrix, assemble_vector, set_bc
from dolfinx.io import VTXWriter
import dolfinx.nls.petsc
from dolfinx_external_operator import (
    FEMExternalOperator,
    evaluate_external_operators,
    evaluate_operands,
    replace_external_operators,
)

jax.config.update("jax_enable_x64", True)
simulation_number = "p=1"
outputfoldername = f"./results_JAX_autodiff-{simulation_number}/"  # path to save outputfiles
if (MPI.COMM_WORLD.rank == 0):
    try:
        os.mkdir(outputfoldername)
    except:
        pass

########################  MESH  ###############################
dmesh = dolfinx.mesh.create_box(MPI.COMM_WORLD, [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], [2, 2, 2], dolfinx.mesh.CellType.hexahedron)#,ghost_mode=mesh.GhostMode.shared_facet)

######################## FUNCTION SPACES ######################
p=1 # shape function polynomial
V = fem.functionspace(dmesh,("Lagrange", p, (3,)))
u_trial = ufl.TrialFunction(V) # defining trial Function
u_test = ufl.TestFunction(V) # defining test function
u = fem.Function(V)
delta_U = fem.Function(V, name="Newton_correction")
deg_quad = 2*p+1
QF_3ten_e = basix.ufl.quadrature_element(dmesh.basix_cell(), value_shape=(3,3), scheme="default", degree=deg_quad)
QF_3ten = fem.functionspace(dmesh, QF_3ten_e)
def deformation_grad(u):
    return ufl.Identity(3) + ufl.grad(u)
P = FEMExternalOperator(deformation_grad(u), function_space=QF_3ten) # Dolfinx external operator

######################## MATERIAL PARAMETERS ###################
lamda = 100.0
mu=80.0

######################## BCS ###################################
def left(x):
    return np.isclose(x[0], 0.0)
def right(x):
    return np.isclose(x[0], 1.0)
def front(x):
    return np.isclose(x[1], 0.0)
def bottom(x):
    return np.isclose(x[2], 0.0)

left_facets = dolfinx.mesh.locate_entities_boundary(dmesh, 2, left)
left_dofs_x = fem.locate_dofs_topological(V, 2, left_facets)

right_facets = dolfinx.mesh.locate_entities_boundary(dmesh, 2, right)
right_dofs_x = fem.locate_dofs_topological(V, 2, right_facets)

u_Dirichlet = fem.Constant(dmesh,ScalarType([0.0,0.0,0.0]))
bcs = [ fem.dirichletbc(np.array([0.0,0.0,0.0], dtype=ScalarType),left_dofs_x, V),
        fem.dirichletbc(u_Dirichlet,right_dofs_x, V),]

###################### MATERIAL ROUTINE #########################
I = jnp.identity(3)
def return_mapping(F_local):   
    E = 0.5* (F_local.T @ F_local - I  ) # E = 0.5*(F.T @ F - I)
    S_2PK =  lamda *jnp.trace(E)*I + 2*mu*E  # S = lambda*tr(E) *I + 2*mu*E
    P_1PK = F_local @ S_2PK # P = FS
    return P_1PK , (P_1PK )   # for autodiff

autodiff_Ctang = jax.jacfwd(return_mapping, argnums=0, has_aux = True)
autodiff_Ctang_vec = jax.jit(jax.vmap(autodiff_Ctang, in_axes=(0)))

def C_tang_impl(F_local_):
    F_local_vectorized = F_local_.reshape((-1,3,3))
    (C_tang_, P_1PK_) = autodiff_Ctang_vec(F_local_vectorized)
    return C_tang_.reshape(-1), P_1PK_.reshape(-1)

def P_external(derivatives):
    if derivatives == (1,):
        return C_tang_impl
    else:
        raise NotImplementedError(f'There is no external function for the derivative {derivatives}.')
                
P.external_function = P_external 
    
######################## VARIATIONAL FORM #########################
dx = ufl.Measure("dx", domain=dmesh, metadata={"quadrature_degree": deg_quad, "quadrature_scheme": "default"},)
residual = ufl.inner(P, ufl.grad(u_test))* dx 
Jacobian =ufl.derivative(residual, u, u_trial)
J_expanded = ufl.algorithms.expand_derivatives(Jacobian)
res_replaced, res_external_operators = replace_external_operators(residual)
J_replaced, J_external_operators = replace_external_operators(J_expanded)
res_form = fem.form(res_replaced)
J_form = fem.form(J_replaced)
A = fem.petsc.create_matrix(J_form)
L = fem.petsc.create_vector(res_form)

######################## SOLVER ######################################
solver = PETSc.KSP().create(MPI.COMM_WORLD)
solver.setType(PETSc.KSP.Type.PREONLY)
solver.getPC().setType(PETSc.PC.Type.LU)  
solver.getPC().setFactorSolverType("mumps")
solver.setOperators(A)

######################### OUTPUT #####################################
u_bp = VTXWriter(MPI.COMM_WORLD, f"{outputfoldername}displacement.bp", [u], engine="BP4")
######################### LOAD STEPPING ###############################
def mpi_print(x):
    if (MPI.COMM_WORLD.rank ==0):
        print(x)
evaluated_operands = evaluate_operands(res_external_operators)
((_ ,P_new), ) = evaluate_external_operators(J_external_operators, evaluated_operands)

global_Nitermax, tol = 50, 1e-6  # parameters of the Newton solver
loadsteps = np.arange(0.0,0.03,0.01)
       
for i in loadsteps:
    u_Dirichlet.value = [float(i),0.0,0.0]
    mpi_print(f"Load: {i}")
    delta_U.x.array[:] = 0.0
    delta_U.x.scatter_forward()
    niter = 0
    while niter < global_Nitermax:

        with L.localForm() as loc_L:
            loc_L.set(0)
        A.zeroEntries()
        assemble_matrix(A, J_form, bcs=bcs)
        A.assemble()
        assemble_vector(L, res_form)
        L.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
        L.scale(-1)
        apply_lifting(L, [J_form], [bcs], x0=[u.x.petsc_vec], alpha=1.0)
        L.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)
        set_bc(L, bcs, u.x.petsc_vec, 1.0)

        # SOLVE
        solver.solve(L, delta_U.x.petsc_vec)
        delta_U.x.scatter_forward()
        u.x.array[:] += delta_U.x.array
        u.x.scatter_forward()

        # UPDATE STRESS
        evaluated_operands = evaluate_operands(res_external_operators)
        ((_ ,P_new ), ) = evaluate_external_operators(J_external_operators, evaluated_operands)
        P.ref_coefficient.x.array[:] = P_new
        P.ref_coefficient.x.scatter_forward()

        # CALCULATE RES_NORM WITH UPDATED STRESS (Displacement)
        with L.localForm() as loc_L:
            loc_L.set(0)
        assemble_vector(L, res_form)
        L.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
        L.scale(-1)
        apply_lifting(L, [J_form], [bcs], x0=[u.x.petsc_vec], alpha=1.0)
        L.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)
        set_bc(L, bcs, u.x.petsc_vec, 1.0)
        res_norm = L.norm(1)
        correction_norm = delta_U.x.petsc_vec.norm(1)

        mpi_print(f"Iter:{niter}  res_norm: {res_norm}  correction_norm: {correction_norm}")
        if res_norm < tol and correction_norm < tol:
            break
        elif np.isnan(res_norm) or np.isnan(correction_norm) or np.isinf(res_norm) or np.isinf(correction_norm):
            mpi_print("inf or nan detected in correction or residual norm !!!")
            break
        niter += 1
    if np.isnan(res_norm) or np.isnan(correction_norm) or np.isinf(res_norm) or np.isinf(correction_norm):
        break
  
    u_bp.write(i)
u_bp.close()

The order here is not correct.
It should be:

    with L.localForm() as loc_L:
        loc_L.set(0)
    A.zeroEntries()
    assemble_matrix(A, J_form, bcs=bcs)
    A.assemble()
    assemble_vector(L, res_form)
    L.scale(-1)
    apply_lifting(L, [J_form], [bcs], x0=[u.x.petsc_vec], alpha=1.0)
    L.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
    set_bc(L, bcs, u.x.petsc_vec, 1.0)
    L.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)

    # SOLVE
    solver.solve(L, delta_U.x.petsc_vec)
    delta_U.x.scatter_forward()
    print(delta_U.x.petsc_vec.norm(1))

Same at the later stage of the solver.
I.e. scatter_reverse should happen after lifting and scatter forward should happen after setting the bcs.

Thank you for your reply.
I did try your suggestion, but unfortunately it did not work.

The solver hits the set max. iterations (global_Nitermax) without convergence. Residual norm gradually diverges.
Solutions at loads 0.01 and 0.02 are incorrect as they are not a result of a converged solution.

If global_Nitermax = 25, the norms are (when run with 2 cores):

Load: 0.0
Iter:0 res_norm: 0.0 correction_norm: 0.0
Load: 0.01
Iter:0 res_norm: 0.17242702220851552 correction_norm: 0.03405680111929684
Iter:1 res_norm: 0.1723522837072458 correction_norm: 0.0024310013521334686
Iter:2 res_norm: 0.1741336239668787 correction_norm: 0.0024317732046898774
Iter:3 res_norm: 0.17763788481722487 correction_norm: 0.0024886858967655464
Iter:4 res_norm: 0.18429897195859485 correction_norm: 0.0025991687473128763
Iter:5 res_norm: 0.19669453920767585 correction_norm: 0.002781201356324795
Iter:6 res_norm: 0.2189833223543194 correction_norm: 0.00306573494943784
Iter:7 res_norm: 0.25717146119397205 correction_norm: 0.0034991973574346946
Iter:8 res_norm: 0.319109874270506 correction_norm: 0.004146356425929821
Iter:9 res_norm: 0.4146701881102281 correction_norm: 0.0050938906347248155
Iter:10 res_norm: 0.5567465791440417 correction_norm: 0.00645633079708771
Iter:11 res_norm: 0.7632761193019997 correction_norm: 0.008386241281516797
Iter:12 res_norm: 1.0602682073981393 correction_norm: 0.011089719795261835
Iter:13 res_norm: 1.4862352972193245 correction_norm: 0.014847958730278112
Iter:14 res_norm: 2.099021049891553 correction_norm: 0.020046592567222517
Iter:15 res_norm: 2.9867248142275984 correction_norm: 0.027217115928599017
Iter:16 res_norm: 4.28534119125191 correction_norm: 0.03710020753358929
Iter:17 res_norm: 6.206653852462845 correction_norm: 0.05075329374001919
Iter:18 res_norm: 9.079535744313747 correction_norm: 0.06975130749084214
Iter:19 res_norm: 13.402449137815578 correction_norm: 0.09657923585578101
Iter:20 res_norm: 19.89007815153467 correction_norm: 0.13537352511055636
Iter:21 res_norm: 29.514035772071832 correction_norm: 0.19274746529196
Iter:22 res_norm: 44.138859454300935 correction_norm: 0.2697457469834709
Iter:23 res_norm: 72.25380503219031 correction_norm: 0.3357014977842791
Iter:24 res_norm: 2021.6793778389945 correction_norm: 2.879129541876163
Load: 0.02
Iter:0 res_norm: 604.1692972817705 correction_norm: 0.9736373138444921
Iter:1 res_norm: 211.59223978978395 correction_norm: 0.6542777779456821
Iter:2 res_norm: 187.82003313614018 correction_norm: 0.4997138007336679
Iter:3 res_norm: 298.10940117904187 correction_norm: 0.4938699610142547
Iter:4 res_norm: 475.42147080663335 correction_norm: 0.7047789273622335
Iter:5 res_norm: 668.6520764841489 correction_norm: 0.7521295089300651
Iter:6 res_norm: 937.7291504752852 correction_norm: 0.6006001261310144
Iter:7 res_norm: 1245.2175665091436 correction_norm: 0.4954950570473062
Iter:8 res_norm: 1593.8068710666357 correction_norm: 0.44932603379145186
Iter:9 res_norm: 1991.0881626050145 correction_norm: 0.4549921202506294
Iter:10 res_norm: 2434.521440655406 correction_norm: 0.45314332532067025
Iter:11 res_norm: 2924.9887496650204 correction_norm: 0.45363975431103265
Iter:12 res_norm: 3462.157341479085 correction_norm: 0.4536125851062404
Iter:13 res_norm: 4046.082438250212 correction_norm: 0.4537222156259773
Iter:14 res_norm: 4676.718124548277 correction_norm: 0.4538029702004225
Iter:15 res_norm: 5354.049412078222 correction_norm: 0.45388275420094076
Iter:16 res_norm: 6078.058803510661 correction_norm: 0.4539542081762236
Iter:17 res_norm: 6848.732786085744 correction_norm: 0.4540183682958904
Iter:18 res_norm: 7666.0598169555005 correction_norm: 0.45407557737450666
Iter:19 res_norm: 8530.030155560808 correction_norm: 0.45412656321517286
Iter:20 res_norm: 9440.635478863289 correction_norm: 0.4541720516883583
Iter:21 res_norm: 10397.868624274523 correction_norm: 0.4542127294847582
Iter:22 res_norm: 11401.723381223503 correction_norm: 0.4542492126648207
Iter:23 res_norm: 12452.19432740999 correction_norm: 0.45428204122467836
Iter:24 res_norm: 13549.276698234804 correction_norm: 0.45431168241580017

PETScVec.scale should be done after ghost updating. I.e. I guess something along the lines of:

        with L.localForm() as loc_L:
            loc_L.set(0)
        A.zeroEntries()
        assemble_matrix(A, J_form, bcs=bcs)
        A.assemble()
        assemble_vector(L, res_form)
        apply_lifting(L, [J_form], [bcs], x0=[u.x.petsc_vec], alpha=-1.0)
        L.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
        set_bc(L, bcs, u.x.petsc_vec, alpha=-1.0)
        L.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)
        L.scale(-1)

should do the trick (also for the computation of stress displacement).

Thank you for the solution. With your suggestion, the norms in both the serial and parallel runs are now equal.
Does this imply that the order shown at Custom Newton solvers — FEniCSx tutorial is incorrect for parallel cases?