Not converging with manually formulated Jacobian

Hello everyone,

I’m currently working on implementing hyperelasticity, following this example:
https://jsdokken.com/dolfinx-tutorial/chapter2/hyperelasticity.html

My goal is to integrate a constitutive model using JAX, allowing for constitutive updates, similar to the method described here: Linear viscoelasticity with JAX — Computational Mechanics Numerical Tours with FEniCSx

Code: (I hope it is MWE, because I can only produce the error like this):

import numpy as np
import jax
import jax.numpy as jnp
import basix
import ufl
import pyvista
from dolfinx import mesh, fem, default_scalar_type, plot
from dolfinx.log import set_log_level, LogLevel
from mpi4py import MPI
from dolfinx.common import Timer
from solvers import CustomNonLinearProblem, CustomNewtonSolver

jax.config.update("jax_enable_x64", True)  # use double-precision

L = 20.0
domain = mesh.create_box(MPI.COMM_WORLD, [[0.0, 0.0, 0.0], [L, 1, 1]], [20, 5, 5], mesh.CellType.hexahedron)
dim = domain.topology.dim
degree = 1
shape = (dim,)
V = fem.functionspace(domain, ("Lagrange", 2, (domain.geometry.dim, )))
Vtensorpost = fem.functionspace(domain, ("Lagrange", 2, (domain.geometry.dim, domain.geometry.dim)))

deg_quad = 2
vdim = 3

WSe = basix.ufl.quadrature_element(
    domain.basix_cell(), value_shape=(), scheme="default", degree=deg_quad
)
WS = fem.functionspace(domain, WSe)

We = basix.ufl.quadrature_element(
    domain.basix_cell(), value_shape=(vdim,), scheme="default", degree=deg_quad
)
W = fem.functionspace(domain, We)

WTe = basix.ufl.quadrature_element(
    domain.basix_cell(), value_shape=(vdim,vdim), scheme="default", degree=deg_quad
)
WT = fem.functionspace(domain, WTe)

WTTe = basix.ufl.quadrature_element(
    domain.basix_cell(), value_shape=(vdim,vdim,vdim,vdim), scheme="default", degree=deg_quad
)
WTT = fem.functionspace(domain, WTTe)

P = fem.Function(WT, name="Piola_kirchhoff_stress")
F = fem.Function(WT, name="Deformation_gradient")
F_old = fem.Function(WT, name="Previous_deformation_gradient")
Ct = fem.Function(WTT, name="Tangent_operator") # delta_P/delta_F

B = fem.Constant(domain, default_scalar_type((0, 0, 0)))
T = fem.Constant(domain, default_scalar_type((0, 0, 0)))

# %%
# Material properties (Neo-Hookean material)
E = 1e4
nu = 0.3
mu = fem.Constant(domain, E / 2 / (1 + nu))
lmbda = fem.Constant(domain, E * nu / (1 - 2 * nu) / (1 + nu))

# %%
v = ufl.TestFunction(V)
du = ufl.TrialFunction(V)
u = fem.Function(V, name="Displacement")
delta_u = fem.Function(V, name="Iteration_correction")
# P = FEMExternalOperator(u, function_space=WT) 

# %%
def deformation_gradient(u):
    I = ufl.Identity(domain.geometry.dim) 
    F = I + ufl.grad(u)       
    return F

def calculate_psi(F):

    # Right Cauchy-Green tensor
    C = ufl.dot(F.T, F)

    # Invariants of deformation tensors
    Ic = ufl.tr(C)  # First invariant
    J = ufl.det(F)  # Determinant of F
    
    # Strain energy density function
    psi = (mu / 2) * (Ic - 3) - mu * ufl.ln(J) + (lmbda / 2) * (ufl.ln(J))**2

    return psi

def piola_kirchhoff_stress(u):

    I = ufl.Identity(len(u)) 
    F = I + ufl.grad(u)
    F = ufl.variable(F)

    psi = calculate_psi(F)

    P = ufl.diff(psi, F)

    return P

# %%
def left(x):
    return np.isclose(x[0], 0)


def right(x):
    return np.isclose(x[0], L)


fdim = domain.topology.dim - 1
left_facets = mesh.locate_entities_boundary(domain, fdim, left)
right_facets = mesh.locate_entities_boundary(domain, fdim, right)

# Concatenate and sort the arrays based on facet indices. Left facets marked with 1, right facets with two
marked_facets = np.hstack([left_facets, right_facets])
marked_values = jnp.hstack([jnp.full_like(left_facets, 1), np.full_like(right_facets, 2)])
sorted_facets = np.argsort(marked_facets)
facet_tag = mesh.meshtags(domain, fdim, marked_facets[sorted_facets], marked_values[sorted_facets])

u_bc = np.array((0,) * domain.geometry.dim, dtype=default_scalar_type)
left_dofs = fem.locate_dofs_topological(V, facet_tag.dim, facet_tag.find(1))
bcs = [fem.dirichletbc(u_bc, left_dofs, V)]

# %%
metadata = {"quadrature_degree": 4}

dx = ufl.Measure(
    "dx",
    domain=domain,
    metadata=metadata,
)

ds = ufl.Measure('ds',
    domain=domain, 
    subdomain_data=facet_tag, 
    metadata=metadata
)

# %%
basix_celltype = getattr(basix.CellType, domain.topology.cell_type.name)
quadrature_points, weights = basix.make_quadrature(basix_celltype, deg_quad)

map_c = domain.topology.index_map(domain.topology.dim)
num_cells = map_c.size_local + map_c.num_ghosts  
cells = np.arange(0, num_cells, dtype=np.int32)  
ngauss = num_cells * len(weights)  

F_expr = fem.Expression(deformation_gradient(u), quadrature_points)

def eval_at_quadrature_points(expression):
    return expression.eval(domain, cells).reshape(ngauss, -1)

def interpolate_quadrature(ufl_expr, function):
    expr_expr = fem.Expression(ufl_expr, quadrature_points)
    expr_eval = expr_expr.eval(domain, cells)
    function.x.array[:] = expr_eval.flatten()[:]

# %%
Residual = ufl.inner(ufl.grad(v), P) * dx - ufl.inner(v, B) * dx - ufl.inner(v, T) * ds(2)

# %%
i, j, k, l = ufl.indices(4)
A = ufl.as_tensor(Ct[i,j,k,l] * ufl.grad(du)[k,l], (i,j))
tangent_form = ufl.inner(ufl.grad(v), A)*dx

# %%
def strain_energy(F, mu, lambda_):
    J = jnp.linalg.det(F)  
    Ic = jnp.trace(jnp.dot(F.T, F))  
    return (mu / 2) * (Ic - 3) - mu * jnp.log(J) + (lambda_ / 2) * (jnp.log(J))**2

def tangent_stiffness(F, mu, lambda_):
    return jax.jacfwd(first_piola_kirchhoff)(F, mu, lambda_)

def first_piola_kirchhoff(F, mu, lambda_):
    return jax.grad(strain_energy, argnums=0)(F, mu, lambda_)

def local_constitutive_update(dF, F_old, reg_param=1e-12):
    mu_ = float(mu.value)
    lambda_ = float(lmbda.value)
    
    F_new = jnp.matmul(dF, F_old)  
    J = jnp.linalg.det(F_new)
    F_new = jnp.where(J < 1e-8, jnp.eye(3), F_new)

    P_new = first_piola_kirchhoff(F_new, mu_, lambda_)

    Ct = tangent_stiffness(F_new, mu_, lambda_)

    return P_new, (P_new, F_new, Ct)


# %%
tangent_operator_and_state = jax.jacfwd(
    local_constitutive_update, argnums=0, has_aux=True
)

# %%
batched_constitutive_update = jax.jit(
    jax.vmap(tangent_operator_and_state, in_axes=(0, 0))
)

# %%
def constitutive_update(u, P, F_old, F):

    with Timer("Constitutive update"):
        F_expr_local = fem.Expression(deformation_gradient(u), quadrature_points)
        F_values = eval_at_quadrature_points(F_expr_local).reshape(ngauss, 3, 3)

        F_old_values = F_old.x.array.reshape(ngauss, 3, 3)

        det_F_old = np.linalg.det(F_old_values)
        if np.any(det_F_old < 1e-8):
            idx = det_F_old < 1e-8
            F_old_values[idx] = np.tile(np.eye(3).flatten(), (np.count_nonzero(idx), 1)).reshape(-1, 3, 3)

        dF_values = jnp.matmul(F_values, jnp.linalg.inv(F_old_values))
        
        Ct_values, states = batched_constitutive_update(dF_values, F_old_values)
        P_values = states[0]
        F_new_values = states[1]  

        
        P.x.array[:] = P_values.ravel()
        F.x.array[:] = F_new_values.ravel() 
        Ct.x.array[:] = Ct_values.ravel()

        
        P.x.scatter_forward()
        F.x.scatter_forward()
        Ct.x.scatter_forward()

        arr_u = u.x.array
        print(f"[Constitutive Update] max|u| = {np.max(np.abs(arr_u)):.4e}, min|u| = {np.min(arr_u):.4e}")


# %%
petsc_options={
        "ksp_type": "preonly",
        "pc_type": "lu",
        "pc_factor_mat_solver_type": "mumps",
        "ksp_monitor": None,
    }
nonlinear_problem = CustomNonLinearProblem(
    residual_form=Residual,
    jacobian_form=tangent_form,
    u=u,
    bcs=bcs,  # Boundary conditions
    petsc_options=petsc_options,
)

# %%
newton = CustomNewtonSolver(domain.comm, nonlinear_problem, bcs=bcs)
newton.callback = constitutive_update


# %%
pyvista.start_xvfb()
plotter = pyvista.Plotter()
plotter.open_gif("deformation.gif", fps=3)

topology, cells, geometry = plot.vtk_mesh(u.function_space)
function_grid = pyvista.UnstructuredGrid(topology, cells, geometry)

values = np.zeros((geometry.shape[0], 3))
values[:, :len(u)] = u.x.array.reshape(geometry.shape[0], len(u))
function_grid["u"] = values
function_grid.set_active_vectors("u")

# Warp mesh by deformation
warped = function_grid.warp_by_vector("u", factor=1)
warped.set_active_vectors("u")

# Add mesh to plotter and visualize
actor = plotter.add_mesh(warped, show_edges=True, lighting=False, clim=[0, 10])

# Compute magnitude of displacement to visualize in GIF
Vs = fem.functionspace(domain, ("Lagrange", 2))
magnitude = fem.Function(Vs)
us = fem.Expression(ufl.sqrt(sum([u[i]**2 for i in range(len(u))])), Vs.element.interpolation_points())
magnitude.interpolate(us)
warped["mag"] = magnitude.x.array


# %%
# Set log level to INFO
set_log_level(LogLevel.INFO)

P.x.petsc_vec.set(0.0)
P_values = P.x.array.reshape(ngauss, -1)

Ct.x.petsc_vec.set(0.0)
Ct_values = Ct.x.array.reshape(ngauss, -1)
# Anzahl der Quadraturpunkte
ngauss = F_old.x.array.size // 9  

F_old.x.array[:] = np.tile(np.eye(3).flatten(), ngauss)
u.x.petsc_vec.set(0.0)    

tval0 = -1.5  


for n in range(1, 10):
    T.value[2] = n * tval0 
    state = (P, F_old, F) 
    num_its, converged = newton.solve(u, *state) 
    assert converged
    u.x.scatter_forward()

    # Update the previous total and viscous strain
    F_expr_local = fem.Expression(deformation_gradient(u), quadrature_points)
    F_values = F_expr_local.eval(domain, np.arange(ngauss)).reshape(ngauss, 9)
    F_old.x.array[:] = F_values.ravel()
 
    function_grid["u"][:, :len(u)] = u.x.array.reshape(geometry.shape[0], len(u))
    magnitude.interpolate(us)
    warped.set_active_scalars("mag")
    warped_n = function_grid.warp_by_vector(factor=1)
    warped.points[:, :] = warped_n.points
    warped.point_data["mag"][:] = magnitude.x.array
    plotter.update_scalar_bar_range([0, 10])
    plotter.write_frame()

plotter.close()

After running the code, I will get nan in the resudual-norm like:

[Constitutive Update] max|u| = 0.0000e+00, min|u| = 0.0000e+00
Iteration 1/50: Residual-Norm = 1.63333e-01
Initiale Residual-Norm: 1.63333e-01
[Constitutive Update] max|u| = 1.2317e+02, min|u| = -9.8413e+01
Iteration 2/50: Residual-Norm = 8.75109e+03
...
[Constitutive Update] max|u| = 1.6255e+12, min|u| = -1.1264e+12
Iteration 12/50: Residual-Norm = 2.25652e+13
[Constitutive Update] max|u| = inf, min|u| = -inf
Iteration 13/50: Residual-Norm = nan

I tried adding more elements to the mesh, lowering the load increment factor, but nothing really helped me.

Thank you in advance

Hi,
I did not look at your code but if you just need hyperelasticity, pure UFL strategy much easier than custom constitutive update with JAX for instance. If you need something more complex such as finite strain plasticity for instance I suggest you to use Documentation — dolfinx_materials
In particular I have a FeFp plasticity implementation in JAX Finite-strain \(\boldsymbol{F}^\text{e}\boldsymbol{F}^\text{p}\) plasticity — dolfinx_materials

1 Like

Hi there,

Thanks for your quick response!

I’ve also studied your library dolfinx_materials and the example FeFp.

Since I plan to customize the material model later, and it closely resembles the hyperelasticity case, the examples you provided didn’t quite take me any further.

Thank you in advance!