Howdy,
Building on a previous comment here: How are boundary integrals implemented in FENICS? - #3 by Ryan_Vogt2 - I implemented the normal vector.
I’m trying to have my JAX friendly implementation of a boundary integral over a square (by decomposing it into four 1-d integrals) that is consistent with FENICS. Since I need to take derivatives of a boundary integral with respect to time through autodifferentiation for some semi-discretiation approaches I am creating, I cannot rely on the FENICs implementation of a boundary integral.
I am doing the boundary integral counter clock wise. However the orientation doesn’t seem to be the problem for me.
When I try to calculate the boundary integral in fenics for 25 basis functions I in FENICS vs JAX I find -
fenics boundary vector
[-0.03479013 -0.36843907 0.16037663 -0.44233786 0. 0.03698622
-0.48873423 0. 0. -0.08870382 -0.50719482 0.
0. 0. -0.17113571 -0.48873423 0. 0.
-0.08870382 -0.44233786 0. 0.03698622 -0.36843907 0.16037663
-0.03479013]
jax boundary vector
[ 3.4922317e-02 3.6381820e-01 -1.5679288e-01 4.3623355e-01
3.8644311e-18 -3.5164107e-02 4.8160055e-01 0.0000000e+00
1.2756327e-19 8.8573724e-02 0.0000000e+00 0.0000000e+00
0.0000000e+00 -6.6702707e-20 -7.4505806e-09 -4.8160055e-01
0.0000000e+00 -1.1649350e-19 -8.8573724e-02 -4.3623355e-01
-3.3880987e-18 3.5164103e-02 -3.6381817e-01 1.5679288e-01
-3.4922298e-02]
There seems to be a trend that I agree (up to error), are off by a negative sign, or some enteries are reported to be zero by JAX that are non-zero in FENICS.
In FENICS when I calculate the boundary integral I do:
def fenics_boundary_integral(t):
true_fun.t=t
ac = Constant(alpha)
n = FacetNormal(mesh)
# n_proj = project(n, V)
# print(n_proj.eval(np.asarray([0,0]),mesh.coordinates().flatten()))
# quit()
true_fun_prog = project(true_fun,V_quad)
boundary = ac*np.dot(grad(true_fun_prog),n)*phi_i*ds
boundary_vec = assemble(boundary)
#print(boundary_vec)
boundary_vec = boundary_vec.get_local()
print("fenics boundary vector")
print(boundary_vec)
The only thing I could think is the issue is there is some weird oritentation issue that I am not taking into account to ensure I am always consistent with FENICS. There is clearly some assumption I am violating. For the JAX boundary integral I use the exact gradient, and dot it with the FENICS normal vector againist all basis functions and integrate using a trapezoid rule.
Here is my complete bare bone code of calculating the boundary integral in JAX and FENICS:
import sys
import os
import math
import numpy as np
from fenics import *
import jax.numpy as jnp
import jax
from jax.config import config
true_fun = Expression("(1-exp(-40*t))*sin(x[0]+x[1])",t=0,degree=2)
alpha=2.0
xlims=(0,1)
ylims=(0,1)
N=4
trapz_samples=100
limit_points = (Point(xlims[0], ylims[0]), Point(xlims[1], ylims[1]))
mesh = RectangleMesh(limit_points[0], limit_points[1], N, N)
V = FunctionSpace(mesh, "CG", 1)
V_quad = FunctionSpace(mesh, "CG", 2)
phi_i = TestFunction(V)
phi_j = TrialFunction(V)
tree = mesh.bounding_box_tree()
basis_func_temp = lambda i,x: basis_func(i,x)
def basis_func(i,pt):
cell_index = tree.compute_first_entity_collision(Point(*pt))
cell_global_dofs = V.dofmap().cell_dofs(cell_index)
for local_dof in range(0,len(cell_global_dofs)):
if(i==cell_global_dofs[local_dof]):
cell = Cell(mesh, cell_index)
values = np.array([0,])
return V.element().evaluate_basis(local_dof,pt,
cell.get_vertex_coordinates(),
cell.orientation())[0]
return 0.0
def basis_func_list():
list_of_basis_functions=[]
for i in range(0,V.dim()):
list_of_basis_functions +=[(lambda i: lambda x : basis_func_temp(i,x))(i),]
return list_of_basis_functions
my_basis_list =basis_func_list()
def basis_funcs(pt):
return jnp.asarray([my_basis_list[i](pt) for i in range(V.dim())])
def trapz(f, lims, n, t):
xs = np.linspace(lims[0], lims[1], n, endpoint=True)
return jnp.trapz(jnp.asarray([f(x, t) for x in xs]), xs, axis=0)
trapz_JIT = jax.jit(trapz, static_argnums=(0, 1, 2))
def boundary_fun(pt, t):
return jnp.dot(grad_fun(pt[0], pt[1], t), normal(pt)) * basis_funcs(pt)
# Because the boundary is made of piecewise functions, we parametrize the boundary integral as an integral over a
# single variable l in the domain [0, 1]
a_x = xlims[1] - xlims[0]
a_y = ylims[1] - ylims[0]
# Calculates the parametrized line integral of gradient and basis function
def boundary_integrand(l, t):
return a_x * boundary_fun((a_x * l + xlims[0], ylims[0]), t) + \
a_y * boundary_fun((xlims[1], a_y * l + ylims[0]), t) + \
-a_x * boundary_fun((-a_x * l + xlims[1], ylims[1]), t) + \
-a_y * boundary_fun((xlims[0], -a_y * l + ylims[1]), t)
def boundary_integral(t):
return alpha * trapz_JIT(boundary_integrand, (0, 1), trapz_samples, t)
boundary_mesh = BoundaryMesh(mesh, "exterior", True)
D= VectorFunctionSpace(boundary_mesh, "CG", 1)
n = FacetNormal(mesh)
# Interpolate mesh values into boundary function
u_boundary = Function(D)
u_boundary.set_allow_extrapolation(True)
# Approximate facet normal in a suitable space using projection
n = FacetNormal(mesh)
V_ex = VectorFunctionSpace(mesh, "CG", 2)
u_ = TrialFunction(V_ex)
v_ = TestFunction(V_ex)
a_ = inner(u_,v_)*ds
l_ = inner(n, v_)*ds
A_ = assemble(a_, keep_diagonal=True)
L_ = assemble(l_)
A_.ident_zeros()
nh = Function(V_ex)
solve(A_, nh.vector(), L_)
def normal(pt):
my_pt = Point(pt[0], pt[1])
my_norm = nh(my_pt)
return jnp.array([my_norm[0], my_norm[1]])
def grad_fun(x, y, t):
return jnp.array([(1-jnp.exp(-40*t))*jnp.cos(x+y), (1-jnp.exp(-40*t))*jnp.cos(x+y)])
def fenics_boundary_integral(t):
true_fun.t=t
ac = Constant(alpha)
n = FacetNormal(mesh)
# n_proj = project(n, V)
# print(n_proj.eval(np.asarray([0,0]),mesh.coordinates().flatten()))
# quit()
true_fun_prog = project(true_fun,V_quad)
boundary = ac*np.dot(grad(true_fun_prog),n)*phi_i*ds
boundary_vec = assemble(boundary)
#print(boundary_vec)
boundary_vec = boundary_vec.get_local()
print("fenics boundary vector")
print(boundary_vec)
def jax_boundary_integral(t):
bi = boundary_integral(t)
print("jax boundary vector")
jax.debug.print("bi: {}", bi)
print("Running test")
fenics_boundary_integral(0.18577451194110492)
jax_boundary_integral(0.18577451194110492)