Dear FEniCS community,
for the last couple of months, I’ve been utilizing external operators to implement some material models. I’ve managed to achive great results/performance in the case of hyperelasticity, and small strain elasto-plasticity. The latter one was done in a slightly different algorithmic manner than the referenced example.
My next goal was finite elasto-plasticity. While the routine is working, I’ve noticed a significant drop in performance. I’m aware that the finite elasto-plasticity is more computationally intensive than hyperelasticity or its small strain counterpart, still I find such performance drop off a bit surprising.
I’ve installed FEniCSx via conda-forge. The following is an example of uniaxial tension test of a 3D-Cube, which I use as a benchmark:
# Import modules
import numpy as np
import jax
import jax.lax
import jax.numpy as jnp
from jax.scipy.linalg import expm
from mpi4py import MPI
from petsc4py import PETSc
from dolfinx import fem, mesh
from dolfinx.io import XDMFFile
import ufl
import basix
from dolfinx_external_operator import (
FEMExternalOperator,
evaluate_external_operators,
evaluate_operands,
replace_external_operators,
)
from solvers import NonlinearProblemWithCallback
from dolfinx.nls.petsc import NewtonSolver
import matplotlib.pyplot as plt
jax.config.update("jax_enable_x64", True) # Set 64-bit arithmetic in JAX
# Tensor- to Nye-notation and vice versa
def jaxt2n(A, x):
return jnp.array([A[0,0], A[1,1], A[2,2], x * A[0,1], x * A[0,2], x * A[1,2]])
def jaxn2t(v, x):
return jnp.array([[v[0], v[3] / x, v[4] / x],
[v[3] / x, v[1], v[5] / x],
[v[4] / x, v[5] / x, v[2]]])
def uflt2n(A, x):
return ufl.as_vector([A[0,0], A[1,1], A[2,2], x * A[0,1], x * A[0,2], x * A[1,2]])
def ufln2t(v, x):
return ufl.as_tensor([[v[0], v[3] / x, v[4] / x],
[v[3] / x, v[1], v[5] / x],
[v[4] / x, v[5] / x, v[2]]])
# A 3-D cube
length = 1.0
width = 1.0
height = 1.0
N = 5
domain = mesh.create_box(MPI.COMM_WORLD, [[0.0, 0.0, 0.0], [length, width, height]],\
[N, N, N], mesh.CellType.hexahedron)
gdim = domain.topology.dim # Dimension of the body
fdim = gdim - 1 # Dimension of the facets
domain.topology.create_connectivity(fdim, gdim) # Establish topological connections (needed for boundary conditions)
# Identify the planar boundaries of the cube mesh
def xBot(x):
return np.isclose(x[0], 0)
def xTop(x):
return np.isclose(x[0], length)
def yBot(x):
return np.isclose(x[1], 0)
def yTop(x):
return np.isclose(x[1], width)
def zBot(x):
return np.isclose(x[2], 0)
def zTop(x):
return np.isclose(x[2], height)
boundaries = [(1, xBot),(2,xTop),(3,yBot),(4,yTop),(5,zBot),(6,zTop)] # Mark the sub-domains
# Build collections of facets on each subdomain and mark them appropriately.
facet_indices, facet_markers = [], [] # Initalize empty collections of indices and markers.
for (marker, locator) in boundaries:
facets = mesh.locate_entities(domain, fdim, locator) # An array of all the facets in a given subdomain ("locator")
facet_indices.append(facets) # Add these facets to the collection.
facet_markers.append(np.full_like(facets, marker)) # Mark them with the appropriate index.
# Format the facet indices and markers as required for use in dolfinx.
facet_indices = np.hstack(facet_indices).astype(np.int32)
facet_markers = np.hstack(facet_markers).astype(np.int32)
sorted_facets = np.argsort(facet_indices)
facet_tags = mesh.meshtags(domain, fdim, facet_indices[sorted_facets], facet_markers[sorted_facets]) # Add these marked facets as "mesh tags" for later use in BCs.
# Define element formulationd and function spaced
V1e = basix.ufl.element("Lagrange", domain.basix_cell(), 1) # For scalars
V3e = basix.ufl.element("Lagrange", domain.basix_cell(), 1, shape=(3,)) # For vectors
G6e = basix.ufl.quadrature_element(domain.basix_cell(), degree=2, value_shape=(6, )) # For tensors in Nye-notation (at the quadrature point)
G8e = basix.ufl.quadrature_element(domain.basix_cell(), degree=2, value_shape=(8, )) # For internal variables (at the quadrature point)
V1 = fem.functionspace(domain, V1e)
V3 = fem.functionspace(domain, V3e)
G6 = fem.functionspace(domain, G6e)
G8 = fem.functionspace(domain, G8e)
u = fem.Function(V3, name="Displacement")
hist = fem.Function(G8, name="History variables")
# Initialize internal variables
init_array = np.array([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])
num_quadpoints = hist.x.array.size // 8
init_values = np.tile(init_array, num_quadpoints)
hist.x.array[:] = init_values
# Kinematics
Id = ufl.Identity(3)
F = Id + ufl.grad(u)
GLS = 0.5 * (F.T * F - Id)
GLS_nye = uflt2n(GLS, 2)
PK2_nye = FEMExternalOperator(GLS_nye, function_space=G6) # Declare the second Piola-Kirchhoff stress tensor as an external operator
def psi(RCGe, kappa):
E = 200e3 # E-Modulus
nu = 0.3 # Poisson's ratio
H = 0 # Hardening modulus
lmbda = E * nu / (1 - 2 * nu) / (1 + nu) # 1st Lame parameter
mu = E / 2 / (1 + nu) # 2nd Lame parameter
I1 = jnp.trace(RCGe) # 1st invariant
I3 = jnp.linalg.det(RCGe) # 3rd invariant
psi_e = 0.5 * mu * (I1 - 3 - jnp.log(I3)) + 0.25 * lmbda * (I3 - 1 - jnp.log(I3)) # Elastic part of the free energy
psi_h = 0.5 * H * kappa**2 # Hardening part of the free energy
return psi_e + psi_h
dpsidRCGe = jax.grad(psi, argnums=0)
dpsidkappa = jax.grad(psi, argnums=1)
def Phi(Y, R):
sigy_0 = 300 # Initial yield stress
devY = Y - jnp.trace(Y) / 3 * jnp.identity(3) # Deviatoric part of the plastic driving force
Phi = jnp.sqrt(3 / 2) * jnp.linalg.norm(devY, ord='fro') - (sigy_0 + R) # Von Mises yield criterion/function
return Phi
dPhidY = jax.grad(Phi, argnums=0)
dPhidR = jax.grad(Phi, argnums=1)
# Local residuum
def r(L, hist_old, RCG):
Up_n_nye = hist_old[:6]
kappa_n = hist_old[6]
Up_nye = L[:6]
kappa = L[6]
dlambda = L[7]
Up = jaxn2t(Up_nye, 2)
Up_n = jaxn2t(Up_n_nye, 2)
invUp = jnp.linalg.inv(Up) #precompute_inverses(Up)
invUp_n = jnp.linalg.inv(Up_n) #precompute_inverses(Up_n)
RCGe = invUp @ RCG @ invUp
dpsidRCGe_ = dpsidRCGe(RCGe, kappa)
PK2 = 2.0 * invUp @ dpsidRCGe_ @ invUp
Y = 2.0 * RCGe @ dpsidRCGe_
R = dpsidkappa(RCGe, kappa)
Z = 2.0 * dlambda * dPhidY(Y, R)
res_Phi = Phi(Y, R)
res_Up = invUp @ expm(Z) @ invUp - invUp_n @ invUp_n
res_Up_nye = jaxt2n(res_Up, 1)
res_kappa = kappa_n - kappa - dlambda * dPhidR(Y, R)
res = jnp.concatenate((res_Phi.reshape(1), res_Up_nye, res_kappa.reshape(1)))
return res, PK2
drdL = jax.jacfwd(r) # Local jacobian
TOL = 1.0e-8 # Tolerance for local Newton scheme
NITER_MAX = 25 # Maximal number of local Newton iterations
def PK2_local(GLS_nye, hist_old):
"""
Core function. Performs the return-mapping
algorithm on a Gauss-Point level. Takes the
Green-Lagrange strain tensor and the history
variables from the previos time step as inputs
and calulates the 2nd Piola-Kirchoff stress
tensors and the new history variables, alongside
some additional control parameters.
"""
Up_n_nye = hist_old[:6] # Plastic stretch from previous time step
kappa_n = hist_old[6] # Accumulated plastic strain from previous time step
GLS = jaxn2t(GLS_nye, 2)
Up_n = jaxn2t(Up_n_nye, 2)
RCG = 2.0 * GLS + jnp.identity(3)
invUp_n = jnp.linalg.inv(Up_n) # precompute_inverses(Up_n)
# Trial step
RCGe_trial = invUp_n @ RCG @ invUp_n
dpsidRCGe_trial = dpsidRCGe(RCGe_trial, kappa_n)
PK2_trial = 2.0 * invUp_n @ dpsidRCGe_trial @ invUp_n
Y_trial = 2.0 * RCGe_trial @ dpsidRCGe_trial
R_trial = dpsidkappa(RCGe_trial, kappa_n)
Phi_trial = Phi(Y_trial, R_trial)
state_cond_initial = (PK2_trial, hist_old, RCG)
def elastic_step(state_cond):
return state_cond[0], state_cond[1]
def plastic_step(state_cond_n):
hist_old = state_cond_n[1]
RCG = state_cond_n[2]
L0 = hist_old.at[7].set(0.0) # Initial local solution vector
niter = 0 # Initiate counter
res_0, PK2_0 = r(L0, hist_old, RCG)
norm_res_0 = jnp.linalg.norm(res_0)
state_loop_initial = {"iteration_number": niter,
"intvar_current": L0,
"intvar_previous": hist_old,
"residuum_norm": norm_res_0,
"residuum": res_0,
"stress": PK2_0,
"deformation": RCG}
def cond_fun(state_loop):
norm_res = state_loop["residuum_norm"]
niter = state_loop["iteration_number"]
return jnp.logical_and(norm_res > TOL, niter < NITER_MAX)
def body_fun(state_loop_n):
niter = state_loop_n["iteration_number"]
res_n = state_loop_n["residuum"]
norm_res = state_loop_n["residuum_norm"]
Ln = state_loop_n["intvar_current"]
h0 = state_loop_n["intvar_previous"]
RCG = state_loop_n["deformation"]
j = drdL(Ln, h0, RCG)[0]
dL = jnp.linalg.solve(j, -res_n)
L = Ln + dL
res, PK2 = r(L, h0, RCG)
norm_res = jnp.linalg.norm(res)
niter += 1
state_loop = {"iteration_number": niter,
"intvar_current": L,
"intvar_previous": h0,
"residuum_norm": norm_res,
"residuum": res,
"stress": PK2,
"deformation": RCG}
return state_loop
state_loop_final = jax.lax.while_loop(cond_fun, body_fun, state_loop_initial)
PK2 = state_loop_final["stress"]
h1 = state_loop_final["intvar_current"]
return PK2, h1
PK2, h1 = jax.lax.cond(Phi_trial <= TOL, elastic_step, plastic_step, state_cond_initial)
return jaxt2n(PK2, 1), h1
dPK2dGLS_local = jax.jacfwd(PK2_local, has_aux=True) # Material tangent
PK2_global = jax.jit(jax.vmap(PK2_local, in_axes=(0, 0))) # Global stress values
dPK2dGLS_global = jax.jit(jax.vmap(dPK2dGLS_local, in_axes=(0, 0))) # Global material tangent
def exof_stress(GLS):
GLS_ = GLS.reshape((-1, 6))
hist_ = hist.x.array[:].reshape((-1, 8))
PK2_, _ = PK2_global(GLS_, hist_)
return PK2_.reshape(-1)
def exof_tangent(GLS):
GLS_ = GLS.reshape((-1, 6))
hist_ = hist.x.array[:].reshape((-1, 8))
dPK2dGLS_, hist_new = dPK2dGLS_global(GLS_, hist_)
return dPK2dGLS_.reshape(-1), hist_new.reshape(-1)
def PK2_external(derivatives):
if derivatives == (0,):
return exof_stress
elif derivatives == (1,):
return exof_tangent
else:
return NotImplementedError
PK2_nye.external_function = PK2_external
dx = ufl.Measure('dx', domain=domain, metadata={'quadrature_degree': 2, 'quadrature_rule':'default'}) # Volume integration measure
v = ufl.TestFunction(V3)
virtGLS = ufl.sym(F.T * ufl.grad(v))
virtGLS_nye = uflt2n(virtGLS, 2)
du = ufl.TrialFunction(V3)
R = ufl.inner(PK2_nye, virtGLS_nye) * dx # Residuum
J = ufl.derivative(R, u, du) # Jacobian
J_expanded = ufl.algorithms.expand_derivatives(J)
R_replaced, R_external_operators = replace_external_operators(R)
J_replaced, J_external_operators = replace_external_operators(J_expanded)
R_form = fem.form(R_replaced)
J_form = fem.form(J_replaced)
# Find the specific constrained DOFs
xBot_ux_dofs = fem.locate_dofs_topological(V3.sub(0), facet_tags.dim, facet_tags.find(1))
yBot_uy_dofs = fem.locate_dofs_topological(V3.sub(1), facet_tags.dim, facet_tags.find(3))
zBot_uz_dofs = fem.locate_dofs_topological(V3.sub(2), facet_tags.dim, facet_tags.find(5))
xTop_ux_dofs = fem.locate_dofs_topological(V3.sub(0), facet_tags.dim, facet_tags.find(2))
ux_disp = fem.Constant(domain, 0.0) # Initialize displacements of xTop facet
bcs_1 = fem.dirichletbc(0.0, xBot_ux_dofs, V3.sub(0)) # ux fix - xBot
bcs_2 = fem.dirichletbc(0.0, yBot_uy_dofs, V3.sub(1)) # uy fix - yBot
bcs_3 = fem.dirichletbc(0.0, zBot_uz_dofs, V3.sub(2)) # uz fix - zBot
bcs_4 = fem.dirichletbc(ux_disp, xTop_ux_dofs, V3.sub(0)) # ux disp - xTop
bcs = [bcs_1, bcs_2, bcs_3, bcs_4]
# Callback function for updating the external operators in the Newton-loop
def constitutive_update():
evaluated_operands = evaluate_operands(J_external_operators)
PK2_coeff, _ = evaluate_external_operators(J_external_operators, evaluated_operands)
PK2_nye.ref_coefficient.x.array[:] = PK2_coeff
# Set up the problem and initialize the solver
problem = NonlinearProblemWithCallback(R_replaced, u, bcs=bcs, J=J_replaced, external_callback=constitutive_update)
# Global Newton-Raphson solver and its options
solver = NewtonSolver(domain.comm, problem)
solver.max_it = 200
solver.rtol = 1e-8
solver.convergence_criterion = "incremental"
ksp = solver.krylov_solver
opts = PETSc.Options() # type: ignore
option_prefix = ksp.getOptionsPrefix()
opts[f"{option_prefix}ksp_type"] = "preonly"
opts[f"{option_prefix}pc_type"] = "lu"
opts[f"{option_prefix}pc_factor_mat_solver_type"] = "mumps"
ksp.setFromOptions()
ux_disp_max = 0.05 # Maximum displacements of x-Top facet
Nsteps = 100 # Number of loading steps
u.x.array[:] = np.finfo(PETSc.ScalarType).eps # Initial displacements
# Loading loop
for n, disp in enumerate(np.linspace(0, ux_disp_max, Nsteps + 1)[1:]):
ux_disp.value = disp # Apply displacement
num_its, converged = solver.solve(u) # Solve the problem
assert converged
print(f"Time step {n}, Displacement {disp}, Number of iterations {num_its} .")
u.x.scatter_forward() # Update ghost values (parallel computing)
# Extract history values
evaluated_operands = evaluate_operands(J_external_operators)
_, (_, hist_coeff) = evaluate_external_operators(J_external_operators, evaluated_operands)
hist.x.array[:] = hist_coeff # Update history variables
As mentioned previously, the routine produces valid results (deformed shape, stress response, evolution of internal variables, local convergenve etc.). The core of the performance issues might lie in how I structure/utilize the JAX functions. Still, I would like to rule out other potential issues which I might be missing. If someone is profficient in JAX, or has stumbled upon similiar performance issues using external operators, I would really appriciate your help. Please let me know if I should provide further crucial information.
Best regards
Nadir