@jpdean I’m trying to build on the working example you gave me to the case where several forms each enter the tabulation function. Here is a slightly modified static condensation demo that is not working for me. I tried to replicate what your example does by essentially flattening the lists of coefficients and constants that go into the Form. How can I modify the following to make it work? Thanks for the help.
#Import
from petsc4py import PETSc
from mpi4py import MPI
import cffi
import numba
import numba.core.typing.cffi_utils as cffi_support
import numpy as np
import ufl
from basix.ufl import element
from dolfinx import mesh, geometry
from dolfinx.fem import (
Form,
Function,
Constant,
IntegralType,
dirichletbc,
form,
form_cpp_class,
functionspace,
locate_dofs_topological,
)
from dolfinx.fem.petsc import apply_lifting, assemble_matrix, assemble_vector
from dolfinx.io import XDMFFile
from dolfinx.jit import ffcx_jit
from dolfinx.mesh import locate_entities_boundary, meshtags
from ffcx.codegeneration.utils import numba_ufcx_kernel_signature as ufcx_signature
def extract_coefficients_constants(forms, ufcx_forms):
"""
Collect and deduplicate coefficients and constants used across multiple forms,
and compute the enabled coefficient indices for custom integral definition.
Parameters:
forms: list of dolfinx.fem.Form
The list of Python-level forms (e.g. form(a00), form(a01), etc.)
ufcx_forms: list of compiled UFCx objects
The list of corresponding compiled form objects (e.g. from ffcx_jit)
Returns:
flat_coeffs: list
List of unique coefficient C++ objects in order expected by kernel.
flat_consts: list
List of unique constant C++ objects.
flat_enabled: np.ndarray
List of enabled coefficient indices for the kernel.
"""
# All constants from all forms
all_consts = []
for f in forms:
for c in f.constants():
all_consts.append(c)
# All coefficients from all forms
all_coeffs = []
for f in forms:
for c in f.coefficients():
all_coeffs.append(c)
# Deduplicate by Python object identity
unique_consts = []
seen_consts = set()
for c in all_consts:
if id(c) not in seen_consts:
seen_consts.add(id(c))
unique_consts.append(c)
unique_coeffs = []
seen_coeffs = set()
for c in all_coeffs:
if id(c) not in seen_coeffs:
seen_coeffs.add(id(c))
unique_coeffs.append(c)
# Map each ufcx coefficient index to our deduplicated list
flat_enabled = set()
for i, ufcx in enumerate(ufcx_forms):
for j in range(ufcx.num_coefficients):
if ufcx.form_integrals[0].enabled_coefficients[j]:
coeff_pos = ufcx.original_coefficient_positions[j]
coeff_obj = forms[i].coefficients()[coeff_pos] # index into original coefficient list
flat_index = unique_coeffs.index(coeff_obj)
flat_enabled.add(flat_index)
flat_coeffs = [c._cpp_object for c in unique_coeffs]
flat_consts = [c._cpp_object for c in unique_consts]
flat_enabled = np.array(sorted(flat_enabled), dtype=np.int32)
return flat_coeffs, flat_consts, flat_enabled
#Create mesh
msh = mesh.create_unit_square(MPI.COMM_WORLD, 2, 2)
# Stress (Se) and displacement (Ue) elements
Se = element("DG", msh.basix_cell(), 1, shape=(2, 2), symmetry=True, dtype=PETSc.RealType) # type: ignore
Ue = element("Lagrange", msh.basix_cell(), 2, shape=(2,), dtype=PETSc.RealType) # type: ignore
S = functionspace(msh, Se)
U = functionspace(msh, Ue)
DG0 = functionspace(msh, ('DG', 0))
sigma, tau = ufl.TrialFunction(S), ufl.TestFunction(S)
u, v = ufl.TrialFunction(U), ufl.TestFunction(U)
# Locate all facets at the free end and assign them value 1. Sort the
# facet indices (requirement for constructing MeshTags)
free_end_facets = np.sort(locate_entities_boundary(msh, 1, lambda x: np.isclose(x[0], 1.)))
mt = meshtags(msh, 1, free_end_facets, 1)
ds = ufl.Measure("ds", subdomain_data=mt)
# Homogeneous boundary condition in displacement
u_bc = Function(U)
u_bc.x.array[:] = 0
# Displacement BC is applied to the left side
left_facets = locate_entities_boundary(msh, 1, lambda x: np.isclose(x[0], 0.0))
bdofs = locate_dofs_topological(U, 1, left_facets)
bc = dirichletbc(u_bc, bdofs)
# Elastic stiffness tensor and Poisson ratio
E, nu = 1.0, 1.0 / 3.0
lam_ = E*nu/((1+nu)*(1-2*nu)) #Lame's first parameter Pa
mu_ = E/(2*(1+nu)) #Shear modulus Pa
rho_ = 2.0
#Want to work
lam, mu, rho = Function(DG0), Function(DG0), Function(DG0)
lam.interpolate(lambda x: np.full(x.shape[1], lam_))
lam.x.scatter_forward()
mu.interpolate(lambda x: np.full(x.shape[1], mu_))
mu.x.scatter_forward()
rho.interpolate(lambda x: np.full(x.shape[1], rho_))
rho.x.scatter_forward()
##Want to work
#lam = Constant(msh, lam_)
#mu = Constant(msh, mu_)
#rho = Constant(msh, rho_)
##Works
#lam = lam_
#mu = mu_
#rho = rho_
def sigma_u(u):
"""Constitutive relation for stress-strain. Assuming plane-stress in XY"""
eps = 0.5 * (ufl.grad(u) + ufl.grad(u).T)
sigma = 2 * mu * eps + lam * ufl.Identity(2) * ufl.tr(eps)
return sigma
a00 = ufl.inner(sigma, tau) * ufl.dx
a10 = -ufl.inner(sigma, ufl.grad(v)) * ufl.dx
a01 = -ufl.inner(sigma_u(u), tau) * ufl.dx
a11 = -rho * ufl.inner(u, v) * ufl.dx
f = ufl.as_vector([0.0, 1.0 / 16])
b1 = form(-ufl.inner(f, v) * ds(1), dtype=PETSc.ScalarType) # type: ignore
# JIT compile individual blocks tabulation kernels
ufcx00, _, _ = ffcx_jit(msh.comm, a00, form_compiler_options={"scalar_type": PETSc.ScalarType}) # type: ignore
kernel00 = getattr(ufcx00.form_integrals[0], f"tabulate_tensor_{np.dtype(PETSc.ScalarType).name}") # type: ignore
ufcx01, _, _ = ffcx_jit(msh.comm, a01, form_compiler_options={"scalar_type": PETSc.ScalarType}) # type: ignore
kernel01 = getattr(ufcx01.form_integrals[0], f"tabulate_tensor_{np.dtype(PETSc.ScalarType).name}") # type: ignore
ufcx10, _, _ = ffcx_jit(msh.comm, a10, form_compiler_options={"scalar_type": PETSc.ScalarType}) # type: ignore
kernel10 = getattr(ufcx10.form_integrals[0], f"tabulate_tensor_{np.dtype(PETSc.ScalarType).name}") # type: ignore
ufcx11, _, _ = ffcx_jit(msh.comm, a11, form_compiler_options={"scalar_type": PETSc.ScalarType}) # type: ignore
kernel11 = getattr(ufcx11.form_integrals[0], f"tabulate_tensor_{np.dtype(PETSc.ScalarType).name}") # type: ignore
ffi = cffi.FFI()
cffi_support.register_type(ffi.typeof("double _Complex"), numba.types.complex128)
# Get local dofmap sizes for later local tensor tabulations
Ssize = S.element.space_dimension
Usize = U.element.space_dimension
@numba.cfunc(ufcx_signature(PETSc.ScalarType, PETSc.RealType), nopython=True) # type: ignore
def tabulate_A(A_, w_, c_, coords_, entity_local_index, permutation=ffi.NULL):
"""Element kernel that applies static condensation."""
# Prepare target condensed local element tensor
A = numba.carray(A_, (Usize, Usize), dtype=PETSc.ScalarType)
# Tabulate all sub blocks locally
A00 = np.zeros((Ssize, Ssize), dtype=PETSc.ScalarType)
kernel00(ffi.from_buffer(A00), w_, c_, coords_, entity_local_index, permutation)
A01 = np.zeros((Ssize, Usize), dtype=PETSc.ScalarType)
kernel01(ffi.from_buffer(A01), w_, c_, coords_, entity_local_index, permutation)
A10 = np.zeros((Usize, Ssize), dtype=PETSc.ScalarType)
kernel10(ffi.from_buffer(A10), w_, c_, coords_, entity_local_index, permutation)
A11 = np.zeros((Usize, Usize), dtype=PETSc.ScalarType)
kernel11(ffi.from_buffer(A11), w_, c_, coords_, entity_local_index, permutation)
# A = - A10 * A00^{-1} * A01
A[:, :] = A11 - A10 @ np.linalg.solve(A00, A01)
# Prepare a Form with a condensed tabulation kernel
formtype = form_cpp_class(PETSc.ScalarType) # type: ignore
cells = np.arange(msh.topology.index_map(msh.topology.dim).size_local)
forms = [a00, a01, a10, a11]
ufcx_forms = [ufcx00, ufcx01, ufcx10, ufcx11]
flat_coeffs, flat_consts, flat_enabled = extract_coefficients_constants(forms, ufcx_forms)
integrals = {IntegralType.cell: [(-1, tabulate_A.address, cells, flat_enabled)]}
a_cond = Form(formtype([U._cpp_object, U._cpp_object], integrals, flat_coeffs, flat_consts, False, {}, None))
A_cond = assemble_matrix(a_cond, bcs=[bc])
A_cond.assemble()
# Pure displacement based formulation
a = form(-rho*ufl.inner(u, v) * ufl.dx - ufl.inner(sigma_u(u), ufl.grad(v)) * ufl.dx)
A = assemble_matrix(a, bcs=[bc])
A.assemble()
#Check that the inverse is what we expect
A_cond.convert('dense')
A_cond_np = A_cond.getDenseArray()
A.convert('dense')
A_np = A.getDenseArray()
assert np.allclose(A_cond_np, A_np, atol=1e-8), "Something incorrect with static condensation"