Thank you very much @dokken. It works perfectly.
def plot(self):
cells, types, x = plot.create_vtk_mesh(self.V)
num_cells_local = self.domain.topology.index_map(self.domain.topology.dim).size_local
num_dofs_local = self.V.dofmap.index_map.size_local * self.V.dofmap.index_map_bs
num_dofs_per_cell = cells[0]
cells_dofs = (np.arange(len(cells)) % (num_dofs_per_cell+1)) != 0
global_dofs = self.V.dofmap.index_map.local_to_global(cells[cells_dofs].copy())
cells[cells_dofs] = global_dofs
root = 0
global_cells = self.domain.comm.gather(cells[:(num_dofs_per_cell+1)*num_cells_local], root=root)
global_types = self.domain.comm.gather(types[:num_cells_local])
global_x = self.domain.comm.gather(x[:self.V.dofmap.index_map.size_local,:], root=root)
global_vals = self.domain.comm.gather(self.T.x.array[:num_dofs_local], root=root)
if MPI.COMM_WORLD.rank == 0:
root_x = np.vstack(global_x)
root_cells = np.concatenate(global_cells)
root_types = np.concatenate(global_types)
root_vals = np.concatenate(global_vals)
grid = pyvista.UnstructuredGrid(root_cells, root_types, root_x)
grid.point_data["u"] = root_vals
grid.set_active_scalars("u")
plotter = pyvista.Plotter()
plotter.add_mesh(grid, show_edges=False)
plotter.view_xy()
plotter.show()