Performance issues using external operators for finite elasto-plasticity

Dear FEniCS community,

for the last couple of months, I’ve been utilizing external operators to implement some material models. I’ve managed to achive great results/performance in the case of hyperelasticity, and small strain elasto-plasticity. The latter one was done in a slightly different algorithmic manner than the referenced example.

My next goal was finite elasto-plasticity. While the routine is working, I’ve noticed a significant drop in performance. I’m aware that the finite elasto-plasticity is more computationally intensive than hyperelasticity or its small strain counterpart, still I find such performance drop off a bit surprising.

I’ve installed FEniCSx via conda-forge. The following is an example of uniaxial tension test of a 3D-Cube, which I use as a benchmark:

# Import modules
import numpy as np

import jax
import jax.lax
import jax.numpy as jnp
from jax.scipy.linalg import expm

from mpi4py import MPI

from petsc4py import PETSc

from dolfinx import fem, mesh
from dolfinx.io import XDMFFile

import ufl

import basix

from dolfinx_external_operator import (
    FEMExternalOperator,
    evaluate_external_operators,
    evaluate_operands,
    replace_external_operators,
)

from solvers import NonlinearProblemWithCallback
from dolfinx.nls.petsc import NewtonSolver

import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", True)   # Set 64-bit arithmetic in JAX

# Tensor- to Nye-notation and vice versa
def jaxt2n(A, x):
    return jnp.array([A[0,0], A[1,1], A[2,2], x * A[0,1], x * A[0,2], x * A[1,2]])

def jaxn2t(v, x):
    return jnp.array([[v[0], v[3] / x, v[4] / x],
                      [v[3] / x, v[1], v[5] / x],
                      [v[4] / x, v[5] / x, v[2]]])

def uflt2n(A, x):
    return ufl.as_vector([A[0,0], A[1,1], A[2,2], x * A[0,1], x * A[0,2], x * A[1,2]])

def ufln2t(v, x):
    return ufl.as_tensor([[v[0], v[3] / x, v[4] / x],
                          [v[3] / x, v[1], v[5] / x],
                          [v[4] / x, v[5] / x, v[2]]])

# A 3-D cube
length = 1.0
width  = 1.0
height = 1.0 
N = 5
domain = mesh.create_box(MPI.COMM_WORLD, [[0.0, 0.0, 0.0], [length, width, height]],\
                         [N, N, N], mesh.CellType.hexahedron)

gdim = domain.topology.dim    # Dimension of the body 
fdim = gdim - 1               # Dimension of the facets

domain.topology.create_connectivity(fdim, gdim)   # Establish topological connections (needed for boundary conditions)

# Identify the planar boundaries of the cube mesh
def xBot(x):
    return np.isclose(x[0], 0)
def xTop(x):
    return np.isclose(x[0], length)
def yBot(x):
    return np.isclose(x[1], 0)
def yTop(x):
    return np.isclose(x[1], width)
def zBot(x):
    return np.isclose(x[2], 0)
def zTop(x):
    return np.isclose(x[2], height)
    
boundaries = [(1, xBot),(2,xTop),(3,yBot),(4,yTop),(5,zBot),(6,zTop)]   # Mark the sub-domains

# Build collections of facets on each subdomain and mark them appropriately.
facet_indices, facet_markers = [], []                     # Initalize empty collections of indices and markers.                    
for (marker, locator) in boundaries:
    facets = mesh.locate_entities(domain, fdim, locator)    # An array of all the facets in a given subdomain ("locator")
    facet_indices.append(facets)                          # Add these facets to the collection.
    facet_markers.append(np.full_like(facets, marker))    # Mark them with the appropriate index.

# Format the facet indices and markers as required for use in dolfinx.
facet_indices = np.hstack(facet_indices).astype(np.int32)
facet_markers = np.hstack(facet_markers).astype(np.int32)
sorted_facets = np.argsort(facet_indices)
 
facet_tags = mesh.meshtags(domain, fdim, facet_indices[sorted_facets], facet_markers[sorted_facets])  # Add these marked facets as "mesh tags" for later use in BCs.

# Define element formulationd and function spaced
V1e = basix.ufl.element("Lagrange", domain.basix_cell(), 1)                              # For scalars
V3e = basix.ufl.element("Lagrange", domain.basix_cell(), 1, shape=(3,))                  # For vectors
G6e = basix.ufl.quadrature_element(domain.basix_cell(), degree=2, value_shape=(6, ))     # For tensors in Nye-notation (at the quadrature point) 
G8e = basix.ufl.quadrature_element(domain.basix_cell(), degree=2, value_shape=(8, ))     # For internal variables (at the quadrature point) 

V1 = fem.functionspace(domain, V1e)
V3 = fem.functionspace(domain, V3e)
G6 = fem.functionspace(domain, G6e)
G8 = fem.functionspace(domain, G8e)

u = fem.Function(V3, name="Displacement")
hist = fem.Function(G8, name="History variables")    

# Initialize internal variables
init_array = np.array([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])
num_quadpoints = hist.x.array.size // 8
init_values = np.tile(init_array, num_quadpoints)
hist.x.array[:] = init_values

# Kinematics
Id = ufl.Identity(3)
F = Id + ufl.grad(u)   
GLS = 0.5 * (F.T * F - Id) 
GLS_nye = uflt2n(GLS, 2)

PK2_nye = FEMExternalOperator(GLS_nye, function_space=G6)   # Declare the second Piola-Kirchhoff stress tensor as an external operator

def psi(RCGe, kappa):
    E  = 200e3   # E-Modulus
    nu = 0.3    # Poisson's ratio
    H  = 0   # Hardening modulus
    lmbda = E * nu / (1 - 2 * nu) / (1 + nu)    # 1st Lame parameter     
    mu = E / 2 / (1 + nu)                       # 2nd Lame parameter

    I1 = jnp.trace(RCGe)    # 1st invariant
    I3 = jnp.linalg.det(RCGe)   # 3rd invariant

    psi_e = 0.5 * mu * (I1 - 3 - jnp.log(I3)) + 0.25 * lmbda * (I3 - 1 - jnp.log(I3))   # Elastic part of the free energy   
    psi_h = 0.5 * H * kappa**2  # Hardening part of the free energy
    return psi_e + psi_h 

dpsidRCGe = jax.grad(psi, argnums=0)
dpsidkappa = jax.grad(psi, argnums=1)

def Phi(Y, R):
    sigy_0 = 300    # Initial yield stress
    devY = Y - jnp.trace(Y) / 3 * jnp.identity(3)    # Deviatoric part of the plastic driving force
    Phi = jnp.sqrt(3 / 2) * jnp.linalg.norm(devY, ord='fro') - (sigy_0 + R)    # Von Mises yield criterion/function
    return Phi

dPhidY = jax.grad(Phi, argnums=0)
dPhidR = jax.grad(Phi, argnums=1)

# Local residuum
def r(L, hist_old, RCG): 
    Up_n_nye = hist_old[:6]
    kappa_n = hist_old[6] 
    Up_nye = L[:6]
    kappa = L[6]
    dlambda = L[7]

    Up = jaxn2t(Up_nye, 2)
    Up_n = jaxn2t(Up_n_nye, 2)

    invUp = jnp.linalg.inv(Up) #precompute_inverses(Up)
    invUp_n = jnp.linalg.inv(Up_n) #precompute_inverses(Up_n)
    RCGe = invUp @ RCG @ invUp

    dpsidRCGe_ = dpsidRCGe(RCGe, kappa)
    PK2 = 2.0 * invUp @ dpsidRCGe_ @ invUp
    Y = 2.0 * RCGe @ dpsidRCGe_
    R = dpsidkappa(RCGe, kappa)
    Z = 2.0 * dlambda * dPhidY(Y, R)
    
    res_Phi = Phi(Y, R)
    res_Up = invUp @ expm(Z) @ invUp - invUp_n @ invUp_n
    res_Up_nye = jaxt2n(res_Up, 1)
    res_kappa = kappa_n - kappa - dlambda * dPhidR(Y, R)
    res = jnp.concatenate((res_Phi.reshape(1), res_Up_nye, res_kappa.reshape(1)))
    return res, PK2

drdL = jax.jacfwd(r)    # Local jacobian

TOL = 1.0e-8     # Tolerance for local Newton scheme
NITER_MAX = 25   # Maximal number of local Newton iterations

def PK2_local(GLS_nye, hist_old):
    """
    Core function. Performs the return-mapping 
    algorithm on a Gauss-Point level. Takes the 
    Green-Lagrange strain tensor and the history 
    variables from the previos time step as inputs
    and calulates the 2nd Piola-Kirchoff stress 
    tensors and the new history variables, alongside
    some additional control parameters.
    """
    Up_n_nye = hist_old[:6]    # Plastic stretch from previous time step
    kappa_n = hist_old[6]      # Accumulated plastic strain from previous time step

    GLS = jaxn2t(GLS_nye, 2)
    Up_n = jaxn2t(Up_n_nye, 2)
    RCG = 2.0 * GLS + jnp.identity(3)
    invUp_n = jnp.linalg.inv(Up_n) # precompute_inverses(Up_n)

    # Trial step
    RCGe_trial = invUp_n @ RCG @ invUp_n
    dpsidRCGe_trial = dpsidRCGe(RCGe_trial, kappa_n)

    PK2_trial = 2.0 * invUp_n @ dpsidRCGe_trial @ invUp_n
    Y_trial = 2.0 * RCGe_trial @ dpsidRCGe_trial
    R_trial = dpsidkappa(RCGe_trial, kappa_n)

    Phi_trial = Phi(Y_trial, R_trial)

    state_cond_initial = (PK2_trial, hist_old, RCG)
    
    def elastic_step(state_cond):
        return state_cond[0], state_cond[1]
    
    def plastic_step(state_cond_n):
        hist_old = state_cond_n[1]
        RCG = state_cond_n[2]
        
        L0 = hist_old.at[7].set(0.0)    # Initial local solution vector
        
        niter = 0             # Initiate counter
        
        res_0, PK2_0 = r(L0, hist_old, RCG)
        norm_res_0 = jnp.linalg.norm(res_0)

        state_loop_initial = {"iteration_number": niter,
                              "intvar_current": L0,
                              "intvar_previous": hist_old,
                              "residuum_norm": norm_res_0,
                              "residuum": res_0,
                              "stress": PK2_0,
                              "deformation": RCG}

        def cond_fun(state_loop):
            norm_res = state_loop["residuum_norm"] 
            niter = state_loop["iteration_number"]
            return jnp.logical_and(norm_res > TOL, niter < NITER_MAX)

        def body_fun(state_loop_n): 
            niter = state_loop_n["iteration_number"]
            res_n = state_loop_n["residuum"]
            norm_res = state_loop_n["residuum_norm"]
            Ln = state_loop_n["intvar_current"]
            h0 = state_loop_n["intvar_previous"]
            RCG = state_loop_n["deformation"]

            j = drdL(Ln, h0, RCG)[0]
            dL = jnp.linalg.solve(j, -res_n)
            L = Ln + dL

            res, PK2 = r(L, h0, RCG)
            norm_res = jnp.linalg.norm(res)

            niter += 1

            state_loop = {"iteration_number": niter,
                          "intvar_current": L,
                          "intvar_previous": h0,
                          "residuum_norm": norm_res,
                          "residuum": res,
                          "stress": PK2,
                          "deformation": RCG}

            return state_loop
        
        state_loop_final = jax.lax.while_loop(cond_fun, body_fun, state_loop_initial)

        PK2 = state_loop_final["stress"]
        h1 = state_loop_final["intvar_current"]

        return PK2, h1
    
    PK2, h1 = jax.lax.cond(Phi_trial <= TOL, elastic_step, plastic_step, state_cond_initial)

    return jaxt2n(PK2, 1), h1  
     
dPK2dGLS_local = jax.jacfwd(PK2_local, has_aux=True)    # Material tangent

PK2_global = jax.jit(jax.vmap(PK2_local, in_axes=(0, 0)))   # Global stress values

dPK2dGLS_global = jax.jit(jax.vmap(dPK2dGLS_local, in_axes=(0, 0)))   # Global material tangent

def exof_stress(GLS):
    GLS_ = GLS.reshape((-1, 6))
    hist_ = hist.x.array[:].reshape((-1, 8))
    PK2_, _ = PK2_global(GLS_, hist_)
    return PK2_.reshape(-1)

def exof_tangent(GLS):
    GLS_ = GLS.reshape((-1, 6))
    hist_ = hist.x.array[:].reshape((-1, 8))
    dPK2dGLS_, hist_new = dPK2dGLS_global(GLS_, hist_)
    return dPK2dGLS_.reshape(-1), hist_new.reshape(-1) 

def PK2_external(derivatives):
    if derivatives == (0,):
        return exof_stress
    elif derivatives == (1,):
        return exof_tangent
    else:
        return NotImplementedError

PK2_nye.external_function = PK2_external

dx = ufl.Measure('dx', domain=domain, metadata={'quadrature_degree': 2, 'quadrature_rule':'default'})   # Volume integration measure 

v = ufl.TestFunction(V3)
virtGLS = ufl.sym(F.T * ufl.grad(v)) 
virtGLS_nye = uflt2n(virtGLS, 2)

du = ufl.TrialFunction(V3)  

R = ufl.inner(PK2_nye, virtGLS_nye) * dx   # Residuum
J = ufl.derivative(R, u, du)           # Jacobian 
J_expanded = ufl.algorithms.expand_derivatives(J)

R_replaced, R_external_operators = replace_external_operators(R)
J_replaced, J_external_operators = replace_external_operators(J_expanded)

R_form = fem.form(R_replaced)
J_form = fem.form(J_replaced)

# Find the specific constrained DOFs
xBot_ux_dofs = fem.locate_dofs_topological(V3.sub(0), facet_tags.dim, facet_tags.find(1))
yBot_uy_dofs = fem.locate_dofs_topological(V3.sub(1), facet_tags.dim, facet_tags.find(3))
zBot_uz_dofs = fem.locate_dofs_topological(V3.sub(2), facet_tags.dim, facet_tags.find(5))
xTop_ux_dofs = fem.locate_dofs_topological(V3.sub(0), facet_tags.dim, facet_tags.find(2))

ux_disp = fem.Constant(domain, 0.0) # Initialize displacements of xTop facet

bcs_1 = fem.dirichletbc(0.0, xBot_ux_dofs, V3.sub(0))      # ux fix - xBot
bcs_2 = fem.dirichletbc(0.0, yBot_uy_dofs, V3.sub(1))      # uy fix - yBot
bcs_3 = fem.dirichletbc(0.0, zBot_uz_dofs, V3.sub(2))      # uz fix - zBot
bcs_4 = fem.dirichletbc(ux_disp, xTop_ux_dofs, V3.sub(0))  # ux disp - xTop

bcs = [bcs_1, bcs_2, bcs_3, bcs_4]

# Callback function for updating the external operators in the Newton-loop
def constitutive_update():
    evaluated_operands = evaluate_operands(J_external_operators)
    PK2_coeff, _ = evaluate_external_operators(J_external_operators, evaluated_operands)
    PK2_nye.ref_coefficient.x.array[:] = PK2_coeff
 
# Set up the problem and initialize the solver 
problem = NonlinearProblemWithCallback(R_replaced, u, bcs=bcs, J=J_replaced, external_callback=constitutive_update)

# Global Newton-Raphson solver and its options
solver = NewtonSolver(domain.comm, problem)  
solver.max_it = 200
solver.rtol = 1e-8
solver.convergence_criterion = "incremental"
ksp = solver.krylov_solver
opts = PETSc.Options()  # type: ignore
option_prefix = ksp.getOptionsPrefix()
opts[f"{option_prefix}ksp_type"] = "preonly"
opts[f"{option_prefix}pc_type"] = "lu"
opts[f"{option_prefix}pc_factor_mat_solver_type"] = "mumps"
ksp.setFromOptions()

ux_disp_max = 0.05   # Maximum displacements of x-Top facet
Nsteps = 100         # Number of loading steps

u.x.array[:] = np.finfo(PETSc.ScalarType).eps   # Initial displacements

# Loading loop
for n, disp in enumerate(np.linspace(0, ux_disp_max, Nsteps + 1)[1:]):

    ux_disp.value = disp  # Apply displacement 
    
    num_its, converged = solver.solve(u)  # Solve the problem
    assert converged                     

    print(f"Time step {n}, Displacement {disp}, Number of iterations {num_its} .")

    u.x.scatter_forward()  # Update ghost values (parallel computing)

    # Extract history values
    evaluated_operands = evaluate_operands(J_external_operators)
    _, (_, hist_coeff) = evaluate_external_operators(J_external_operators, evaluated_operands)

    hist.x.array[:] = hist_coeff  # Update history variables

As mentioned previously, the routine produces valid results (deformed shape, stress response, evolution of internal variables, local convergenve etc.). The core of the performance issues might lie in how I structure/utilize the JAX functions. Still, I would like to rule out other potential issues which I might be missing. If someone is profficient in JAX, or has stumbled upon similiar performance issues using external operators, I would really appriciate your help. Please let me know if I should provide further crucial information.

Best regards
Nadir

@a-latyshev maybe you have some ideas ?:slight_smile:

Hi, I suspect that the issue is due to the derivation of the material tangent via direct automatic differentiation of the Newton solver in the plastic step.
I suggest to look at the approach proposed in dolfinx_materials such as the Finite-strain \(\boldsymbol{F}^\text{e}\boldsymbol{F}^\text{p}\) plasticity — dolfinx_materials
It is not yet implemented with external operators but shares a similar implementation. The main difference is that we define a custom JVP rule for the Newton solver using implicit differentiation rather than AD of unrolled Newton iterations, see also https://bleyerj.github.io/teaching/Bleyer_advanced_computational_methods_2025.pdf slides 41 to 45

3 Likes

Dear @Nadir.Kopic

@bleyerj is right! It looks like passing AD through the Newton loop is quite expensive. Using the implicit function theorem (IFT) will improve performance.

Actually, there is an implementation of the Mohr-Coulomb tutorial, where we apply IFT and indeed obtain a better performance. You may find this implementation in supplementary materials to the paper on external operators (see appendix A.1), which will be published soon in JTCAM. In supplementary materials, you need to find the file demo_plasticity_mohr_coulomb_mpi.py, which contains this more advanced implementation with IFT.

Since you are familiar with our tutorials on external operators, it will be very straightforward for you to understand the script.

Essentially, you need these lines of code

def C_tang_local(sigma_local, dlambda_local, deps_local, sigma_n_local):
    # implicit function theorem and automatic differentiation:
    # http://implicit-layers-tutorial.org/implicit_functions/ 
    y_local = jnp.c_["0,1,-1", sigma_local, dlambda_local]
    j = drdy(y_local, deps_local, sigma_n_local)
    return jnp.linalg.inv(j)[:4,:4] @ C_elas

return_mapping_vec = jax.jit(jax.vmap(return_mapping, in_axes=(0, 0)))
C_tang_vec = jax.jit(jax.vmap(C_tang_local, in_axes=(0, 0, 0, 0)))

def C_tang_impl(deps):
    deps_ = deps.reshape((-1, stress_dim))
    sigma_n_ = sigma_n.x.array.reshape((-1, stress_dim))

    (sigam_global, state) = return_mapping_vec(deps_, sigma_n_)
    sigma_global, niter, yielding, norm_res, dlambda = state

    C_tang_global = C_tang_vec(sigma_global, dlambda, deps_, sigma_n_)

    unique_iters, counts = jnp.unique(niter, return_counts=True)

    max_yielding = jnp.max(yielding)
    max_yielding_rank = MPI.COMM_WORLD.allreduce((max_yielding, MPI.COMM_WORLD.rank), op=MPI.MAXLOC)[1]

    if MPI.COMM_WORLD.rank == max_yielding_rank:
        print("\tInner Newton summary:", flush=True)
        print(f"\t\tUnique number of iterations: {unique_iters}", flush=True)
        print(f"\t\tCounts of unique number of iterations: {counts}", flush=True)
        print(f"\t\tMaximum f: {max_yielding}", flush=True)
        print(f"\t\tMaximum residual: {jnp.max(norm_res)}", flush=True)

    return C_tang_global.reshape(-1), sigma_global.reshape(-1)

that replace the code from the original tutorials, where we apply AD to the Newton algorithm:

dsigma_ddeps = jax.jacfwd(return_mapping, has_aux=True)

Consider applying IFT to your dPK2dGLS_local.

I hope this information will help you improve performance. :slightly_smiling_face:

Cheers,
Andrey

I have an additional remark on your implementation concerning the use of NonlinearProblemWithCallback.

The function constitutive_update will be called on every iteration of the Newton solver, or more explicitly, on this line

    num_its, converged = solver.solve(u)  # Solve the problem

The motivation behind the classes NonlinearProblemWithCallback and PETScNonlinearProblem is that we want to make additional computations on every iteration. For example, updating the history variables which is the case of nonlinear solid mechanics problems (see how we update cumulative plastic strain increment dp in the von Mises tutorial).

In this regard, are you sure that you don’t need to update your history variables in the constitutive_update function?

Another remark, in your code

    PK2_coeff, _ = evaluate_external_operators(J_external_operators, evaluated_operands)
    PK2_nye.ref_coefficient.x.array[:] = PK2_coeff

normally evaluate_external_operators(J_external_operators, evaluated_operands) returns the output of the exof_tangent function, the first output of which is values of the derivative dPK2dGLS_. So basically, you do PK2_nye.ref_coefficient.x.array[:] = dPK2dGLS_.reshape(-1), which is not supposed to be correct.

Probably, it is not clear, but in the plasticity tutorials, to get the derivative of the external operator, we have to pass through the return-mapping procedure that results in computing the external operator itself. That’s why, to save numerical resources, the external function returns values for both the external operator and its derivatives (e.g. like here Plasticity of Mohr-Coulomb with apex-smoothing — External operators within FEniCSx/DOLFINx).

At the same time, the first argument of the external function is the values of an appropriate external operator, and they will be used automatically by the algorithm via evaluate_external_operators, which is why we typically ignore this output in our code.

After saying this, I’d change your actual code so that PK2_local should return jaxt2n(PK2, 1) twice and then you have to catch those values throughout the rest of your code (and you don’t need exof_stress). Basically, you have to reproduce the same structure of the code as in the Mohr-Coulomb tutorial. Please, take a look at it again and track how we manipulate the data throughout the code.

Hi everyone,

@bleyerj utilizing implicit differentiation improved the performance massively. I was familiar with the method, but didn’t expect it would have such a huge impact on the performance. Thank you for highlighting it!

@a-latyshev Thank you for the additional implementation example! Regarding your questions.

I don’t think it’s necessary to update the history variables for each global newton iteration. I do it once the global residuum has converged. I didn’t experience any convergence issues (local and global) in any simulations I ran. The reason could be on how we treat the history part of our models. While you consider a rate formulation of the constitutive law in your examples, I declare the plastic stretch as a history variable and proceed to formulate the constitutive relationships in a more conventional way. Still, I might be wrong. I will take a more detailed look into the literature in the near future :smiley: .

Regarding your second suggestion, I have already tried this with hyperelasticity, yet errors occurred. Here is how I did it:

import numpy as np

import jax
import jax.lax
import jax.numpy as jnp

from mpi4py import MPI

from petsc4py import PETSc

from dolfinx import fem, mesh

import ufl

import basix

from dolfinx_external_operator import (
    FEMExternalOperator,
    evaluate_external_operators,
    evaluate_operands,
    replace_external_operators,
)

from solvers import NonlinearProblemWithCallback
from dolfinx.nls.petsc import NewtonSolver

jax.config.update("jax_enable_x64", True)   

def jaxt2n(A, x):
    return jnp.array([A[0,0], A[1,1], A[2,2], x * A[0,1], x * A[0,2], x * A[1,2]])

def jaxn2t(v, x):
    return jnp.array([[v[0], v[3] / x, v[4] / x],
                      [v[3] / x, v[1], v[5] / x],
                      [v[4] / x, v[5] / x, v[2]]])

length = 1.0
width  = 1.0
height = 1.0 
N = 10
domain = mesh.create_box(MPI.COMM_WORLD, [[0.0, 0.0, 0.0], [length, width, height]], 
                         [N, N, N], mesh.CellType.hexahedron)

gdim = domain.topology.dim 
fdim = gdim - 1             

domain.topology.create_connectivity(fdim, gdim)   

def xBot(x):
    return np.isclose(x[0], 0)
def xTop(x):
    return np.isclose(x[0], length)
def yBot(x):
    return np.isclose(x[1], 0)
def yTop(x):
    return np.isclose(x[1], width)
def zBot(x):
    return np.isclose(x[2], 0)
def zTop(x):
    return np.isclose(x[2], height)
    
boundaries = [(1, xBot),(2,xTop),(3,yBot),(4,yTop),(5,zBot),(6,zTop)]

facet_indices, facet_markers = [], []                     
for (marker, locator) in boundaries:
    facets = mesh.locate_entities(domain, fdim, locator)  
                                                          
    facet_indices.append(facets)                          
    facet_markers.append(np.full_like(facets, marker))    

facet_indices = np.hstack(facet_indices).astype(np.int32)
facet_markers = np.hstack(facet_markers).astype(np.int32)
sorted_facets = np.argsort(facet_indices)
 
facet_tags = mesh.meshtags(domain, fdim, facet_indices[sorted_facets], facet_markers[sorted_facets])


V3 = basix.ufl.element("Lagrange", domain.basix_cell(), 1, shape=(3,))                
T3 = basix.ufl.quadrature_element(domain.basix_cell(), degree=2, value_shape=(6,))      
V  = fem.functionspace(domain, V3)
T  = fem.functionspace(domain, T3)

u = fem.Function(V, name="Displacement")
Id = ufl.Identity(3)
F = Id + ufl.grad(u) 
RCG = F.T * F
GLS = 0.5 * (RCG - Id) 
GLS_nye = ufl.as_vector([GLS[0,0], GLS[1,1], GLS[2,2], 2.0 * GLS[0,1], 2.0 * GLS[0,2], 2.0 * GLS[1,2]])

PK2 = FEMExternalOperator(GLS_nye, function_space=T)     

def psi_local(GLS_nye):
    E = 200e3                                   
    nu = 0.3                                    
    lmbda = E * nu / (1 - 2 * nu) / (1 + nu)   
    mu = E / 2 / (1 + nu)                       
    GLS = jaxn2t(GLS_nye, 2)
    Id = jnp.identity(3)
    RCG = 2.0 * GLS + Id
    I1 = jnp.trace(RCG)
    I3 = jnp.linalg.det(RCG)  
    return mu / 2 * (I1 - 3 - jnp.log(I3)) + lmbda / 4 * (I3 - 1 - jnp.log(I3))

def PK2_local(GLS_nye):    
    return jax.grad(psi_local)(GLS_nye), jax.grad(psi_local)(GLS_nye)

dPK2dGLS_local = jax.jacfwd(PK2_local, has_aux=True) 

PK2_global = jax.jit(jax.vmap(PK2_local, in_axes=0))

dPK2dGLS_global = jax.jit(jax.vmap(dPK2dGLS_local, in_axes=0))

def exof_tangent(GLS_glo):
    GLS_ = GLS_glo.reshape((-1, 6))
    dPK2dGLS_, PK2_ = dPK2dGLS_global(GLS_)
    return dPK2dGLS_.reshape(-1), PK2_.reshape(-1)

def PK2_external(derivatives):
    if derivatives == (1,):
        return exof_tangent
    else:
        return NotImplementedError

PK2.external_function = PK2_external

dx = ufl.Measure('dx', domain=domain, metadata={'quadrature_degree': 2, 'quadrature_rule':'default'})  # Volume integration measure 

v = ufl.TestFunction(V)
virGLS = ufl.sym(F.T * ufl.grad(v))
virGLS_nye = ufl.as_vector([virGLS[0,0], virGLS[1,1], virGLS[2,2], 2.0 * virGLS[0,1], 2.0 * virGLS[0,2], 2.0 * virGLS[1,2]])

du = ufl.TrialFunction(V)  

R = ufl.inner(PK2, virGLS_nye) * dx
J = ufl.derivative(R, u, du) 
J_expanded = ufl.algorithms.expand_derivatives(J)

R_replaced, R_external_operators = replace_external_operators(R)
J_replaced, J_external_operators = replace_external_operators(J_expanded)

R_form = fem.form(R_replaced)
J_form = fem.form(J_replaced)

xBot_ux_dofs = fem.locate_dofs_topological(V.sub(0), facet_tags.dim, facet_tags.find(1))
yBot_uy_dofs = fem.locate_dofs_topological(V.sub(1), facet_tags.dim, facet_tags.find(3))
zBot_uz_dofs = fem.locate_dofs_topological(V.sub(2), facet_tags.dim, facet_tags.find(5))
xTop_ux_dofs = fem.locate_dofs_topological(V.sub(0), facet_tags.dim, facet_tags.find(2))

ux_disp = fem.Constant(domain, 0.0)     

bcs_1 = fem.dirichletbc(0.0, xBot_ux_dofs, V.sub(0))      
bcs_2 = fem.dirichletbc(0.0, yBot_uy_dofs, V.sub(1))      
bcs_3 = fem.dirichletbc(0.0, zBot_uz_dofs, V.sub(2))      
bcs_4 = fem.dirichletbc(ux_disp, xTop_ux_dofs, V.sub(0))  

bcs = [bcs_1, bcs_2, bcs_3, bcs_4]

def constitutive_update():
    evaluated_operands = evaluate_operands(J_external_operators)
    ((_, PK2_coeff), ) = evaluate_external_operators(J_external_operators, evaluated_operands)
    PK2.ref_coefficient.x.array[:] = PK2_coeff
 
problem = NonlinearProblemWithCallback(R_replaced, u, bcs=bcs, J=J_replaced, external_callback=constitutive_update)

solver = NewtonSolver(domain.comm, problem)
solver.max_it = 200
solver.rtol = 1e-8
solver.convergence_criterion = "incremental"
ksp = solver.krylov_solver
opts = PETSc.Options()  # type: ignore
option_prefix = ksp.getOptionsPrefix()
opts[f"{option_prefix}ksp_type"] = "preonly"
opts[f"{option_prefix}pc_type"] = "lu"
opts[f"{option_prefix}pc_factor_mat_solver_type"] = "mumps"
ksp.setFromOptions()

ux_disp_max = 0.1                                 
Nsteps = 100                                      

u.x.array[:] = np.finfo(PETSc.ScalarType).eps   

# Loading loop
for n, disp in enumerate(np.linspace(0, ux_disp_max, Nsteps + 1)[1:]):

    ux_disp.value = disp    
    
    num_its, converged = solver.solve(u) 
    assert converged                     

    print(f"Time step {n}, Displacement {disp}, Number od iterations {num_its} .")

    u.x.scatter_forward()    

This is the error message I get:

Traceback (most recent call last):
  File "/home/nadir/FEM/cam-fenicsx/elasticity/singelem_hypelas_exo_batched.py", line 210, in <module>
    num_its, converged = solver.solve(u)
                         ~~~~~~~~~~~~^^^
  File "/home/nadir/miniconda3/envs/fenicsx-envnew/lib/python3.13/site-packages/dolfinx/nls/petsc.py", line 51, in solve
    n, converged = super().solve(u.x.petsc_vec)
                   ~~~~~~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/nadir/FEM/cam-fenicsx/elasticity/solvers.py", line 113, in form
    self.external_callback()
    ~~~~~~~~~~~~~~~~~~~~~~^^
  File "/home/nadir/FEM/cam-fenicsx/elasticity/singelem_hypelas_exo_batched.py", line 181, in constitutive_update
    ((_, PK2_coeff), ) = evaluate_external_operators(J_external_operators, evaluated_operands)
                         ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/nadir/dolfinx-external-operator/src/dolfinx_external_operator/external_operator.py", line 169, in evaluate_external_operators
    np.copyto(external_operator.ref_coefficient.x.array, external_operator_eval)
    ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot cast scalar from dtype('O') to dtype('float64') according to the rule 'same_kind'

The same approach worked for small-strain elasticity. Could it be due to the non-linearity of the operand? We already discussed this once. Your suggestion was to define separate functions for the stress and tangent.

Best regards
Nadir