Hi everyone!
I am trying to learn to use dolfinx_external_operators so that I can add some spring-like objects to a mechanical model of beams (Timoshenko beams using a mixed element formulation). I have a simplified version of the problem below with the same degrees of freedom but with a simplified energy function.
When I go to set up the solver objects, the code breaks when running this line (here, F, is the function I am trying to solve; F=0):
F_replaced, F_external_operators = replace_external_operators(F)
The error occurs because in _replace_form() (see below), ufl.algorithms.replace is returning a FormSum instead of a Form. This then breaks the next part of the external_operators code because it tries to pull out the base_form_operators from the FormSum and raises this error:
'FormSum' object has no attribute 'base_form_operators'.
Is this happening because of how I am defining the external operator? Is this maybe related to having the external operator in the energy function that is then being differentiated instead of directly using it in the residual? Any help would be greatly appreciated! Thanks!
Code to replicate the error:
import numpy as np
import gmsh
import ufl
import ufl.algorithms
import basix
from mpi4py import MPI
from dolfinx import mesh, fem, io
import dolfinx
import dolfinx.fem.petsc as petsc
from dolfinx.nls.petsc import NewtonSolver
from dolfinx_external_operator import (
FEMExternalOperator,
evaluate_external_operators,
evaluate_operands,
replace_external_operators,
)
from petsc4py import PETSc
import jax
# enables the work with double precision under float64
jax.config.update("jax_enable_x64", True)
#################################################################
# region CONSTRUCT GEOMETRY #
#################################################################
degree = 3 # cubic basic functions
quad_degree = 4*degree
N = 30 # number of elements per branch
char_len = 1/(N-1)
d = 1
# Initialize gmsh
gmsh.initialize()
gmsh.option.setNumber("General.Terminal", 0)
# Creating points and lines (look into add_bezier for nicer curves)
p0 = gmsh.model.occ.addPoint(0, 0, 0, d)
p1 = gmsh.model.occ.addPoint(0, 1, 0, d)
l1 = gmsh.model.occ.addLine(p0, p1)
gmsh.model.occ.synchronize()
# add_physical_group so we can set up subdomains with different properties
# inputs (topological dimension, [list of object], tag)
gmsh.model.add_physical_group(1, [l1], 1)
gmsh.option.setNumber("Mesh.CharacteristicLengthMax",char_len)
gmsh.model.mesh.generate(dim=1)
domain, _, _ = io.gmshio.model_to_mesh(gmsh.model, MPI.COMM_WORLD, 0)
# Finalize gmsh
gmsh.finalize()
gdim = domain.geometry.dim
tdim = domain.topology.dim
print(gdim,tdim)
#################################################################
# region CONSTRUCT REFERENCE CONFIGURATION #
#################################################################
dx_dX = ufl.Jacobian(domain)[:, 0]
t = dx_dX
t /= ufl.sqrt(ufl.dot(t,t)) # uncomment to normalize tangent vector
ez = ufl.as_vector([0, 0, 1])
a1 = ufl.cross(ez, t)
a1 /= ufl.sqrt(ufl.dot(a1, a1))
a2 = ufl.cross(t, a1)
a2 /= ufl.sqrt(ufl.dot(a2, a2))
def dds(u):
# use this to get derivatives along the rest configuration of the beam (d/ds)
# and convert it to the local frame where the tangent direction is e3
return ufl.dot(ufl.grad(u), t)
#################################################################
# region CONSTRUCT ELEMENTS AND FUNCTION SPACES #
#################################################################
Ue = basix.ufl.element("P", domain.basix_cell(), degree, shape=(gdim,)) # here "P" means Lagrange polynomial elements (here cubic)
U_EO = basix.ufl.quadrature_element(domain.topology.cell_name(),degree=quad_degree,value_shape=()) # quadrature elements for external operator
W = fem.functionspace(domain, basix.ufl.mixed_element([Ue, Ue]))
W_EO = fem.functionspace(domain,U_EO) # function space for the external operator
dstate = ufl.TestFunction(W)
dr, dtheta = ufl.split(dstate)
state = fem.Function(W)
r, theta = ufl.split(state) # r: position vector, theta: incremental rotation vector
# setting up initial condition
def r0_func(x):
return x
state.sub(0).interpolate(r0_func)
state.x.scatter_forward()
#################################################################
# region GET DEGREES OF FREEDOM FOR BCS AND SPRING ENERGY #
#################################################################
def left_tip(x):
return np.logical_and(np.isclose(x[0], 0), np.isclose(x[1], 0))
def right_tip(x):
return np.logical_and(np.isclose(x[0], 0), np.isclose(x[1], 1))
fdim = domain.topology.dim -1
left_facets = mesh.locate_entities_boundary(domain,fdim,left_tip)
right_facets = mesh.locate_entities_boundary(domain,fdim,right_tip)
marked_facets = np.hstack([left_facets,right_facets])
marked_values = np.hstack([np.full(len(left_facets),0,dtype=np.int32),
np.full(len(right_facets),1,dtype=np.int32)])
sorted_facets = np.argsort(marked_facets)
facet_tag = mesh.meshtags(domain,fdim,marked_facets[sorted_facets],marked_values[sorted_facets])
# integrators
metadata={"quadrature_scheme":"default", "quadrature_degree": quad_degree}
ds = ufl.Measure('ds',domain=domain,subdomain_data=facet_tag,metadata=metadata)
dx = ufl.Measure('dx',domain=domain,subdomain_data=facet_tag,metadata=metadata)
# get sub domains of the different elements of the function space
Vr, Vr_to_W = W.sub(0).collapse()
Vt, Vt_to_W = W.sub(1).collapse()
# get degrees of freedom for each possible boundary condition
r_left_dofs = fem.locate_dofs_geometrical((W.sub(0), Vr), left_tip)
theta_left_dofs = fem.locate_dofs_geometrical((W.sub(1), Vt), left_tip)
r_right_dofs = fem.locate_dofs_geometrical((W.sub(0), Vr), right_tip)
theta_right_dofs = fem.locate_dofs_geometrical((W.sub(1), Vt), right_tip)
r0 = fem.Function(Vr)
r0.interpolate(lambda x: x)
theta0 = fem.Function(Vt)
theta0.interpolate(lambda x: 0*x)
# Boundary conditions: no displacement, no rotation
bcs = [fem.dirichletbc(r0, r_right_dofs, W.sub(0)),
fem.dirichletbc(theta0, theta_right_dofs, W.sub(1))]
#################################################################
# region SETTING UP ENERGY EXPRESSION AND EXTERNAL OPERATOR #
#################################################################
Psi = ufl.dot(dds(r), dds(r)) + ufl.dot(dds(theta), dds(theta))
One = fem.Constant(domain, dolfinx.default_scalar_type(1.0))
V = fem.assemble_scalar(fem.form(One * dx))
# ENERGY FUNCTION FOR EXTERNAL OPERATOR
def spring(r):
xyz_left = r[r_left_dofs[1]]
L = ( xyz_left[0] )**2
n_DOFs = len(r)/3
sef = L * np.ones(n_DOFs) / V # spread the value of the integral over the full domain
return sef.reshape(-1)
dspring_dr = jax.jacfwd(spring,argnums=(0)) # AD derivative of spring strain energy function
ddspring_dr = jax.jacfwd(dspring_dr,argnums=(0)) # AD second derivative of spring strain energy function
def spring_external(derivatives):
if derivatives == (0,): # no derivation, the function itself
return spring
elif derivatives == (1,): # the derivative with respect to the operand `T`
return dspring_dr
elif derivatives == (2,):
return ddspring_dr
else:
raise NotImplementedError(f"No external function is defined for the requested derivative {derivatives}.")
spring_energy = FEMExternalOperator(r,function_space=W_EO)
spring_energy.external_function = spring_external
Psi += spring_energy # add the effects of the spring
Psi *= dx # integrating over the domain
#################################################################
# region SETTING UP OBJECTS FOR SOLVING #
#################################################################
F = ufl.derivative(Psi,state,dstate)
state_hat = ufl.TrialFunction(W)
J = ufl.derivative(F, state, state_hat)
J_expanded = ufl.algorithms.expand_derivatives(J)
J_replaced, J_external_operators = replace_external_operators(J_expanded)
##### ERROR OCCURS HERE: #######
F_replaced, F_external_operators = replace_external_operators(F)
evaluated_operands = evaluate_operands(F_external_operators)
_ = evaluate_external_operators(F_external_operators, evaluated_operands)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)
F_compiled = fem.form(F_replaced)
J_compiled = fem.form(J_replaced)
b_vector = fem.assemble_vector(F_compiled)
A_matrix = fem.assemble_matrix(J_compiled)
Function that generates the issue in external_operators
# in external_operators.py
def _replace_form(form: ufl.Form):
external_operators = form.base_form_operators()
ex_ops_map = {ex_op: ex_op.ref_coefficient for ex_op in external_operators}
replaced_form = ufl.algorithms.replace(form, ex_ops_map)
return replaced_form, external_operators