How to collect the global mesh without writing a file

Here is an example code of how to do it:

# Copyright (C) 2023 Jørgen S. Dokken
#
# SPDX-License-Identifier:    MIT
#
# Code to plot mesh with pyvista in parallel by gather data on a root process
#
# Last changed: 2023-10-23

import dolfinx
from mpi4py import MPI
import pyvista
import numpy as np
pyvista.start_xvfb(0.2)

mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 10, 10)
V = dolfinx.fem.functionspace(mesh, ("Lagrange", 2))
u = dolfinx.fem.Function(V)
u.interpolate(lambda x: x[0] + 2*x[1]**2)


# Create local VTK mesh data structures
topology, cell_types, geometry = dolfinx.plot.vtk_mesh(V)
num_cells_local = mesh.topology.index_map(mesh.topology.dim).size_local
num_dofs_local = V.dofmap.index_map.size_local * V.dofmap.index_map_bs
# VTK topology stored as (num_dofs_cell_0, dof_0,.... dof_(num_dofs_cell_0, num_dofs_cell_1, ....))
# We assume we only have one cell type and one dofmap, thus every `num_dofs_per_cell` is the same
num_dofs_per_cell = topology[0]
# Get only dof indices
topology_dofs = (np.arange(len(topology)) % (num_dofs_per_cell+1)) != 0

# Map to global dof indices
global_dofs = V.dofmap.index_map.local_to_global(topology[topology_dofs].copy())
# Overwrite topology
topology[topology_dofs] = global_dofs
# Choose root
root = 0

# Gather data
global_topology = mesh.comm.gather(topology[:(num_dofs_per_cell+1)*num_cells_local], root=root)
global_geometry = mesh.comm.gather(geometry[:V.dofmap.index_map.size_local,:], root=root)
global_ct = mesh.comm.gather(cell_types[:num_cells_local])
global_vals = mesh.comm.gather(u.x.array[:num_dofs_local])

if mesh.comm.rank == root:
    # Stack data
    root_geom = np.vstack(global_geometry)
    root_top = np.concatenate(global_topology)
    root_ct = np.concatenate(global_ct)
    root_vals = np.concatenate(global_vals)

    # Plot as in serial
    pyvista.OFF_SCREEN = True
    grid = pyvista.UnstructuredGrid(root_top, root_ct, root_geom)
    grid.point_data["u"] = root_vals
    grid.set_active_scalars("u")
    u_plotter = pyvista.Plotter()
    u_plotter.add_mesh(grid, show_edges=True)
    u_plotter.view_xy()
    u_plotter.screenshot("figure.png")

I’ll add it to the list of add to my tutorial when I have time for it.

2 Likes