Hello everyone,
I’m currently working on implementing hyperelasticity, following this example:
https://jsdokken.com/dolfinx-tutorial/chapter2/hyperelasticity.html
My goal is to integrate a constitutive model using JAX, allowing for constitutive updates, similar to the method described here: Linear viscoelasticity with JAX — Computational Mechanics Numerical Tours with FEniCSx
Code: (I hope it is MWE, because I can only produce the error like this):
import numpy as np
import jax
import jax.numpy as jnp
import basix
import ufl
import pyvista
from dolfinx import mesh, fem, default_scalar_type, plot
from dolfinx.log import set_log_level, LogLevel
from mpi4py import MPI
from dolfinx.common import Timer
from solvers import CustomNonLinearProblem, CustomNewtonSolver
jax.config.update("jax_enable_x64", True) # use double-precision
L = 20.0
domain = mesh.create_box(MPI.COMM_WORLD, [[0.0, 0.0, 0.0], [L, 1, 1]], [20, 5, 5], mesh.CellType.hexahedron)
dim = domain.topology.dim
degree = 1
shape = (dim,)
V = fem.functionspace(domain, ("Lagrange", 2, (domain.geometry.dim, )))
Vtensorpost = fem.functionspace(domain, ("Lagrange", 2, (domain.geometry.dim, domain.geometry.dim)))
deg_quad = 2
vdim = 3
WSe = basix.ufl.quadrature_element(
domain.basix_cell(), value_shape=(), scheme="default", degree=deg_quad
)
WS = fem.functionspace(domain, WSe)
We = basix.ufl.quadrature_element(
domain.basix_cell(), value_shape=(vdim,), scheme="default", degree=deg_quad
)
W = fem.functionspace(domain, We)
WTe = basix.ufl.quadrature_element(
domain.basix_cell(), value_shape=(vdim,vdim), scheme="default", degree=deg_quad
)
WT = fem.functionspace(domain, WTe)
WTTe = basix.ufl.quadrature_element(
domain.basix_cell(), value_shape=(vdim,vdim,vdim,vdim), scheme="default", degree=deg_quad
)
WTT = fem.functionspace(domain, WTTe)
P = fem.Function(WT, name="Piola_kirchhoff_stress")
F = fem.Function(WT, name="Deformation_gradient")
F_old = fem.Function(WT, name="Previous_deformation_gradient")
Ct = fem.Function(WTT, name="Tangent_operator") # delta_P/delta_F
B = fem.Constant(domain, default_scalar_type((0, 0, 0)))
T = fem.Constant(domain, default_scalar_type((0, 0, 0)))
# %%
# Material properties (Neo-Hookean material)
E = 1e4
nu = 0.3
mu = fem.Constant(domain, E / 2 / (1 + nu))
lmbda = fem.Constant(domain, E * nu / (1 - 2 * nu) / (1 + nu))
# %%
v = ufl.TestFunction(V)
du = ufl.TrialFunction(V)
u = fem.Function(V, name="Displacement")
delta_u = fem.Function(V, name="Iteration_correction")
# P = FEMExternalOperator(u, function_space=WT)
# %%
def deformation_gradient(u):
I = ufl.Identity(domain.geometry.dim)
F = I + ufl.grad(u)
return F
def calculate_psi(F):
# Right Cauchy-Green tensor
C = ufl.dot(F.T, F)
# Invariants of deformation tensors
Ic = ufl.tr(C) # First invariant
J = ufl.det(F) # Determinant of F
# Strain energy density function
psi = (mu / 2) * (Ic - 3) - mu * ufl.ln(J) + (lmbda / 2) * (ufl.ln(J))**2
return psi
def piola_kirchhoff_stress(u):
I = ufl.Identity(len(u))
F = I + ufl.grad(u)
F = ufl.variable(F)
psi = calculate_psi(F)
P = ufl.diff(psi, F)
return P
# %%
def left(x):
return np.isclose(x[0], 0)
def right(x):
return np.isclose(x[0], L)
fdim = domain.topology.dim - 1
left_facets = mesh.locate_entities_boundary(domain, fdim, left)
right_facets = mesh.locate_entities_boundary(domain, fdim, right)
# Concatenate and sort the arrays based on facet indices. Left facets marked with 1, right facets with two
marked_facets = np.hstack([left_facets, right_facets])
marked_values = jnp.hstack([jnp.full_like(left_facets, 1), np.full_like(right_facets, 2)])
sorted_facets = np.argsort(marked_facets)
facet_tag = mesh.meshtags(domain, fdim, marked_facets[sorted_facets], marked_values[sorted_facets])
u_bc = np.array((0,) * domain.geometry.dim, dtype=default_scalar_type)
left_dofs = fem.locate_dofs_topological(V, facet_tag.dim, facet_tag.find(1))
bcs = [fem.dirichletbc(u_bc, left_dofs, V)]
# %%
metadata = {"quadrature_degree": 4}
dx = ufl.Measure(
"dx",
domain=domain,
metadata=metadata,
)
ds = ufl.Measure('ds',
domain=domain,
subdomain_data=facet_tag,
metadata=metadata
)
# %%
basix_celltype = getattr(basix.CellType, domain.topology.cell_type.name)
quadrature_points, weights = basix.make_quadrature(basix_celltype, deg_quad)
map_c = domain.topology.index_map(domain.topology.dim)
num_cells = map_c.size_local + map_c.num_ghosts
cells = np.arange(0, num_cells, dtype=np.int32)
ngauss = num_cells * len(weights)
F_expr = fem.Expression(deformation_gradient(u), quadrature_points)
def eval_at_quadrature_points(expression):
return expression.eval(domain, cells).reshape(ngauss, -1)
def interpolate_quadrature(ufl_expr, function):
expr_expr = fem.Expression(ufl_expr, quadrature_points)
expr_eval = expr_expr.eval(domain, cells)
function.x.array[:] = expr_eval.flatten()[:]
# %%
Residual = ufl.inner(ufl.grad(v), P) * dx - ufl.inner(v, B) * dx - ufl.inner(v, T) * ds(2)
# %%
i, j, k, l = ufl.indices(4)
A = ufl.as_tensor(Ct[i,j,k,l] * ufl.grad(du)[k,l], (i,j))
tangent_form = ufl.inner(ufl.grad(v), A)*dx
# %%
def strain_energy(F, mu, lambda_):
J = jnp.linalg.det(F)
Ic = jnp.trace(jnp.dot(F.T, F))
return (mu / 2) * (Ic - 3) - mu * jnp.log(J) + (lambda_ / 2) * (jnp.log(J))**2
def tangent_stiffness(F, mu, lambda_):
return jax.jacfwd(first_piola_kirchhoff)(F, mu, lambda_)
def first_piola_kirchhoff(F, mu, lambda_):
return jax.grad(strain_energy, argnums=0)(F, mu, lambda_)
def local_constitutive_update(dF, F_old, reg_param=1e-12):
mu_ = float(mu.value)
lambda_ = float(lmbda.value)
F_new = jnp.matmul(dF, F_old)
J = jnp.linalg.det(F_new)
F_new = jnp.where(J < 1e-8, jnp.eye(3), F_new)
P_new = first_piola_kirchhoff(F_new, mu_, lambda_)
Ct = tangent_stiffness(F_new, mu_, lambda_)
return P_new, (P_new, F_new, Ct)
# %%
tangent_operator_and_state = jax.jacfwd(
local_constitutive_update, argnums=0, has_aux=True
)
# %%
batched_constitutive_update = jax.jit(
jax.vmap(tangent_operator_and_state, in_axes=(0, 0))
)
# %%
def constitutive_update(u, P, F_old, F):
with Timer("Constitutive update"):
F_expr_local = fem.Expression(deformation_gradient(u), quadrature_points)
F_values = eval_at_quadrature_points(F_expr_local).reshape(ngauss, 3, 3)
F_old_values = F_old.x.array.reshape(ngauss, 3, 3)
det_F_old = np.linalg.det(F_old_values)
if np.any(det_F_old < 1e-8):
idx = det_F_old < 1e-8
F_old_values[idx] = np.tile(np.eye(3).flatten(), (np.count_nonzero(idx), 1)).reshape(-1, 3, 3)
dF_values = jnp.matmul(F_values, jnp.linalg.inv(F_old_values))
Ct_values, states = batched_constitutive_update(dF_values, F_old_values)
P_values = states[0]
F_new_values = states[1]
P.x.array[:] = P_values.ravel()
F.x.array[:] = F_new_values.ravel()
Ct.x.array[:] = Ct_values.ravel()
P.x.scatter_forward()
F.x.scatter_forward()
Ct.x.scatter_forward()
arr_u = u.x.array
print(f"[Constitutive Update] max|u| = {np.max(np.abs(arr_u)):.4e}, min|u| = {np.min(arr_u):.4e}")
# %%
petsc_options={
"ksp_type": "preonly",
"pc_type": "lu",
"pc_factor_mat_solver_type": "mumps",
"ksp_monitor": None,
}
nonlinear_problem = CustomNonLinearProblem(
residual_form=Residual,
jacobian_form=tangent_form,
u=u,
bcs=bcs, # Boundary conditions
petsc_options=petsc_options,
)
# %%
newton = CustomNewtonSolver(domain.comm, nonlinear_problem, bcs=bcs)
newton.callback = constitutive_update
# %%
pyvista.start_xvfb()
plotter = pyvista.Plotter()
plotter.open_gif("deformation.gif", fps=3)
topology, cells, geometry = plot.vtk_mesh(u.function_space)
function_grid = pyvista.UnstructuredGrid(topology, cells, geometry)
values = np.zeros((geometry.shape[0], 3))
values[:, :len(u)] = u.x.array.reshape(geometry.shape[0], len(u))
function_grid["u"] = values
function_grid.set_active_vectors("u")
# Warp mesh by deformation
warped = function_grid.warp_by_vector("u", factor=1)
warped.set_active_vectors("u")
# Add mesh to plotter and visualize
actor = plotter.add_mesh(warped, show_edges=True, lighting=False, clim=[0, 10])
# Compute magnitude of displacement to visualize in GIF
Vs = fem.functionspace(domain, ("Lagrange", 2))
magnitude = fem.Function(Vs)
us = fem.Expression(ufl.sqrt(sum([u[i]**2 for i in range(len(u))])), Vs.element.interpolation_points())
magnitude.interpolate(us)
warped["mag"] = magnitude.x.array
# %%
# Set log level to INFO
set_log_level(LogLevel.INFO)
P.x.petsc_vec.set(0.0)
P_values = P.x.array.reshape(ngauss, -1)
Ct.x.petsc_vec.set(0.0)
Ct_values = Ct.x.array.reshape(ngauss, -1)
# Anzahl der Quadraturpunkte
ngauss = F_old.x.array.size // 9
F_old.x.array[:] = np.tile(np.eye(3).flatten(), ngauss)
u.x.petsc_vec.set(0.0)
tval0 = -1.5
for n in range(1, 10):
T.value[2] = n * tval0
state = (P, F_old, F)
num_its, converged = newton.solve(u, *state)
assert converged
u.x.scatter_forward()
# Update the previous total and viscous strain
F_expr_local = fem.Expression(deformation_gradient(u), quadrature_points)
F_values = F_expr_local.eval(domain, np.arange(ngauss)).reshape(ngauss, 9)
F_old.x.array[:] = F_values.ravel()
function_grid["u"][:, :len(u)] = u.x.array.reshape(geometry.shape[0], len(u))
magnitude.interpolate(us)
warped.set_active_scalars("mag")
warped_n = function_grid.warp_by_vector(factor=1)
warped.points[:, :] = warped_n.points
warped.point_data["mag"][:] = magnitude.x.array
plotter.update_scalar_bar_range([0, 10])
plotter.write_frame()
plotter.close()
After running the code, I will get nan in the resudual-norm like:
[Constitutive Update] max|u| = 0.0000e+00, min|u| = 0.0000e+00
Iteration 1/50: Residual-Norm = 1.63333e-01
Initiale Residual-Norm: 1.63333e-01
[Constitutive Update] max|u| = 1.2317e+02, min|u| = -9.8413e+01
Iteration 2/50: Residual-Norm = 8.75109e+03
...
[Constitutive Update] max|u| = 1.6255e+12, min|u| = -1.1264e+12
Iteration 12/50: Residual-Norm = 2.25652e+13
[Constitutive Update] max|u| = inf, min|u| = -inf
Iteration 13/50: Residual-Norm = nan
I tried adding more elements to the mesh, lowering the load increment factor, but nothing really helped me.
Thank you in advance