Issues with explicitly defining P1 basis functions at a vertex in a sub mesh

Context:

I am implementing an equilibrated flux reconstruction with FEniCSx.

Hereby, I have to do the following: I have to iterate over all the vertices v of the mesh and solve local problems in the patch arund the vertex v.

In particular, the problem formulation requires the P1 hat function \psi_v that is equal to 1 at v and 0 at the other vertices in the patch.

Issue:

  • I manage to correctly identify the patches around a vertex v.
  • I also manage to define the global P1 hat function equal to 1 at v in the parent mesh.
  • However, when I try to do the same on the extracted submesh, things start to go wrong. The local \psi_v is equal to 1 at the wrong vertex.

My approach:

  1. I tried to interpolate the global hat function into the local P1 space by explicitly providing the cells.
  2. I tried to identify the local vertex index using the parent_vertex_map returned by create_submesh

Both approaches fail in the same way: the wrong vertex is set to 1. What’s confusing me is that when I plot the local vertex, it is the correct one. Furthermore, the approach does not always fail (see case vertex_index = 1 in the MWE)

Question: What would be the correct, reliable way to do this?

Here’s an MWE:

import numpy as np
from mpi4py import MPI
from dolfinx import fem
from dolfinx.mesh import create_submesh
from dolfinx.plot import vtk_mesh
from dolfinx.io import gmsh as gmshio
import gmsh
import pyvista as pv
import matplotlib.pyplot as plt
from matplotlib.collections import PolyCollection


def cell_polygon(cell_index, domain, coords, tdim):
    cell_vertices = domain.topology.connectivity(tdim, 0).links(cell_index)
    return coords[cell_vertices][:, :2]  # only x, y


def plot_patch(vi, cells, domain, coords):
    tdim = domain.topology.dim
    fig, ax = plt.subplots(figsize=(6, 6))

    # Ensure required connectivities are available.
    domain.topology.create_connectivity(tdim, 0)
    domain.topology.create_connectivity(tdim - 1, 0)
    domain.topology.create_connectivity(tdim - 1, tdim)

    all_polys = [
        cell_polygon(c, domain, coords, tdim)
        for c in range(domain.topology.index_map(tdim).size_local)
    ]
    pc_all = PolyCollection(
        all_polys, facecolor="lightgray", edgecolor="black", linewidth=0.5, alpha=0.4
    )
    ax.add_collection(pc_all)

    patch_polys = [cell_polygon(c, domain, coords, tdim) for c in cells]
    pc_patch = PolyCollection(
        patch_polys,
        facecolor="cornflowerblue",
        edgecolor="black",
        linewidth=1.5,
        alpha=0.9,
    )
    ax.add_collection(pc_patch)

    ax.scatter(coords[vi, 0], coords[vi, 1], c="red", s=60, label=f"vertex a={vi}")

    ax.set_title(f"Patch ω_{vi}: {len(cells)} cells")
    ax.set_aspect("equal")
    ax.legend()
    plt.show()


def create_mesh(mesh_size: float = 0.5) -> gmshio.MeshData:
    gmsh.initialize()
    gmsh.model.add("unit_square")

    geom = gmsh.model.occ
    p1 = geom.addPoint(0, 0, 0, mesh_size)
    p2 = geom.addPoint(1, 0, 0, mesh_size)
    p3 = geom.addPoint(1, 1, 0, mesh_size)
    p4 = geom.addPoint(0, 1, 0, mesh_size)
    geom.synchronize()

    l1 = geom.addLine(p1, p2)
    l2 = geom.addLine(p2, p3)
    l3 = geom.addLine(p3, p4)
    l4 = geom.addLine(p4, p1)
    geom.synchronize()
    cl = geom.addCurveLoop([l1, l2, l3, l4])
    ps = geom.addPlaneSurface([cl])
    geom.synchronize()

    # gmsh.model.addPhysicalGroup(1, [l1, l2, l3, l4], name="DirichletBoundary")
    gmsh.model.addPhysicalGroup(1, [l1, l2], name="DirichletBoundary")
    gmsh.model.addPhysicalGroup(1, [l3, l4], name="NeumannBoundary")
    gmsh.model.addPhysicalGroup(2, [ps], name="Domain")

    gmsh.option.setNumber("Mesh.CharacteristicLengthMin", 1e-12)
    gmsh.option.setNumber("Mesh.CharacteristicLengthMax", mesh_size)
    gmsh.model.mesh.generate(2)

    # gmsh.fltk.run()

    mesh_data = gmshio.model_to_mesh(gmsh.model, MPI.COMM_WORLD, 0, gdim=2)

    gmsh.finalize()

    return mesh_data


def plot_psi(psi, title):
    cells_plot, types_plot, x_plot = vtk_mesh(
        psi.function_space.mesh, psi.function_space.mesh.topology.dim
    )
    grid = pv.UnstructuredGrid(cells_plot, types_plot, x_plot)
    grid.point_data["psi"] = psi.x.array.real
    pl = pv.Plotter()
    pl.add_mesh(grid, scalars="psi", show_edges=True)
    pl.add_title(title, font_size=14)
    pl.add_axes()
    pl.view_xy()
    pl.show()


if __name__ == "__main__":

    mesh_data = create_mesh()
    mesh = mesh_data.mesh

    # vertex_index = 1  # works
    vertex_index = 3  # fails
    tdim = mesh.topology.dim
    mesh.topology.create_connectivity(0, tdim)
    v2c = mesh.topology.connectivity(0, tdim)
    patch_cells = v2c.links(vertex_index)
    plot_patch(vertex_index, patch_cells, mesh, mesh.geometry.x)

    # global hat function ψ_v on full mesh
    P1_global = fem.functionspace(mesh, ("CG", 1))
    psi_global = fem.Function(P1_global)

    vertex_dofs = fem.locate_dofs_topological(
        P1_global, 0, np.array([vertex_index], dtype=np.int32)
    )
    psi_global.x.array[:] = 0.0
    psi_global.x.array[vertex_dofs] = 1.0
    plot_psi(psi_global, "Global hat function ψ_v on full mesh")

    submesh, parent_cell_map, parent_vertex_map, parent_node_map = create_submesh(
        mesh, tdim, patch_cells
    )

    # interpolated hat function ψ_v on submesh
    patch_cells_sub = parent_cell_map.sub_topology_to_topology(
        patch_cells, inverse=True
    )
    P1_sub = fem.functionspace(submesh, ("CG", 1))
    psi_sub_interp = fem.Function(P1_sub)
    psi_sub_interp.interpolate(
        psi_global,
        cells0=patch_cells,
        cells1=patch_cells_sub,
    )
    plot_psi(psi_sub_interp, "Interpolated hat function ψ_v on submesh")

    # exact hat function ψ_v on submesh
    vertex_local_index = parent_vertex_map.sub_topology_to_topology(
        np.array([vertex_index]), inverse=True
    )[0]
    submesh.topology.create_connectivity(0, submesh.topology.dim)
    vertex_dofs_sub = fem.locate_dofs_topological(
        P1_sub, 0, np.array([vertex_local_index], dtype=np.int32)
    )
    plot_patch(
        vertex_local_index,
        np.arange(len(patch_cells), dtype=np.int32),
        submesh,
        submesh.geometry.x,
    )

    psi_sub_exact = fem.Function(P1_sub)
    psi_sub_exact.x.array[:] = 0.0
    psi_sub_exact.x.array[vertex_dofs_sub] = 1.0
    plot_psi(psi_sub_exact, "Exact hat function ψ_v on submesh")

Visualizations:


Version info:

fenics-basix              0.10.0          py311ha99e8d3_0    conda-forge
fenics-basix-nanobind-abi 2.16            py310hf7f2cdb_0    conda-forge
fenics-dolfinx            0.10.0          py311hc730058_101    conda-forge
fenics-ffcx               0.10.0             pyh1e3ed43_0    conda-forge
fenics-libbasix           0.10.0          py311haf9c301_0    conda-forge
fenics-libdolfinx         0.10.0          py311h0e29e79_101    conda-forge
fenics-ufcx               0.10.0               haec879b_0    conda-forge
fenics-ufl                2025.2.0           pyhcf101f3_0    conda-forge
1 Like

Hi @Philipp_Weder

There are a few things to consider here:

  1. The relation between mesh topology (cell<->vertex info) and mesh geometry (cell<->mesh_node) info. These could be 1-1 for linear meshes on P1 spaces, but they are not. Therefore, I’ve rewritten cell_polygon to take care of the mapping
  2. The relation between mesh geometry and function dofmap. For P1, these could be equivalent, but they are note. I’ve rewritten plot_psi to take this into account.

Code below:

import matplotlib as mpl

mpl.use("TkAgg")

import numpy as np
from mpi4py import MPI
from dolfinx import fem
from dolfinx.mesh import create_submesh, entities_to_geometry, Mesh
from dolfinx.plot import vtk_mesh
from dolfinx.io import gmsh as gmshio
import gmsh
import pyvista as pv
import matplotlib.pyplot as plt
from matplotlib.collections import PolyCollection
import numpy.typing as npt


def cell_polygons(domain: Mesh, cells: npt.NDArray[np.int32]):
    cells_as_vertices = entities_to_geometry(domain, domain.topology.dim, cells)
    coords = domain.geometry.x
    return coords[cells_as_vertices][:, :, :2]  # only x, y


def plot_patch(vi, cells, domain, coords):
    tdim = domain.topology.dim
    fig, ax = plt.subplots(figsize=(6, 6))

    # Ensure required connectivities are available.
    domain.topology.create_connectivity(tdim, 0)
    domain.topology.create_connectivity(tdim - 1, 0)
    domain.topology.create_connectivity(tdim - 1, tdim)

    all_polys = cell_polygons(
        domain, np.arange(domain.topology.index_map(tdim).size_local, dtype=np.int32)
    )
    pc_all = PolyCollection(
        all_polys, facecolor="lightgray", edgecolor="black", linewidth=0.5, alpha=0.4
    )
    ax.add_collection(pc_all)

    patch_polys = cell_polygons(domain, cells)
    pc_patch = PolyCollection(
        patch_polys,
        facecolor="cornflowerblue",
        edgecolor="black",
        linewidth=1.5,
        alpha=0.9,
    )
    ax.add_collection(pc_patch)

    ax.scatter(coords[vi, 0], coords[vi, 1], c="red", s=60, label=f"vertex a={vi}")

    ax.set_title(f"Patch ω_{vi}: {len(cells)} cells")
    ax.set_aspect("equal")
    ax.legend()
    plt.show()


def create_mesh(mesh_size: float = 0.5):  # -> gmshio.MeshData:
    gmsh.initialize()
    gmsh.model.add("unit_square")

    geom = gmsh.model.occ
    p1 = geom.addPoint(0, 0, 0, mesh_size)
    p2 = geom.addPoint(1, 0, 0, mesh_size)
    p3 = geom.addPoint(1, 1, 0, mesh_size)
    p4 = geom.addPoint(0, 1, 0, mesh_size)
    geom.synchronize()

    l1 = geom.addLine(p1, p2)
    l2 = geom.addLine(p2, p3)
    l3 = geom.addLine(p3, p4)
    l4 = geom.addLine(p4, p1)
    geom.synchronize()
    cl = geom.addCurveLoop([l1, l2, l3, l4])
    ps = geom.addPlaneSurface([cl])
    geom.synchronize()

    # gmsh.model.addPhysicalGroup(1, [l1, l2, l3, l4], name="DirichletBoundary")
    gmsh.model.addPhysicalGroup(1, [l1, l2], name="DirichletBoundary")
    gmsh.model.addPhysicalGroup(1, [l3, l4], name="NeumannBoundary")
    gmsh.model.addPhysicalGroup(2, [ps], name="Domain")

    gmsh.option.setNumber("Mesh.CharacteristicLengthMin", 1e-12)
    gmsh.option.setNumber("Mesh.CharacteristicLengthMax", mesh_size)
    gmsh.model.mesh.generate(2)

    # gmsh.fltk.run()

    mesh_data = gmshio.model_to_mesh(gmsh.model, MPI.COMM_WORLD, 0, gdim=2)

    gmsh.finalize()

    return mesh_data


def plot_psi(psi, title):
    cells_plot, types_plot, x_plot = vtk_mesh(
        psi.function_space,
    )
    grid = pv.UnstructuredGrid(cells_plot, types_plot, x_plot)
    grid.point_data["psi"] = psi.x.array.real

    pl = pv.Plotter()
    pl.add_mesh(grid, scalars="psi", show_edges=True)
    pl.add_title(title, font_size=14)
    pl.add_axes()
    pl.view_xy()
    pl.show()


if __name__ == "__main__":
    mesh_data = create_mesh()
    mesh = mesh_data.mesh

    # vertex_index = 1  # works
    vertex_index = 3  # fails
    tdim = mesh.topology.dim
    mesh.topology.create_connectivity(0, tdim)
    v2c = mesh.topology.connectivity(0, tdim)
    patch_cells = v2c.links(vertex_index)
    plot_patch(vertex_index, patch_cells, mesh, mesh.geometry.x)

    # global hat function ψ_v on full mesh
    P1_global = fem.functionspace(mesh, ("CG", 1))
    psi_global = fem.Function(P1_global)

    vertex_dofs = fem.locate_dofs_topological(
        P1_global, 0, np.array([vertex_index], dtype=np.int32)
    )
    psi_global.x.array[:] = 0.0
    psi_global.x.array[vertex_dofs] = 1.0
    plot_psi(psi_global, "Global hat function ψ_v on full mesh")

    submesh, parent_cell_map, parent_vertex_map, parent_node_map = create_submesh(
        mesh, tdim, patch_cells
    )

    # interpolated hat function ψ_v on submesh
    num_submesh_cells = submesh.topology.index_map(tdim).size_local + submesh.topology.index_map(tdim).num_ghosts
    patch_cells_sub = np.arange(num_submesh_cells, dtype=np.int32)
    patch_cells_mapped = parent_cell_map.sub_topology_to_topology(
        patch_cells_sub, inverse=False
    )
    P1_sub = fem.functionspace(submesh, ("CG", 1))
    psi_sub_interp = fem.Function(P1_sub)
    psi_sub_interp.interpolate(
        psi_global,
        cells0=patch_cells_mapped,
        cells1=patch_cells_sub,
    )
    plot_psi(psi_sub_interp, "Interpolated hat function ψ_v on submesh")

    # exact hat function ψ_v on submesh
    vertex_local_index = parent_vertex_map.sub_topology_to_topology(
        np.array([vertex_index]), inverse=True
    )[0]
    submesh.topology.create_connectivity(0, submesh.topology.dim)
    vertex_dofs_sub = fem.locate_dofs_topological(
        P1_sub, 0, np.array([vertex_local_index], dtype=np.int32)
    )
    plot_patch(
        vertex_local_index,
        np.arange(len(patch_cells), dtype=np.int32),
        submesh,
        submesh.geometry.x,
    )

    psi_sub_exact = fem.Function(P1_sub)
    psi_sub_exact.x.array[:] = 0.0
    psi_sub_exact.x.array[vertex_dofs_sub] = 1.0
    plot_psi(psi_sub_exact, "Exact hat function ψ_v on submesh")