I have been recently working my way around GitHub - a-latyshev/dolfinx-external-operator at v0.9.0 and met with a strange behavior in parallel.
I have attached a MWE which solves a nonlinear material model (St.Venant Kirchhoff) Hyperelastic material - Wikipedia using External operators and utilizes a custom Newton solver (similar to Custom Newton solvers — FEniCSx tutorial ).
My problem :
- The residual and correction norms are different in serial and parallel. And also parallel takes more iterations to converge.
- However, I did not find any differences in the output deformation file. So eventually after convergence, they produce the same results.
- I am working on a much more complex model (which I haven’t shared here) where a similar behavior is seen. Therein in serial the code converges, whereas in multi-core the code struggles to converge or diverges.
I would like to know if I am doing anything wrong in this code. I heavily suspect the parallel communication but I quite cannot figure it out.
Versions → Dolfinx v0.9.0 (conda), External Operators v0.9.0
Norms when run with 1 core:
Load: 0.0
Iter:0 res_norm: 0.0 correction_norm: 0.0
Load: 0.01
Iter:0 res_norm: 0.0011645655700496462 correction_norm: 0.03405680111929685
Iter:1 res_norm: 1.3752055063070807e-08 correction_norm: 1.3426438184578667e-05
Iter:2 res_norm: 4.890900405460301e-15 correction_norm: 2.012602023127826e-10
Load: 0.02
Iter:0 res_norm: 0.0010800780829752143 correction_norm: 0.03406127555606496
Iter:1 res_norm: 1.2142982020414133e-08 correction_norm: 1.2668553588944848e-05
Iter:2 res_norm: 7.198501084021153e-15 correction_norm: 1.8168006344051814e-10
Norms when run with 2 cores:
Load: 0.0
Iter:0 res_norm: 0.0 correction_norm: 0.0
Load: 0.01
Iter:0 res_norm: 0.8930699302443738 correction_norm: 0.033854985881255226
Iter:1 res_norm: 0.00277598379917335 correction_norm: 0.007258805661231315
Iter:2 res_norm: 8.5285456113812e-08 correction_norm: 2.976174386216571e-05
Iter:3 res_norm: 5.5992304977974396e-15 correction_norm: 1.3421445319803384e-09
Load: 0.02
Iter:0 res_norm: 0.9160841037720723 correction_norm: 0.03385713388935493
Iter:1 res_norm: 0.0027668283620376597 correction_norm: 0.007231808363626835
Iter:2 res_norm: 8.417046869946899e-08 correction_norm: 2.9638927069016592e-05
Iter:3 res_norm: 3.2397806339452964e-15 correction_norm: 1.323708209311654e-09
Norms when run with 3 cores:
Load: 0.0
Iter:0 res_norm: 0.0 correction_norm: 0.0
Load: 0.01
Iter:0 res_norm: 1.6039310082055045 correction_norm: 0.034674963015875325
Iter:1 res_norm: 0.009230115330149384 correction_norm: 0.014410243023101827
Iter:2 res_norm: 1.0417828487614108e-06 correction_norm: 0.00011104545528349307
Iter:3 res_norm: 2.0919133683030973e-14 correction_norm: 1.595314573696844e-08
Load: 0.02
Iter:0 res_norm: 1.6465132421118773 correction_norm: 0.03470099329171745
Iter:1 res_norm: 0.009241721171174552 correction_norm: 0.014446538899800547
Iter:2 res_norm: 1.0379626322396358e-06 correction_norm: 0.00011123026437007292
Iter:3 res_norm: 1.4106714151888706e-14 correction_norm: 1.592229392352619e-08
MWE:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
###################### LIBRARIES ###############################
import jax
import jax.lax
import jax.numpy as jnp
import numpy as np
from petsc4py.PETSc import ScalarType
from mpi4py import MPI
import ufl , basix, os
from dolfinx import fem
import dolfinx.fem.petsc
from petsc4py import PETSc
from dolfinx.fem.petsc import apply_lifting, assemble_matrix, assemble_vector, set_bc
from dolfinx.io import VTXWriter
import dolfinx.nls.petsc
from dolfinx_external_operator import (
FEMExternalOperator,
evaluate_external_operators,
evaluate_operands,
replace_external_operators,
)
jax.config.update("jax_enable_x64", True)
simulation_number = "p=1"
outputfoldername = f"./results_JAX_autodiff-{simulation_number}/" # path to save outputfiles
if (MPI.COMM_WORLD.rank == 0):
try:
os.mkdir(outputfoldername)
except:
pass
######################## MESH ###############################
dmesh = dolfinx.mesh.create_box(MPI.COMM_WORLD, [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], [2, 2, 2], dolfinx.mesh.CellType.hexahedron)#,ghost_mode=mesh.GhostMode.shared_facet)
######################## FUNCTION SPACES ######################
p=1 # shape function polynomial
V = fem.functionspace(dmesh,("Lagrange", p, (3,)))
u_trial = ufl.TrialFunction(V) # defining trial Function
u_test = ufl.TestFunction(V) # defining test function
u = fem.Function(V)
delta_U = fem.Function(V, name="Newton_correction")
deg_quad = 2*p+1
QF_3ten_e = basix.ufl.quadrature_element(dmesh.basix_cell(), value_shape=(3,3), scheme="default", degree=deg_quad)
QF_3ten = fem.functionspace(dmesh, QF_3ten_e)
def deformation_grad(u):
return ufl.Identity(3) + ufl.grad(u)
P = FEMExternalOperator(deformation_grad(u), function_space=QF_3ten) # Dolfinx external operator
######################## MATERIAL PARAMETERS ###################
lamda = 100.0
mu=80.0
######################## BCS ###################################
def left(x):
return np.isclose(x[0], 0.0)
def right(x):
return np.isclose(x[0], 1.0)
def front(x):
return np.isclose(x[1], 0.0)
def bottom(x):
return np.isclose(x[2], 0.0)
left_facets = dolfinx.mesh.locate_entities_boundary(dmesh, 2, left)
left_dofs_x = fem.locate_dofs_topological(V, 2, left_facets)
right_facets = dolfinx.mesh.locate_entities_boundary(dmesh, 2, right)
right_dofs_x = fem.locate_dofs_topological(V, 2, right_facets)
u_Dirichlet = fem.Constant(dmesh,ScalarType([0.0,0.0,0.0]))
bcs = [ fem.dirichletbc(np.array([0.0,0.0,0.0], dtype=ScalarType),left_dofs_x, V),
fem.dirichletbc(u_Dirichlet,right_dofs_x, V),]
###################### MATERIAL ROUTINE #########################
I = jnp.identity(3)
def return_mapping(F_local):
E = 0.5* (F_local.T @ F_local - I ) # E = 0.5*(F.T @ F - I)
S_2PK = lamda *jnp.trace(E)*I + 2*mu*E # S = lambda*tr(E) *I + 2*mu*E
P_1PK = F_local @ S_2PK # P = FS
return P_1PK , (P_1PK ) # for autodiff
autodiff_Ctang = jax.jacfwd(return_mapping, argnums=0, has_aux = True)
autodiff_Ctang_vec = jax.jit(jax.vmap(autodiff_Ctang, in_axes=(0)))
def C_tang_impl(F_local_):
F_local_vectorized = F_local_.reshape((-1,3,3))
(C_tang_, P_1PK_) = autodiff_Ctang_vec(F_local_vectorized)
return C_tang_.reshape(-1), P_1PK_.reshape(-1)
def P_external(derivatives):
if derivatives == (1,):
return C_tang_impl
else:
raise NotImplementedError(f'There is no external function for the derivative {derivatives}.')
P.external_function = P_external
######################## VARIATIONAL FORM #########################
dx = ufl.Measure("dx", domain=dmesh, metadata={"quadrature_degree": deg_quad, "quadrature_scheme": "default"},)
residual = ufl.inner(P, ufl.grad(u_test))* dx
Jacobian =ufl.derivative(residual, u, u_trial)
J_expanded = ufl.algorithms.expand_derivatives(Jacobian)
res_replaced, res_external_operators = replace_external_operators(residual)
J_replaced, J_external_operators = replace_external_operators(J_expanded)
res_form = fem.form(res_replaced)
J_form = fem.form(J_replaced)
A = fem.petsc.create_matrix(J_form)
L = fem.petsc.create_vector(res_form)
######################## SOLVER ######################################
solver = PETSc.KSP().create(MPI.COMM_WORLD)
solver.setType(PETSc.KSP.Type.PREONLY)
solver.getPC().setType(PETSc.PC.Type.LU)
solver.getPC().setFactorSolverType("mumps")
solver.setOperators(A)
######################### OUTPUT #####################################
u_bp = VTXWriter(MPI.COMM_WORLD, f"{outputfoldername}displacement.bp", [u], engine="BP4")
######################### LOAD STEPPING ###############################
def mpi_print(x):
if (MPI.COMM_WORLD.rank ==0):
print(x)
evaluated_operands = evaluate_operands(res_external_operators)
((_ ,P_new), ) = evaluate_external_operators(J_external_operators, evaluated_operands)
global_Nitermax, tol = 50, 1e-6 # parameters of the Newton solver
loadsteps = np.arange(0.0,0.03,0.01)
for i in loadsteps:
u_Dirichlet.value = [float(i),0.0,0.0]
mpi_print(f"Load: {i}")
delta_U.x.array[:] = 0.0
delta_U.x.scatter_forward()
niter = 0
while niter < global_Nitermax:
with L.localForm() as loc_L:
loc_L.set(0)
A.zeroEntries()
assemble_matrix(A, J_form, bcs=bcs)
A.assemble()
assemble_vector(L, res_form)
L.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
L.scale(-1)
apply_lifting(L, [J_form], [bcs], x0=[u.x.petsc_vec], alpha=1.0)
L.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)
set_bc(L, bcs, u.x.petsc_vec, 1.0)
# SOLVE
solver.solve(L, delta_U.x.petsc_vec)
delta_U.x.scatter_forward()
u.x.array[:] += delta_U.x.array
u.x.scatter_forward()
# UPDATE STRESS
evaluated_operands = evaluate_operands(res_external_operators)
((_ ,P_new ), ) = evaluate_external_operators(J_external_operators, evaluated_operands)
P.ref_coefficient.x.array[:] = P_new
P.ref_coefficient.x.scatter_forward()
# CALCULATE RES_NORM WITH UPDATED STRESS (Displacement)
with L.localForm() as loc_L:
loc_L.set(0)
assemble_vector(L, res_form)
L.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
L.scale(-1)
apply_lifting(L, [J_form], [bcs], x0=[u.x.petsc_vec], alpha=1.0)
L.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)
set_bc(L, bcs, u.x.petsc_vec, 1.0)
res_norm = L.norm(1)
correction_norm = delta_U.x.petsc_vec.norm(1)
mpi_print(f"Iter:{niter} res_norm: {res_norm} correction_norm: {correction_norm}")
if res_norm < tol and correction_norm < tol:
break
elif np.isnan(res_norm) or np.isnan(correction_norm) or np.isinf(res_norm) or np.isinf(correction_norm):
mpi_print("inf or nan detected in correction or residual norm !!!")
break
niter += 1
if np.isnan(res_norm) or np.isnan(correction_norm) or np.isinf(res_norm) or np.isinf(correction_norm):
break
u_bp.write(i)
u_bp.close()