For a project, I need to do lower level interpolation in the facet DOFs on RT spaces. I started to interpolate a contious vector linear function into RT2 space through the build-in interpolation function and a lower level interpolation in Fenicsx. But I get different interpolations on facets. Can anyone possibly identify where the error comes from the code? Many thanks for helping!! Below please find my code.
import numpy as np
from mpi4py import MPI
from dolfinx import mesh, fem
import basix
def v_func(x):
"""Return vector flux (v(x)·nF) nF at point x."""
v = np.array([2*x[0] + x[1] + 1,
-x[0] + 3*x[1] - 2])
return v
# def v_func(x):
# if x.ndim == 1: # single point
# return np.array([1.0, -2.0])
# else: # batch of points
# N = x.shape[1]
# return np.tile(np.array([[1.0], [-2.0]]), (1, N))
# 1) Mesh (triangles)
domain = mesh.create_unit_square(MPI.COMM_WORLD, 1, 1, cell_type=mesh.CellType.triangle)
tdim = domain.topology.dim
gdim = 2
# Build the needed connectivities
domain.topology.create_connectivity(tdim, tdim-1) # cell → facet
domain.topology.create_connectivity(tdim-1, tdim) # facet → cell
domain.topology.create_connectivity(tdim-1,0) # facet → vertex
facet2cell = domain.topology.connectivity(tdim-1, tdim)
cell2facet = domain.topology.connectivity(tdim, tdim-1)
cells2vert = domain.topology.connectivity(tdim, 0)
facet_to_vertices = domain.topology.connectivity(tdim - 1, 0)
s = 2
V_rt = fem.functionspace(domain, ("RT", s))
el = basix.create_element(basix.ElementFamily.RT, basix.CellType.triangle, s)
entity_dofs = el.entity_dofs
I = el.interpolation_matrix
# print("the interpolation matrix shape", I.shape)
# print("The interpolation matrix I", I)
n_facets = domain.topology.index_map(tdim-1).size_local
facet_normals = np.zeros((n_facets, gdim))
# 2) Function to hold coefficients
sigma = fem.Function(V_rt)
sigma.x.array[:] = 0.0
# Loop over facets and compute normals from vertices
X = domain.geometry.x[:, :tdim]
for f in range(n_facets):
verts = facet_to_vertices.links(f)
# print(f"Facet {f}, vertices {verts}")
xv = X[verts]
if gdim == 2:
# Each facet is an edge with 2 vertices
v = xv[1] - xv[0]
n = np.array([v[1], -v[0]])
elif gdim == 3:
# Each facet is a triangle with 3 vertices
v1 = xv[1] - xv[0]
v2 = xv[2] - xv[0]
n = np.cross(v1, v2)
else:
raise ValueError("Only 2D and 3D supported")
n /= np.linalg.norm(n)
facet_normals[f] = n
# Keep only first tdim components (here tdim=2)
for f in range(n_facets):
# Example f = 1 → right edge of the square ([1,0]–[1,1])
cells = facet2cell.links(f)
# facet2cell: which cells share facet f
# e.g. f=1 (right edge) → cells = [0]
# f=4 (diagonal) → cells = [0, 1]
if len(cells) == 0:
continue # skip if boundary facet not owned by this rank
cell = int(min(cells))
# pick smallest cell index for consistency
# e.g. f=4 (diagonal) → cells = [0,1], so cell = 0
global_facets = cell2facet.links(cell)
# list of facet indices (global) for this cell
# e.g. for cell=0, global_facets might be [0, 1, 4]
lf = np.where(global_facets == f)[0][0] #why [0][0]? Because np.where returns a tuple of arrays, and we want the first index of the first array.
# find which local facet corresponds to global facet f
# e.g. if f=1 and global_facets=[0,1,4] → lf = 1
# e.g. if f=4 and global_facets=[0,1,4] → lf = 2
#get facet interpolation points (from Basix RT element)
# entity_dofs[1][lf] = local DOF indices for this facet in this cell
# e.g. for RT1 on triangle, lf=1 → [2,3] (2 DOFs on this edge)
facet_dofs_local = entity_dofs[1][lf]
# V_rt.dofmap.cell_dofs(cell) = mapping of all local DOFs of this cell to global DOF numbers
# e.g. for cell 0, might be [0,1,2,3,4,5,6,7]
# facet_dofs_global = global DOF numbers corresponding to this facet’s local DOFs
# e.g. facet_dofs_local=[2,3] → facet_dofs_global=[2,3] (global indices in cell 0)
facet_dofs_global = V_rt.dofmap.cell_dofs(cell)[facet_dofs_local]
# evaluate g_F(x) at facet interpolation points
# find the DOFs points on the facet indexed with facet_dofs_local in the reference element
facet_points_ref = el.points[facet_dofs_local]
# map to physical points
verts = X[cells2vert.links(cell)]
x0, x1, x2 = verts
# print("verts", verts)
J = np.column_stack((x1 - x0, x2 - x0))
invJ = np.linalg.inv(J)
facet_points_phys = x0 + facet_points_ref @ J.T
# nF = facet_normals[f]
# evaluate target flux function at these points
num_points = el.points.shape[0] # e.g. 9 for RT2
Interpolation_full = np.zeros(num_points * gdim)
# print("J", J)
detJ = np.linalg.det(J)
num_points = el.points.shape[0]
Interpolation_full = np.zeros((gdim, num_points))
for i, dof in enumerate(facet_dofs_local):
p = facet_points_ref[i]
x_p = x0 + J @ p
val = v_func(x_p)
val_ref = detJ * (invJ @ val) # Piola pullback
Interpolation_full[:, dof] = val_ref
# Flatten in component-major order for any gdim
Interpolation_flat = np.concatenate([Interpolation_full[d, :] for d in range(gdim)])
dof_vals = I[facet_dofs_local, :] @ Interpolation_flat
# Fill only the entries for this facet’s interpolation points
# for i, dof in enumerate(facet_dofs_local):
# p = facet_points_ref[i]
# x_p = x0 + J @ p
# val = v_func(x_p) #v_func(x_p, nF) * nF # vector at this point
# val_ref = detJ * (invJ @ val) # Piola pullback
# offset = dof * gdim
# Interpolation_full[offset:offset+gdim] = val_ref
# Now Interpolation_full has length 18, with nonzeros only at this facet’s points
if len(cells) == 1:
# print("Boundary facet", f, "cell", cell)
# print(" facet_points_phys", facet_points_phys)
# # print("facet_dofs_local", facet_dofs_local)
# # print(" The interpolation matrix I on facet", I[facet_dofs_local, :])
# # print(" Interpolation_full", Interpolation_full)
# print(" dof_vals", dof_vals)
print("facet_dofs_global", facet_dofs_global)
# 4) assign values into global RT vector
# ref_normal = basix.facet_normals(el)[lf] # for local facet lf
# nF = facet_normals[f]
# if np.dot(J @ ref_normal, nF) < 0:
# dof_vals *= -1
sigma.x.array[facet_dofs_global] = dof_vals
v = fem.Function(V_rt)
v.interpolate(v_func) # built-in interpolation with Piola transform
tol = 1e-5
diff_idx = np.where(np.abs(sigma.x.array - v.x.array) > tol)[0]
# print("Mismatched indices:", diff_idx)
print("sigma.x at mismatched:", sigma.x.array[diff_idx])
print("v.x at mismatched:", v.x.array[diff_idx])
print("difference:", sigma.x.array[diff_idx] - v.x.array[diff_idx])
print(v.x.array[:])