As part of a larger JAX/FEniCS program, I have code that iterates over mesh cells, and the cells and nodal coordinates are used in some arbitrary numpy computations, similar to the code snippet below. Could anyone suggest a way to get this to work nicely with either pyadjoint or jax grad for reverse mode AD purposes?
import dolfin as d
import jax.numpy as jnp
from jax import grad
mesh = d.UnitSquareMesh(10,10)
def bfs(mesh, point):
tree = mesh.bounding_box_tree()
start_cell = tree.compute_first_entity_collision(
d.Point(point))
#....
return start_cell
grad_fn = grad(bfs) # get the gradient function
test = bfs(mesh, (0.5, 0.5))
print(test)
test_grad = grad_fn(mesh, (0.5,0.5))