Mesh operations as part of JAX program

As part of a larger JAX/FEniCS program, I have code that iterates over mesh cells, and the cells and nodal coordinates are used in some arbitrary numpy computations, similar to the code snippet below. Could anyone suggest a way to get this to work nicely with either pyadjoint or jax grad for reverse mode AD purposes?

import dolfin as d
import jax.numpy as jnp
from jax import grad
mesh = d.UnitSquareMesh(10,10)

def bfs(mesh, point):

    tree = mesh.bounding_box_tree()
    start_cell = tree.compute_first_entity_collision(
        d.Point(point))
    
    #....
    return start_cell


grad_fn = grad(bfs) # get the gradient function
test = bfs(mesh, (0.5, 0.5))
print(test)
test_grad = grad_fn(mesh, (0.5,0.5))