Dolfinx_external_operators _replace_form returns a FormSum instead of Form

Hi everyone!

I am trying to learn to use dolfinx_external_operators so that I can add some spring-like objects to a mechanical model of beams (Timoshenko beams using a mixed element formulation). I have a simplified version of the problem below with the same degrees of freedom but with a simplified energy function.

When I go to set up the solver objects, the code breaks when running this line (here, F, is the function I am trying to solve; F=0):

F_replaced, F_external_operators = replace_external_operators(F)

The error occurs because in _replace_form() (see below), ufl.algorithms.replace is returning a FormSum instead of a Form. This then breaks the next part of the external_operators code because it tries to pull out the base_form_operators from the FormSum and raises this error:
'FormSum' object has no attribute 'base_form_operators'.

Is this happening because of how I am defining the external operator? Is this maybe related to having the external operator in the energy function that is then being differentiated instead of directly using it in the residual? Any help would be greatly appreciated! Thanks!

Code to replicate the error:

import numpy as np
import gmsh
import ufl
import ufl.algorithms
import basix
from mpi4py import MPI
from dolfinx import mesh, fem, io
import dolfinx
import dolfinx.fem.petsc as petsc
from dolfinx.nls.petsc import NewtonSolver
from dolfinx_external_operator import (
    FEMExternalOperator,
    evaluate_external_operators,
    evaluate_operands,
    replace_external_operators,
)
from petsc4py import PETSc

import jax

# enables the work with double precision under float64
jax.config.update("jax_enable_x64", True)



#################################################################
# region        CONSTRUCT GEOMETRY                              #
#################################################################

degree = 3 # cubic basic functions
quad_degree = 4*degree
N = 30 # number of elements per branch

char_len = 1/(N-1)
d = 1
# Initialize gmsh
gmsh.initialize()

gmsh.option.setNumber("General.Terminal", 0)

# Creating points and lines (look into add_bezier for nicer curves)
p0 = gmsh.model.occ.addPoint(0, 0, 0, d)
p1 = gmsh.model.occ.addPoint(0, 1, 0, d)
l1 = gmsh.model.occ.addLine(p0, p1)
gmsh.model.occ.synchronize()


# add_physical_group so we can set up subdomains with different properties
# inputs (topological dimension, [list of object], tag)
gmsh.model.add_physical_group(1, [l1], 1)

gmsh.option.setNumber("Mesh.CharacteristicLengthMax",char_len)

gmsh.model.mesh.generate(dim=1)
domain, _, _ = io.gmshio.model_to_mesh(gmsh.model, MPI.COMM_WORLD, 0)


# Finalize gmsh
gmsh.finalize()

gdim = domain.geometry.dim
tdim = domain.topology.dim
print(gdim,tdim)

#################################################################
# region        CONSTRUCT REFERENCE CONFIGURATION               #
#################################################################

dx_dX = ufl.Jacobian(domain)[:, 0]
t = dx_dX 
t /= ufl.sqrt(ufl.dot(t,t)) # uncomment to normalize tangent vector
ez = ufl.as_vector([0, 0, 1])
a1 = ufl.cross(ez, t)
a1 /= ufl.sqrt(ufl.dot(a1, a1))
a2 = ufl.cross(t, a1)
a2 /= ufl.sqrt(ufl.dot(a2, a2))

def dds(u):
    # use this to get derivatives along the rest configuration of the beam (d/ds)
    # and convert it to the local frame where the tangent direction is e3
    return ufl.dot(ufl.grad(u), t)



#################################################################
# region        CONSTRUCT ELEMENTS AND FUNCTION SPACES          #
#################################################################


Ue = basix.ufl.element("P", domain.basix_cell(), degree, shape=(gdim,)) # here "P" means Lagrange polynomial elements (here cubic)
U_EO = basix.ufl.quadrature_element(domain.topology.cell_name(),degree=quad_degree,value_shape=()) # quadrature elements for external operator

W = fem.functionspace(domain, basix.ufl.mixed_element([Ue, Ue]))
W_EO = fem.functionspace(domain,U_EO) # function space for the external operator

dstate = ufl.TestFunction(W)
dr, dtheta = ufl.split(dstate)

state = fem.Function(W)
r, theta = ufl.split(state)     # r: position vector, theta: incremental rotation vector


# setting up initial condition
def r0_func(x):
    return x

state.sub(0).interpolate(r0_func)
state.x.scatter_forward()


#################################################################
# region    GET DEGREES OF FREEDOM FOR BCS AND SPRING ENERGY    #
#################################################################

def left_tip(x):
    return np.logical_and(np.isclose(x[0], 0), np.isclose(x[1], 0))

def right_tip(x):
    return np.logical_and(np.isclose(x[0], 0), np.isclose(x[1], 1))

fdim = domain.topology.dim -1 
left_facets = mesh.locate_entities_boundary(domain,fdim,left_tip)
right_facets = mesh.locate_entities_boundary(domain,fdim,right_tip)

marked_facets = np.hstack([left_facets,right_facets])

marked_values = np.hstack([np.full(len(left_facets),0,dtype=np.int32),
                           np.full(len(right_facets),1,dtype=np.int32)])

sorted_facets = np.argsort(marked_facets)
facet_tag = mesh.meshtags(domain,fdim,marked_facets[sorted_facets],marked_values[sorted_facets])


# integrators
metadata={"quadrature_scheme":"default", "quadrature_degree": quad_degree}
ds = ufl.Measure('ds',domain=domain,subdomain_data=facet_tag,metadata=metadata)
dx = ufl.Measure('dx',domain=domain,subdomain_data=facet_tag,metadata=metadata)

# get sub domains of the different elements of the function space
Vr, Vr_to_W = W.sub(0).collapse()
Vt, Vt_to_W = W.sub(1).collapse()

# get degrees of freedom for each possible boundary condition
r_left_dofs = fem.locate_dofs_geometrical((W.sub(0), Vr), left_tip)
theta_left_dofs = fem.locate_dofs_geometrical((W.sub(1), Vt), left_tip)

r_right_dofs = fem.locate_dofs_geometrical((W.sub(0), Vr), right_tip)
theta_right_dofs = fem.locate_dofs_geometrical((W.sub(1), Vt), right_tip)

r0 = fem.Function(Vr)
r0.interpolate(lambda x: x)

theta0 = fem.Function(Vt)
theta0.interpolate(lambda x: 0*x)

# Boundary conditions: no displacement, no rotation
bcs = [fem.dirichletbc(r0, r_right_dofs, W.sub(0)), 
       fem.dirichletbc(theta0, theta_right_dofs, W.sub(1))]


#################################################################
# region   SETTING UP ENERGY EXPRESSION AND EXTERNAL OPERATOR   #
#################################################################
Psi = ufl.dot(dds(r), dds(r)) + ufl.dot(dds(theta), dds(theta))

One = fem.Constant(domain, dolfinx.default_scalar_type(1.0))
V = fem.assemble_scalar(fem.form(One * dx)) 


# ENERGY FUNCTION FOR EXTERNAL OPERATOR
def spring(r):
    xyz_left = r[r_left_dofs[1]]

    L = ( xyz_left[0] )**2

    n_DOFs = len(r)/3
    sef = L * np.ones(n_DOFs) / V # spread the value of the integral over the full domain
    return sef.reshape(-1)

dspring_dr = jax.jacfwd(spring,argnums=(0)) # AD derivative of spring strain energy function
ddspring_dr = jax.jacfwd(dspring_dr,argnums=(0)) # AD second derivative of spring strain energy function

def spring_external(derivatives):
    if derivatives == (0,):  # no derivation, the function itself
        return spring
    elif derivatives == (1,):  # the derivative with respect to the operand `T`
        return dspring_dr
    elif derivatives == (2,):
        return ddspring_dr
    else:
        raise NotImplementedError(f"No external function is defined for the requested derivative {derivatives}.")
    
spring_energy = FEMExternalOperator(r,function_space=W_EO)
spring_energy.external_function = spring_external


Psi += spring_energy    # add the effects of the spring

Psi *= dx   # integrating over the domain



#################################################################
# region     SETTING UP OBJECTS FOR SOLVING                     #
#################################################################

F = ufl.derivative(Psi,state,dstate)


state_hat = ufl.TrialFunction(W)
J = ufl.derivative(F, state, state_hat)
J_expanded = ufl.algorithms.expand_derivatives(J)
J_replaced, J_external_operators = replace_external_operators(J_expanded)

##### ERROR OCCURS HERE: #######
F_replaced, F_external_operators = replace_external_operators(F)


evaluated_operands = evaluate_operands(F_external_operators)

_ = evaluate_external_operators(F_external_operators, evaluated_operands)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)

F_compiled = fem.form(F_replaced)
J_compiled = fem.form(J_replaced)

b_vector = fem.assemble_vector(F_compiled)
A_matrix = fem.assemble_matrix(J_compiled)

Function that generates the issue in external_operators

# in external_operators.py
def _replace_form(form: ufl.Form):
    external_operators = form.base_form_operators()
    ex_ops_map = {ex_op: ex_op.ref_coefficient for ex_op in external_operators}
    replaced_form = ufl.algorithms.replace(form, ex_ops_map)
    return replaced_form, external_operators

@micBenn What are the versions of dolfinx and dolfinx_external_operators you use?

Hi @a-latyshev ! I am currently using fenics-dolfinx v0.9.0 and fenics-ufl v2024.2.0. And I am running everything on Ubuntu (not using Docker).

I haven’t run your code yet but I think the issue is related to this one External operators are not detectable in sums of forms · Issue #24 · a-latyshev/dolfinx-external-operator · GitHub, which was fixed against one of previous nightly versions of dolfinx, so I think you are on dolfinx_external_operators v0.9.0.

You can try a container with dolfinx-nightly and the external operators from the current main.

I will try to make a patch for v0.9.0.

I doubled-checked and it looks like I am running dolfinx_external_operators v0.10.0.dev0 (installed with pip). I will see if I can run the code in the nightly version of dolfinx. Thanks!

Ok, I think I was able to side-step the issue with the following:

  1. First, I directly installed dolfinx_external_operator v0.9.0 (I originally had the development version). This resolved the initial issue, but really just pushed it down the road because a different structure was then a FormSum instead of Form.
  2. Next, I suspected some of the issue may be coming from the directional derivative of the external operator. So, instead of including the external form in the energy function, I took its derivative with jax and applied it like you would a traction.
# ENERGY FUNCTION FOR EXTERNAL OPERATOR
def spring(r):
    xyz_left = r[0] # degrees of freedom of the first element 

    L = np.sum( (xyz_left[0:3] )**2 )

    # n_DOFs = 
    sef = L / V # spread the value of the integral over the full domain
    return sef #.reshape(-1)

dspring_dr = jax.jacfwd(spring,argnums=(0)) # AD derivative of spring strain energy function
ddspring_dr = jax.jacfwd(dspring_dr,argnums=(0)) # AD derivative of spring strain energy function

# what if we just pass the derivatives instead?
def spring_external(derivatives):
    if derivatives == (0,):  # no derivation, the function itself
        return lambda x: dspring_dr(x).flatten()
    elif derivatives == (1,):  # the derivative with respect to the operand `T`
        return lambda x: ddspring_dr(x)
    else:
        raise NotImplementedError(f"No external function is defined for the requested derivative {derivatives}.")
    
external operator that returns the "force" from the spring
spring_energy = FEMExternalOperator(r,function_space=W_EO)
spring_energy.external_function = spring_external

# traction contributions
F2 = ufl.dot(spring_energy, dr) * dx

However, if I still took the directional derivative of the energy function (defined entirely in UFL) and tried to add the tractions (in the code below, by setting method=1), I would get an error saying there were different Arguments with the same number and part.

  1. If instead I first took the derivative of the energy function and then integrated (method=0 below), then things proceed properly until the external operators are being evaluated. At that point, an error comes up because of shape mismatches in the derivatives (but this, I think, is my misunderstanding of how to keep jax derivative local – the output shape is the shape you would expect if you were taking the jacobian of the entire DOF vector, not just the local stiffness tensor).
#################################################################
# region     SETTING UP OBJECTS FOR SOLVING                     #
#################################################################

method = 1 # 1: directional derivative of energy function, 0: differentiation of the energy function before integration

if method==0:
    dF = ufl.diff(Psi, state)
    F1 = ufl.dot(dF, dstate) * dx  #ufl.derivative(Psi,state,dstate)
    
elif method==1:
    Psi *= dx   # integrating over the domain
    F1 = ufl.derivative(Psi,state,dstate)


F = F1 + F2

state_hat = ufl.TrialFunction(W)
J = ufl.derivative(F, state)

### for method==1, the error occurs at this line
J_expanded = ufl.algorithms.expand_derivatives(J)
J_replaced, J_external_operators = replace_external_operators(J_expanded)

F_replaced, F_external_operators = replace_external_operators(F)
evaluated_operands = evaluate_operands(F_external_operators)

_ = evaluate_external_operators(F_external_operators, evaluated_operands)
### if method==0, the error occurs at this line
_ = evaluate_external_operators(J_external_operators, evaluated_operands)

F_compiled = fem.form(F_replaced)
J_compiled = fem.form(J_replaced)

b_vector = fem.assemble_vector(F_compiled)
A_matrix = fem.assemble_matrix(J_compiled)

Complete code for producing the errors discussed in (2) and (3):

import numpy as np
import gmsh
import ufl
import ufl.algorithms
import basix
from mpi4py import MPI
from dolfinx import mesh, fem, io
import dolfinx
import dolfinx.fem.petsc as petsc
from dolfinx.nls.petsc import NewtonSolver
from dolfinx_external_operator import (
    FEMExternalOperator,
    evaluate_external_operators,
    evaluate_operands,
    replace_external_operators,
)
from petsc4py import PETSc

import jax

# enables the work with double precision under float64
jax.config.update("jax_enable_x64", True)



#################################################################
# region        CONSTRUCT GEOMETRY                              #
#################################################################

degree = 1 # cubic basic functions
quad_degree = degree
N = 30 # number of elements per branch

char_len = 1/(N-1)
d = 1
# Initialize gmsh
gmsh.initialize()

gmsh.option.setNumber("General.Terminal", 0)

# Creating points and lines (look into add_bezier for nicer curves)
p0 = gmsh.model.occ.addPoint(0, 0, 0, d)
p1 = gmsh.model.occ.addPoint(0, 1, 0, d)
l1 = gmsh.model.occ.addLine(p0, p1)
gmsh.model.occ.synchronize()


# add_physical_group so we can set up subdomains with different properties
# inputs (topological dimension, [list of object], tag)
gmsh.model.add_physical_group(1, [l1], 1)

gmsh.option.setNumber("Mesh.CharacteristicLengthMax",char_len)

gmsh.model.mesh.generate(dim=1)
domain, _, _ = io.gmshio.model_to_mesh(gmsh.model, MPI.COMM_WORLD, 0)


# Finalize gmsh
gmsh.finalize()

gdim = domain.geometry.dim
tdim = domain.topology.dim
print(gdim,tdim)

#################################################################
# region        CONSTRUCT REFERENCE CONFIGURATION               #
#################################################################

dx_dX = ufl.Jacobian(domain)[:, 0]
t = dx_dX 
t /= ufl.sqrt(ufl.dot(t,t)) # uncomment to normalize tangent vector
ez = ufl.as_vector([0, 0, 1])
a1 = ufl.cross(ez, t)
a1 /= ufl.sqrt(ufl.dot(a1, a1))
a2 = ufl.cross(t, a1)
a2 /= ufl.sqrt(ufl.dot(a2, a2))

def dds(u):
    # use this to get derivatives along the rest configuration of the beam (d/ds)
    # and convert it to the local frame where the tangent direction is e3
    return ufl.dot(ufl.grad(u), t)



#################################################################
# region        CONSTRUCT ELEMENTS AND FUNCTION SPACES          #
#################################################################


Ue = basix.ufl.element("P", domain.basix_cell(), degree, shape=(gdim,)) # here "P" means Lagrange polynomial elements (here cubic)
U_EO = basix.ufl.quadrature_element(domain.topology.cell_name(),degree=quad_degree,value_shape=(gdim,)) # quadrature elements for external operator

W = fem.functionspace(domain, basix.ufl.mixed_element([Ue, Ue]))
W_EO = fem.functionspace(domain,U_EO) # function space for the external operator

dstate = ufl.TestFunction(W)
dr, dtheta = ufl.split(dstate)

state = fem.Function(W)
r, theta = ufl.split(state)     # r: position vector, theta: incremental rotation vector

print(state.x.array.shape)
# setting up initial condition
def r0_func(x):
    return x

state.sub(0).interpolate(r0_func)
state.x.scatter_forward()


#################################################################
# region    GET DEGREES OF FREEDOM FOR BCS AND SPRING ENERGY    #
#################################################################

def left_tip(x):
    return np.logical_and(np.isclose(x[0], 0), np.isclose(x[1], 0))

def right_tip(x):
    return np.logical_and(np.isclose(x[0], 0), np.isclose(x[1], 1))

fdim = domain.topology.dim -1 
left_facets = mesh.locate_entities_boundary(domain,fdim,left_tip)
right_facets = mesh.locate_entities_boundary(domain,fdim,right_tip)

marked_facets = np.hstack([left_facets,right_facets])

marked_values = np.hstack([np.full(len(left_facets),0,dtype=np.int32),
                           np.full(len(right_facets),1,dtype=np.int32)])

sorted_facets = np.argsort(marked_facets)
facet_tag = mesh.meshtags(domain,fdim,marked_facets[sorted_facets],marked_values[sorted_facets])


# integrators
metadata={"quadrature_scheme":"default", "quadrature_degree": quad_degree}
ds = ufl.Measure('ds',domain=domain,subdomain_data=facet_tag,metadata=metadata)
dx = ufl.Measure('dx',domain=domain,subdomain_data=facet_tag,metadata=metadata)

# get sub domains of the different elements of the function space
Vr, Vr_to_W = W.sub(0).collapse()
Vt, Vt_to_W = W.sub(1).collapse()

# get degrees of freedom for each possible boundary condition
r_left_dofs = fem.locate_dofs_geometrical((W.sub(0), Vr), left_tip)
theta_left_dofs = fem.locate_dofs_geometrical((W.sub(1), Vt), left_tip)

r_right_dofs = fem.locate_dofs_geometrical((W.sub(0), Vr), right_tip)
theta_right_dofs = fem.locate_dofs_geometrical((W.sub(1), Vt), right_tip)

r0 = fem.Function(Vr)
r0.interpolate(lambda x: x)

theta0 = fem.Function(Vt)
theta0.interpolate(lambda x: 0*x)

# Boundary conditions: no displacement, no rotation
bcs = [fem.dirichletbc(r0, r_right_dofs, W.sub(0)), 
       fem.dirichletbc(theta0, theta_right_dofs, W.sub(1))]



#################################################################
# region   SETTING UP ENERGY EXPRESSION AND EXTERNAL OPERATOR   #
#################################################################
Psi = ufl.dot(dds(r), dds(r)) + ufl.dot(dds(theta), dds(theta))

One = fem.Constant(domain, dolfinx.default_scalar_type(1.0))
V = fem.assemble_scalar(fem.form(One * dx)) 

# ENERGY FUNCTION FOR EXTERNAL OPERATOR
def spring(r):
    xyz_left = r[0] # degrees of freedom of the first element 

    L = np.sum( (xyz_left[0:3] )**2 )

    # n_DOFs = 
    sef = L / V # spread the value of the integral over the full domain
    return sef #.reshape(-1)

dspring_dr = jax.jacfwd(spring,argnums=(0)) # AD derivative of spring strain energy function
ddspring_dr = jax.jacfwd(dspring_dr,argnums=(0)) # AD derivative of spring strain energy function



# what if we just pass the derivatives instead???
def spring_external(derivatives):
    if derivatives == (0,):  # no derivation, the function itself
        return lambda x: dspring_dr(x).flatten()
    elif derivatives == (1,):  # the derivative with respect to the operand `T`
        return lambda x: ddspring_dr(x)
    else:
        raise NotImplementedError(f"No external function is defined for the requested derivative {derivatives}.")
    
spring_energy = FEMExternalOperator(r,function_space=W_EO)
spring_energy.external_function = spring_external
F2 = ufl.dot(spring_energy, dr) * dx

#################################################################
# region     SETTING UP OBJECTS FOR SOLVING                     #
#################################################################

method = 0 # 1: directional derivative of energy function, 0: differentiation of the energy function before integration

if method==0:
    dF = ufl.diff(Psi, state)
    F1 = ufl.dot(dF, dstate) * dx  #ufl.derivative(Psi,state,dstate)
    
elif method==1:
    Psi *= dx   # integrating over the domain
    F1 = ufl.derivative(Psi,state,dstate)


F = F1 + F2

state_hat = ufl.TrialFunction(W)
J = ufl.derivative(F, state)
J_expanded = ufl.algorithms.expand_derivatives(J)
J_replaced, J_external_operators = replace_external_operators(J_expanded)

F_replaced, F_external_operators = replace_external_operators(F)
evaluated_operands = evaluate_operands(F_external_operators)

_ = evaluate_external_operators(F_external_operators, evaluated_operands)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)

F_compiled = fem.form(F_replaced)
J_compiled = fem.form(J_replaced)

b_vector = fem.assemble_vector(F_compiled)
A_matrix = fem.assemble_matrix(J_compiled)

If you have any suggestions or examples about how to set up a Python energy function that is then differentiated with jax to produce a traction vector and stiffness tensor, that would be greatly appreciated!

I don’t understand exactly how you define the spring energy, but from your formulation, you expect that it has three components since :

F2 = ufl.dot(spring_energy, dr) * dx

so I augmented the computation of your spring energy with extra components. Change this accordingly

import jax.numpy as jnp
# ENERGY FUNCTION FOR EXTERNAL OPERATOR
def spring(r):
    # r here is supposed to be local, i.e. on a quadrature points
    # xyz_left = r[0] # degrees of freedom of the first element 

    L = np.sum(r)**2

    # n_DOFs = 
    sef = L / V # spread the value of the integral over the full domain
    
    return jnp.array([sef, 0.0, 0.0])

Then. The correct vectorisation with JAX is the following

dspring_dr = jax.jacfwd(spring,argnums=(0)) # AD derivative of spring strain energy function
spring_vec = jax.vmap(spring, in_axes=(0))
dspring_dr_vec = jax.vmap(dspring_dr, in_axes=(0))

Then you need to correctly reshape the numpy vector r_global in the following block, which is a projection of your r from a part of the mixed space onto the quadrature space U_EO.

def spring_impl(r_global):
    r_ = r_global.reshape(-1, 3)
    return spring_vec(r_).reshape(-1)

def dspring_dr_impl(r_global):
    r_ = r_global.reshape(-1, 3)
    return dspring_dr_vec(r_).reshape(-1)

Here, I suppose that you work with energy per quadrature point. If your algorithm requires all DOFs, you don’t need the vectorisation and have to change all functions accordingly. Just pay attention to the shapes of the numpy arrays. Basically, r_global has the size = num_quadratures * mathematical size of r.

As for differentiation, since the bug with sums of forms was spotted after publishing external operators v.0.9.0, I am not surprised that there is still an issue in the old version. It seems that you bypassed it with the method == 0, well done! Another workaround, for the method == 1, is outlined below: just apply ufl.algorithms.expand_derivatives to the derivative prior to sumation.

method = 0 # 1: directional derivative of energy function, 0: differentiation of the energy function before integration

if method==0:
    dF = ufl.diff(Psi, state)
    F1 = ufl.dot(dF, dstate) * dx  #ufl.derivative(Psi,state,dstate)
    
elif method==1:
    Psi *= dx   # integrating over the domain
    F1 = ufl.algorithms.expand_derivatives(ufl.derivative(Psi,state,state))

Both methods with the correct vectorisation should work.

Thanks so much! I will keep the expand_derivatives version in mind, because I would like to be able to define everything in terms of energy functions. This has been very helpful!

1 Like