Different boundary integral values between FENICS and JAX - trying to recreate boundary integral in JAX that is consistent with FENICS

Howdy,

Building on a previous comment here: How are boundary integrals implemented in FENICS? - #3 by Ryan_Vogt2 - I implemented the normal vector.

I’m trying to have my JAX friendly implementation of a boundary integral over a square (by decomposing it into four 1-d integrals) that is consistent with FENICS. Since I need to take derivatives of a boundary integral with respect to time through autodifferentiation for some semi-discretiation approaches I am creating, I cannot rely on the FENICs implementation of a boundary integral.

I am doing the boundary integral counter clock wise. However the orientation doesn’t seem to be the problem for me.

When I try to calculate the boundary integral in fenics for 25 basis functions I in FENICS vs JAX I find -

fenics boundary vector
[-0.03479013 -0.36843907  0.16037663 -0.44233786  0.          0.03698622
 -0.48873423  0.          0.         -0.08870382 -0.50719482  0.
  0.          0.         -0.17113571 -0.48873423  0.          0.
 -0.08870382 -0.44233786  0.          0.03698622 -0.36843907  0.16037663
 -0.03479013]
jax boundary vector
 [ 3.4922317e-02  3.6381820e-01 -1.5679288e-01  4.3623355e-01
  3.8644311e-18 -3.5164107e-02  4.8160055e-01  0.0000000e+00
  1.2756327e-19  8.8573724e-02  0.0000000e+00  0.0000000e+00
  0.0000000e+00 -6.6702707e-20 -7.4505806e-09 -4.8160055e-01
  0.0000000e+00 -1.1649350e-19 -8.8573724e-02 -4.3623355e-01
 -3.3880987e-18  3.5164103e-02 -3.6381817e-01  1.5679288e-01
 -3.4922298e-02]

There seems to be a trend that I agree (up to error), are off by a negative sign, or some enteries are reported to be zero by JAX that are non-zero in FENICS.

In FENICS when I calculate the boundary integral I do:

def fenics_boundary_integral(t):
    true_fun.t=t
    ac = Constant(alpha)
    n = FacetNormal(mesh)
    # n_proj = project(n, V)
    # print(n_proj.eval(np.asarray([0,0]),mesh.coordinates().flatten()))
    # quit()
    true_fun_prog = project(true_fun,V_quad)
    boundary = ac*np.dot(grad(true_fun_prog),n)*phi_i*ds
    boundary_vec = assemble(boundary)
    #print(boundary_vec)
    boundary_vec = boundary_vec.get_local()
    print("fenics boundary vector")
    print(boundary_vec)

The only thing I could think is the issue is there is some weird oritentation issue that I am not taking into account to ensure I am always consistent with FENICS. There is clearly some assumption I am violating. For the JAX boundary integral I use the exact gradient, and dot it with the FENICS normal vector againist all basis functions and integrate using a trapezoid rule.

Here is my complete bare bone code of calculating the boundary integral in JAX and FENICS:



import sys
import os
import math
import numpy as np
from fenics import *
import jax.numpy as jnp
import jax
from jax.config import config



true_fun = Expression("(1-exp(-40*t))*sin(x[0]+x[1])",t=0,degree=2)
alpha=2.0
xlims=(0,1)
ylims=(0,1)
N=4
trapz_samples=100
limit_points = (Point(xlims[0], ylims[0]), Point(xlims[1], ylims[1]))
mesh = RectangleMesh(limit_points[0], limit_points[1], N, N)



V = FunctionSpace(mesh, "CG", 1)

V_quad = FunctionSpace(mesh, "CG", 2)
phi_i = TestFunction(V)
phi_j = TrialFunction(V)



tree = mesh.bounding_box_tree()
basis_func_temp  = lambda i,x: basis_func(i,x)
def basis_func(i,pt):
    cell_index = tree.compute_first_entity_collision(Point(*pt))
    cell_global_dofs = V.dofmap().cell_dofs(cell_index)
    for local_dof in range(0,len(cell_global_dofs)):
        if(i==cell_global_dofs[local_dof]):
            cell = Cell(mesh, cell_index)
            values = np.array([0,])
            return V.element().evaluate_basis(local_dof,pt,
                                            cell.get_vertex_coordinates(),
                                            cell.orientation())[0]
    
    return 0.0

def basis_func_list():
    list_of_basis_functions=[]

    for i in range(0,V.dim()):
        list_of_basis_functions +=[(lambda i: lambda x : basis_func_temp(i,x))(i),]
    return list_of_basis_functions
my_basis_list =basis_func_list()

def basis_funcs(pt):
        
    return jnp.asarray([my_basis_list[i](pt) for i in range(V.dim())])
 

def trapz(f, lims, n, t):
    xs = np.linspace(lims[0], lims[1], n, endpoint=True)

    return jnp.trapz(jnp.asarray([f(x, t) for x in xs]), xs, axis=0)


trapz_JIT = jax.jit(trapz, static_argnums=(0, 1, 2))
def boundary_fun(pt, t):
    return jnp.dot(grad_fun(pt[0], pt[1], t), normal(pt)) * basis_funcs(pt)

# Because the boundary is made of piecewise functions, we parametrize the boundary integral as an integral over a
# single variable l in the domain [0, 1]
a_x = xlims[1] - xlims[0]
a_y = ylims[1] - ylims[0]

# Calculates the parametrized line integral of gradient and basis function
def boundary_integrand(l, t):
    return a_x * boundary_fun((a_x * l + xlims[0], ylims[0]), t) + \
        a_y * boundary_fun((xlims[1], a_y * l + ylims[0]), t) + \
        -a_x * boundary_fun((-a_x * l + xlims[1], ylims[1]), t) + \
        -a_y * boundary_fun((xlims[0], -a_y * l + ylims[1]), t)
   

def boundary_integral(t):
    return alpha * trapz_JIT(boundary_integrand, (0, 1), trapz_samples, t)

boundary_mesh = BoundaryMesh(mesh, "exterior", True)
D= VectorFunctionSpace(boundary_mesh, "CG", 1)
n = FacetNormal(mesh)
# Interpolate mesh values into boundary function
u_boundary = Function(D)
u_boundary.set_allow_extrapolation(True)

# Approximate facet normal in a suitable space using projection
n = FacetNormal(mesh)
V_ex = VectorFunctionSpace(mesh, "CG", 2)
u_ = TrialFunction(V_ex)
v_ = TestFunction(V_ex)
a_ = inner(u_,v_)*ds
l_ = inner(n, v_)*ds
A_ = assemble(a_, keep_diagonal=True)
L_ = assemble(l_)

A_.ident_zeros()
nh = Function(V_ex)

solve(A_, nh.vector(), L_)



def normal(pt):
    my_pt = Point(pt[0], pt[1])
    my_norm = nh(my_pt)
    return jnp.array([my_norm[0], my_norm[1]])

def grad_fun(x, y, t):
    return jnp.array([(1-jnp.exp(-40*t))*jnp.cos(x+y), (1-jnp.exp(-40*t))*jnp.cos(x+y)])

def fenics_boundary_integral(t):
    true_fun.t=t
    ac = Constant(alpha)
    n = FacetNormal(mesh)
    # n_proj = project(n, V)
    # print(n_proj.eval(np.asarray([0,0]),mesh.coordinates().flatten()))
    # quit()
    true_fun_prog = project(true_fun,V_quad)
    boundary = ac*np.dot(grad(true_fun_prog),n)*phi_i*ds
    boundary_vec = assemble(boundary)
    #print(boundary_vec)
    boundary_vec = boundary_vec.get_local()
    print("fenics boundary vector")
    print(boundary_vec)

def jax_boundary_integral(t):
    bi = boundary_integral(t)
    print("jax boundary vector")
    jax.debug.print("bi: {}", bi)

print("Running test")
fenics_boundary_integral(0.18577451194110492)
jax_boundary_integral(0.18577451194110492)

I would strongly suggest you take a step back. Create a UnitSquare consisting of two cells, and only integrate over a single facet (on either left, right, top or bottom).
You can compute this integral analytically, and compare it to what you get with jax

thanks for the suggestion Dokken, I will do that. To compute the integral analytically, I would need to know the form of a basis function. I asked a question a long time ago trying to go down this vein : Documentation on definition of basis functions over elements for CG and DG for arbitary degree of polynomial? . Do the links you listed have the functional form of an n-th degree polynomial over an element, or is there FENICS documentation?

Symfem should be able to provide arbitrary order basis functions.
For your test case, you are using Lagrange 1 and 2, which you can find at