Inter-mesh interpolation slower than Poisson solve

I am experiencing slow inter-mesh interpolation. Here is a MWE with the profiling library line-profiler :

from dolfinx import io, fem, mesh, cpp, geometry
import ufl
from mpi4py import MPI
import dolfinx.fem.petsc
import basix.ufl
from line_profiler import LineProfiler

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

def exact_sol(x):
    return 2 -(x[0]**2 + x[1]**2)
def rhs():
    return 4

def interpolate(sending_func,
                receiving_func,):
    '''
    Interpolate sending_func to receiving_func,
    each comming from separate meshes
    '''
    targetSpace = receiving_func.function_space
    nmmid = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
                                 targetSpace.mesh,
                                 targetSpace.element,
                                 sending_func.ufl_function_space().mesh,
                                 padding=1e-6,)
    receiving_func.interpolate(sending_func, nmm_interpolation_data=nmmid)
    return receiving_func

def solve(domain,u):
    '''
    Solve Poisson problem
    '''
    V = u.function_space
    # Dirichlet BC
    cdim = domain.topology.dim
    fdim = cdim-1
    domain.topology.create_connectivity(fdim,cdim)
    bfacets = mesh.exterior_facet_indices(domain.topology)
    u_bc = fem.Function(V)
    u_bc.interpolate(exact_sol)
    dirichlet_bcs = [fem.dirichletbc(u_bc, fem.locate_dofs_topological(V, fdim, bfacets))]
    # Forms
    u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
    a =  ufl.inner(ufl.grad(u),ufl.grad(v))*ufl.dx
    l =  rhs()*v*ufl.dx
    problem = dolfinx.fem.petsc.LinearProblem(a,l,
                                              bcs=dirichlet_bcs,)
    return problem.solve()

def write_post(mesh1,mesh2,u1,u2):
    writer1, writer2 = io.VTKFile(mesh1.comm, f"out/result1.vtk",'w'),io.VTKFile(mesh2.comm, f"out/result2.vtk",'w')
    writer1.write_function(u1)
    writer2.write_function(u2)
    writer1.close()
    writer2.close()

def main():
    # Mesh and problems
    points_side = 128
    mesh1 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    mesh2 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)

    (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),fem.functionspace(mesh2, ("Lagrange", 1)))
    (u1, u2)   = (fem.Function(V1),fem.Function(V2))
    # Solve Poisson problem on mesh1
    u1 = solve(mesh1,u1)
    # Interpolate from mesh1 to mesh2
    interpolate(u1,u2)
    # Write post
    #write_post(mesh1,mesh2,u1,u2)

if __name__=="__main__":
    profiling = True
    if profiling:
        lp = LineProfiler()
        lp.add_function(interpolate)
        lp.add_function(solve)
        lp.add_function(main)
        lp.add_function(write_post)
        lp.add_function(fem.Function.interpolate)
        lp_wrapper = lp(main)
        lp_wrapper()
        lp.print_stats()

I solve a Poisson problem on a mesh that has 32K elements and interpolate the solution to a mesh that has 16K elements.

I run this with the latest development version of the software (dolfinx from March 26th). The profiling library outputs:

Total time: 4.25439 s
File: /root/shared/cases/test_profiling/main.py
Function: main at line 59

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    59                                           def main():
    60                                               # Mesh and problems
    61         1        863.0    863.0      0.0      points_side = 128
    62         1  677506314.0    7e+08     15.9      mesh1 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    63         1  448408917.0    4e+08     10.5      mesh2 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)
    64                                           
    65         1   84272193.0    8e+07      2.0      (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),fem.functionspace(mesh2, ("Lagrange", 1)))
    66         1     329775.0 329775.0      0.0      (u1, u2)   = (fem.Function(V1),fem.Function(V2))
    67                                               # Solve Poisson problem on mesh1
    68         1  526576436.0    5e+08     12.4      u1 = solve(mesh1,u1)
    69                                               # Interpolate from mesh1 to mesh2
    70         1 2517292189.0    3e+09     59.2      interpolate(u1,u2)
    71                                               # Write post
    72                                               #write_post(mesh1,mesh2,u1,u2)


As you can see, the interpolation takes around 5 times more time. The padding used
for create_nonmatching_meshes_interpolation_data is rather small (1e-6), and null padding does not change these times.

Any help would be appreciated.

I feel like you are trying to crack a nut with a sledgehammer.

If you look at the breakdown of your interpolate function, 87% of the time is spent in
create_nonmatching_interpolation_data (see timings below).
This operation is not trivial, as it should work in parallel, and between any finite element space.
What it does is:

  1. Compute interpolation points for all dofs in receiving space
  2. Determine a single cell that owns every interpolation point (this works in parallel as well)
  3. Distribute these points among the processes

This is an operation that should only be done once for time-independent meshes.

In your case, the degrees of freedom will align exactly (as you are taking quad and triangular unit squares with matching vertices), and as far as I can tell you are only running this in serial.

Secondly, I cannot reproduce your timings (or profiling output), as the whole problem runs in less than a second on my system (ubuntu 20.04 using docker with dolfinx nightly)

Total time: 0.218408 s
File: /root/shared/mwe_refine.py
Function: main at line 76

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    76                                           def main():
    77                                               # Mesh and problems
    78         1        790.0    790.0      0.0      points_side = 128
    79         2   34180841.0    2e+07     15.7      mesh1 = mesh.create_unit_square(
    80         1       1542.0   1542.0      0.0          MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    81         2   23667372.0    1e+07     10.8      mesh2 = mesh.create_unit_square(
    82         1       2362.0   2362.0      0.0          MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)
    83                                           
    84         2    4835751.0    2e+06      2.2      (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),
    85         1    3466145.0    3e+06      1.6                  fem.functionspace(mesh2, ("Lagrange", 1)))
    86         1     125374.0 125374.0      0.1      (u1, u2) = (fem.Function(V1), fem.Function(V2))
    87                                               # Solve Poisson problem on mesh1
    88         1   62614797.0    6e+07     28.7      u1 = solve(mesh1, u1)
    89                                               # Interpolate from mesh1 to mesh2
    90         1   89512952.0    9e+07     41.0      interpolate(u1, u2)
    91                                               # Write post
    92                                               # write_post(mesh1,mesh2,u1,u2)
python3 mwe_refine.py 
Timer unit: 1e-09 s

Total time: 0.0901313 s
File: /root/shared/mwe_refine.py
Function: interpolate at line 20

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    20                                           def interpolate(sending_func,
    21                                                           receiving_func,):
    22                                               '''
    23                                               Interpolate sending_func to receiving_func,
    24                                               each comming from separate meshes
    25                                               '''
    26                                               # import time
    27         1       1313.0   1313.0      0.0      targetSpace = receiving_func.function_space
    28                                               # start = time.perf_counter()
    29         2   78678521.0    4e+07     87.3      nmmid = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
    30         1       12python3 mwe_refine.py 27.0   1227.0      0.0          targetSpace.mesh,
    31         1       2266.0   2266.0      0.0          targetSpace.element,
    32         1       2905.0   2905.0      0.0          sending_func.ufl_function_space().mesh,
    33         1        338.0    338.0      0.0          padding=1e-6,)
    34                                               # end = time.perf_counter()
    35                                               # print(end-start)
    36                                           
    37                                               # start = time.perf_counter()
    38         1   11444444.0    1e+07     12.7      receiving_func.interpolate(sending_func, nmm_interpolation_data=nmmid)
    39                                               # end = time.perf_counter()
    40                                               # print(end-start)
    41         1        244.0    244.0      0.0      return receiving_func

Total time: 0.0632843 s
File: /root/shared/mwe_refine.py
Function: solve at line 44

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    44                                           def solve(domain, u):
    45                                               '''
    46                                               Solve Poisson problem
    47                                               '''
    48         1       3414.0   3414.0      0.0      V = u.function_space
    49                                               # Dirichlet BC
    50         1       2491.0   2491.0      0.0      cdim = domain.topology.dim
    51         1        473.0    473.0      0.0      fdim = cdim-1
    52         1    8901403.0    9e+06     14.1      domain.topology.create_connectivity(fdim, cdim)
    53         1      61109.0  61109.0      0.1      bfacets = mesh.exterior_facet_indices(domain.topology)
    54         1      85707.0  85707.0      0.1      u_bc = fem.Function(V)
    55         1    4505463.0    5e+06      7.1      u_bc.interpolate(exact_sol)
    56         2      54425.0  27212.5      0.1      dirichlet_bcs = [fem.dirichletbc(
    57         1      92563.0  92563.0      0.1          u_bc, fem.locate_dofs_topological(V, fdim, bfacets))]
    58                                               # Forms
    59         1      85119.0  85119.0      0.1      u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
    60         1     384044.0 384044.0      0.6      a = ufl.inner(ufl.grad(u), ufl.grad(v))*ufl.dx
    61         1     171744.0 171744.0      0.3      l = rhs()*v*ufl.dx
    62         2    7430577.0    4e+06     11.7      problem = dolfinx.fem.petsc.LinearProblem(a, l,
    63         1        169.0    169.0      0.0                                                bcs=dirichlet_bcs,)
    64         1   41505636.0    4e+07     65.6      return problem.solve()
1 Like

Thank you for your explanation. I am not sure I understand your sledgehammer comment. Is there another way to interpolate functions across meshes?

From your description of create_nonmatching_interpolation_data it sounds like the cost should still be O(n) in terms of nodes so I can’t wrap my head around it taking more time than the solve _which should be O(n^2-3).*

Correction: After asking around me, I understand that optimally, the big O costs of solve and interpolate should be O(n) and O(n log n), respectively, but I am assured that an interpolate should take less time than a solve.

I understand your computer performs better than mine. We can scale the example but the ratios do not change much. Here is the output after twice refining the mesh. As you say, most of the time is taken by create_nonmatching_interpolation_data.

Total time: 76.4825 s
File: /root/shared/cases/test_profiling/main.py
Function: main at line 59

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    59                                           def main():
    60                                               # Mesh and problems
    61         1        682.0    682.0      0.0      points_side = 512
    62         1        1e+10    1e+10     16.0      mesh1 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    63         1 8001788882.0    8e+09     10.5      mesh2 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)
    64                                           
    65         1 1074880191.0    1e+09      1.4      (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),fem.functionspace(mesh2, ("Lagrange", 1)))
    66         1    1533554.0    2e+06      0.0      (u1, u2)   = (fem.Function(V1),fem.Function(V2))
    67                                               # Solve Poisson problem on mesh1
    68         1        1e+10    1e+10     16.1      u1 = solve(mesh1,u1)
    69                                               # Interpolate from mesh1 to mesh2
    70         1        4e+10    4e+10     56.0      interpolate(u1,u2)
    71                                               # Write post
    72                                               #write_post(mesh1,mesh2,u1,u2)
Timer unit: 1e-09 s

Total time: 42.8211 s
File: /root/shared/cases/test_profiling/main.py
Function: interpolate at line 16

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    16                                           def interpolate(sending_func,
    17                                                           receiving_func,):
    18                                               '''
    19                                               Interpolate sending_func to receiving_func,
    20                                               each comming from separate meshes
    21                                               '''
    22         1       3181.0   3181.0      0.0      targetSpace = receiving_func.function_space
    23         2        3e+10    2e+10     74.5      nmmid = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
    24         1       1793.0   1793.0      0.0                                   targetSpace.mesh,
    25         1       6999.0   6999.0      0.0                                   targetSpace.element,
    26         1       5499.0   5499.0      0.0                                   sending_func.ufl_function_space().mesh,
    27         1        241.0    241.0      0.0                                   padding=0,)
    28         1        1e+10    1e+10     25.5      receiving_func.interpolate(sending_func, nmm_interpolation_data=nmmid)
    29         1        411.0    411.0      0.0      return receiving_func

Total time: 12.3099 s
File: /root/shared/cases/test_profiling/main.py
Function: solve at line 31

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    31                                           def solve(domain,u):
    32                                               '''
    33                                               Solve Poisson problem
    34                                               '''
    35         1       2198.0   2198.0      0.0      V = u.function_space
    36                                               # Dirichlet BC
    37         1      16125.0  16125.0      0.0      cdim = domain.topology.dim
    38         1        547.0    547.0      0.0      fdim = cdim-1
    39         1 2746449053.0    3e+09     22.3      domain.topology.create_connectivity(fdim,cdim)
    40         1   14476771.0    1e+07      0.1      bfacets = mesh.exterior_facet_indices(domain.topology)
    41         1     850958.0 850958.0      0.0      u_bc = fem.Function(V)
    42         1  878984889.0    9e+08      7.1      u_bc.interpolate(exact_sol)
    43         1    2748550.0    3e+06      0.0      dirichlet_bcs = [fem.dirichletbc(u_bc, fem.locate_dofs_topological(V, fdim, bfacets))]
    44                                               # Forms
    45         1     156880.0 156880.0      0.0      u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
    46         1     815880.0 815880.0      0.0      a =  ufl.inner(ufl.grad(u),ufl.grad(v))*ufl.dx
    47         1     364210.0 364210.0      0.0      l =  rhs()*v*ufl.dx
    48         2 1787785348.0    9e+08     14.5      problem = dolfinx.fem.petsc.LinearProblem(a,l,
    49         1        175.0    175.0      0.0                                                bcs=dirichlet_bcs,)
    50         1 6877202316.0    7e+09     55.9      return problem.solve()

Total time: 0 s
File: /root/shared/cases/test_profiling/main.py
Function: write_post at line 52

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    52                                           def write_post(mesh1,mesh2,u1,u2):
    53                                               writer1, writer2 = io.VTKFile(mesh1.comm, f"out/result1.vtk",'w'),io.VTKFile(mesh2.comm, f"out/result2.vtk",'w')
    54                                               writer1.write_function(u1)
    55                                               writer2.write_function(u2)
    56                                               writer1.close()
    57                                               writer2.close()

Total time: 11.7836 s
File: /root/shared/dolfinx/python/dolfinx/fem/function.py
Function: interpolate at line 419

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   419                                               def interpolate(
   420                                                   self,
   421                                                   u: typing.Union[typing.Callable, Expression, Function],
   422                                                   cells: typing.Optional[np.ndarray] = None,
   423                                                   nmm_interpolation_data: typing.Optional[PointOwnershipData] = None,
   424                                               ) -> None:
   425                                                   """Interpolate an expression
   426                                           
   427                                                   Args:
   428                                                       u: The function, Expression or Function to interpolate.
   429                                                       cells: The cells to interpolate over. If `None` then all
   430                                                           cells are interpolated over.
   431                                                       nmm_interpolation_data: Data needed to interpolate functions defined on other meshes
   432                                                   """
   433         2       1588.0    794.0      0.0          if nmm_interpolation_data is None:
   434         1      35557.0  35557.0      0.0              x_dtype = self.function_space.mesh.geometry.x.dtype
   435         2      15652.0   7826.0      0.0              nmm_interpolation_data = PointOwnershipData(
   436         1       5628.0   5628.0      0.0                  src_owner=np.empty(0, dtype=np.int32),
   437         1       1199.0   1199.0      0.0                  dest_owners=np.empty(0, dtype=np.int32),
   438         1        910.0    910.0      0.0                  dest_points=np.empty(0, dtype=x_dtype),
   439         1        904.0    904.0      0.0                  dest_cells=np.empty(0, dtype=np.int32),
   440                                                       )
   441                                           
   442         2       1692.0    846.0      0.0          @singledispatch
   443         2      85495.0  42747.5      0.0          def _interpolate(u, cells: typing.Optional[np.ndarray] = None):
   444                                                       """Interpolate a cpp.fem.Function"""
   445                                                       self._cpp_object.interpolate(u, cells, nmm_interpolation_data)  # type: ignore
   446                                           
   447         2      11163.0   5581.5      0.0          @_interpolate.register(Function)
   448         2      28809.0  14404.5      0.0          def _(u: Function, cells: typing.Optional[np.ndarray] = None):
   449                                                       """Interpolate a fem.Function"""
   450                                                       self._cpp_object.interpolate(u._cpp_object, cells, nmm_interpolation_data)  # type: ignore
   451                                           
   452         2       4073.0   2036.5      0.0          @_interpolate.register(int)
   453         2      10782.0   5391.0      0.0          def _(u_ptr: int, cells: typing.Optional[np.ndarray] = None):
   454                                                       """Interpolate using a pointer to a function f(x)"""
   455                                                       self._cpp_object.interpolate_ptr(u_ptr, cells)  # type: ignore
   456                                           
   457         2       3709.0   1854.5      0.0          @_interpolate.register(Expression)
   458         2      10127.0   5063.5      0.0          def _(expr: Expression, cells: typing.Optional[np.ndarray] = None):
   459                                                       """Interpolate Expression for the set of cells"""
   460                                                       self._cpp_object.interpolate(expr._cpp_object, cells)  # type: ignore
   461                                           
   462         2        517.0    258.5      0.0          if cells is None:
   463         2       4456.0   2228.0      0.0              mesh = self.function_space.mesh
   464         2      72789.0  36394.5      0.0              map = mesh.topology.index_map(mesh.topology.dim)
   465         2     415742.0 207871.0      0.0              cells = np.arange(map.size_local + map.num_ghosts, dtype=np.int32)
   466                                           
   467         2        592.0    296.0      0.0          try:
   468                                                       # u is a Function or Expression (or pointer to one)
   469         2        1e+10    5e+09     92.5              _interpolate(u, cells)
   470         1        482.0    482.0      0.0          except TypeError:
   471                                                       # u is callable
   472         1       1032.0   1032.0      0.0              assert callable(u)
   473         1  677340494.0    7e+08      5.7              x = _cpp.fem.interpolation_coords(self._V.element, self._V.mesh.geometry, cells)
   474         1  200819188.0    2e+08      1.7              self._cpp_object.interpolate(np.asarray(u(x), dtype=self.dtype), cells)  # type: ignore


Would a specialized interpolate for CG1 take drastically less time?

PS: I am ultimately interested in solving moving meshes so I can’t avoid recomputing create_nonmatching_interpolation_data many times per iteration.

If you know that the nodes of one mesh will align with the other at any time, a specialized routine exploiting that relationship would be beneficial.

How did you build DOLFINx (are you using docker? Did you build from source in docker? what is the cmake_build_type).

Here are timings from another computer (slightly newer dell xps 13 with ubuntu 22.04)

Total time: 0.642664 s
File: /root/shared/mwe_profile.py
Function: main at line 59

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    59                                           def main():
    60                                               # Mesh and problems
    61         1        459.0    459.0      0.0      points_side = 128
    62         1   20398903.0    2e+07      3.2      mesh1 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    63         1   14487304.0    1e+07      2.3      mesh2 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)
    64                                           
    65         1  247386970.0    2e+08     38.5      (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),fem.functionspace(mesh2, ("Lagrange", 1)))
    66         1     135373.0 135373.0      0.0      (u1, u2)   = (fem.Function(V1),fem.Function(V2))
    67                                               # Solve Poisson problem on mesh1
    68         1  310448935.0    3e+08     48.3      u1 = solve(mesh1,u1)
    69                                               # Interpolate from mesh1 to mesh2
    70         1   49806116.0    5e+07      7.7      interpolate(u1,u2)
    71                                               # Write post
    72                                               #write_post(mesh1,mesh2,u1,u2)
Total time: 0.593025 s
File: /root/shared/mwe_profile.py
Function: main at line 59

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    59                                           def main():
    60                                               # Mesh and problems
    61         1        420.0    420.0      0.0      points_side = 256
    62         1   90943908.0    9e+07     15.3      mesh1 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    63         1   64938951.0    6e+07     11.0      mesh2 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)
    64                                           
    65         1    9910954.0    1e+07      1.7      (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),fem.functionspace(mesh2, ("Lagrange", 1)))
    66         1      96671.0  96671.0      0.0      (u1, u2)   = (fem.Function(V1),fem.Function(V2))
    67                                               # Solve Poisson problem on mesh1
    68         1  228231904.0    2e+08     38.5      u1 = solve(mesh1,u1)
    69                                               # Interpolate from mesh1 to mesh2
    70         1  198902125.0    2e+08     33.5      interpolate(u1,u2)
    71                                               # Write post
    72                                               #write_post(mesh1,mesh2,u1,u2)
Total time: 4.23752 s
File: /root/shared/mwe_profile.py
Function: main at line 59

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    59                                           def main():
    60                                               # Mesh and problems
    61         1        497.0    497.0      0.0      points_side = 512
    62         1  410854677.0    4e+08      9.7      mesh1 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    63         1  286373937.0    3e+08      6.8      mesh2 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)
    64                                           
    65         1   36981305.0    4e+07      0.9      (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),fem.functionspace(mesh2, ("Lagrange", 1)))
    66         1     231240.0 231240.0      0.0      (u1, u2)   = (fem.Function(V1),fem.Function(V2))
    67                                               # Solve Poisson problem on mesh1
    68         1 2557013013.0    3e+09     60.3      u1 = solve(mesh1,u1)
    69                                               # Interpolate from mesh1 to mesh2
    70         1  946066732.0    9e+08     22.3      interpolate(u1,u2)
    71                                               # Write post
    72                                               #write_post(mesh1,mesh2,u1,u2)

which shows drastically different results than what you get. Here solve becomes more an more costly compared to the interpolation.

Total time: 40.1687 s
File: /root/shared/mwe_profile.py
Function: main at line 59

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    59                                           def main():
    60                                               # Mesh and problems
    61         1        493.0    493.0      0.0      points_side = 1024
    62         1 1965245924.0    2e+09      4.9      mesh1 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    63         1 1355902478.0    1e+09      3.4      mesh2 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)
    64                                           
    65         1  136625088.0    1e+08      0.3      (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),fem.functionspace(mesh2, ("Lagrange", 1)))
    66         1    1232068.0    1e+06      0.0      (u1, u2)   = (fem.Function(V1),fem.Function(V2))
    67                                               # Solve Poisson problem on mesh1
    68         1        3e+10    3e+10     79.5      u1 = solve(mesh1,u1)
    69                                               # Interpolate from mesh1 to mesh2
    70         1 4795120953.0    5e+09     11.9      interpolate(u1,u2)
    71                                               # Write post
    72                                               #write_post(mesh1,mesh2,u1,u2)

As I cannot reproduce the behavior of your timings, there isn’t much more I can do.
For the timings above, i pulled the latest image of the following:

docker run -ti --network=host -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix -v $(pwd):/root/shared -w /root/shared --rm ghcr.io/fenics/dolfinx/dolfinx:nightly
1 Like

Thank you for the fast response @doken.

This is not the case in general, this here is just for benchmarking purposes.

I am using the development image and compiling myself using routines you provided in this post.
DOLFINX_CMAKE_BUILD_TYPE is RELWITHDEBUG.

I changed DOLFINX_CMAKE_BUILD_TYPE to RELEASE and it’s way faster. I also recover similar ratios to yours. Here is the profiling output with the same problem size as my previous post and the changed compilation flag:


Total time: 9.21398 s
File: /root/shared/cases/test_profiling/main.py
Function: main at line 59

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    59                                           def main():
    60                                               # Mesh and problems
    61         1        711.0    711.0      0.0      points_side = 512
    62         1  746203133.0    7e+08      8.1      mesh1 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    63         1  501749562.0    5e+08      5.4      mesh2 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)
    64                                           
    65         1   85418134.0    9e+07      0.9      (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),fem.functionspace(mesh2, ("Lagrange", 1)))
    66         1     602114.0 602114.0      0.0      (u1, u2)   = (fem.Function(V1),fem.Function(V2))
    67                                               # Solve Poisson problem on mesh1
    68         1 5740625915.0    6e+09     62.3      u1 = solve(mesh1,u1)
    69                                               # Interpolate from mesh1 to mesh2
    70         1 2139375620.0    2e+09     23.2      interpolate(u1,u2)
    71                                               # Write post
    72                                               #write_post(mesh1,mesh2,u1,u2)
Timer unit: 1e-09 s

Total time: 2.13935 s
File: /root/shared/cases/test_profiling/main.py
Function: interpolate at line 16

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    16                                           def interpolate(sending_func,
    17                                                           receiving_func,):
    18                                               '''
    19                                               Interpolate sending_func to receiving_func,
    20                                               each comming from separate meshes
    21                                               '''
    22         1       2992.0   2992.0      0.0      targetSpace = receiving_func.function_space
    23         2 1625102126.0    8e+08     76.0      nmmid = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
    24         1       1992.0   1992.0      0.0                                   targetSpace.mesh,
    25         1       4648.0   4648.0      0.0                                   targetSpace.element,
    26         1       4988.0   4988.0      0.0                                   sending_func.ufl_function_space().mesh,
    27         1        240.0    240.0      0.0                                   padding=0,)
    28         1  514233567.0    5e+08     24.0      receiving_func.interpolate(sending_func, nmm_interpolation_data=nmmid)
    29         1        662.0    662.0      0.0      return receiving_func

Total time: 5.73651 s
File: /root/shared/cases/test_profiling/main.py
Function: solve at line 31

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    31                                           def solve(domain,u):
    32                                               '''
    33                                               Solve Poisson problem
    34                                               '''
    35         1       2006.0   2006.0      0.0      V = u.function_space
    36                                               # Dirichlet BC
    37         1       6169.0   6169.0      0.0      cdim = domain.topology.dim
    38         1        434.0    434.0      0.0      fdim = cdim-1
    39         1  130067690.0    1e+08      2.3      domain.topology.create_connectivity(fdim,cdim)
    40         1     834879.0 834879.0      0.0      bfacets = mesh.exterior_facet_indices(domain.topology)
    41         1     358413.0 358413.0      0.0      u_bc = fem.Function(V)
    42         1   55060674.0    6e+07      1.0      u_bc.interpolate(exact_sol)
    43         1     791758.0 791758.0      0.0      dirichlet_bcs = [fem.dirichletbc(u_bc, fem.locate_dofs_topological(V, fdim, bfacets))]
    44                                               # Forms
    45         1     189370.0 189370.0      0.0      u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
    46         1    1423918.0    1e+06      0.0      a =  ufl.inner(ufl.grad(u),ufl.grad(v))*ufl.dx
    47         1     552468.0 552468.0      0.0      l =  rhs()*v*ufl.dx
    48         2  116950761.0    6e+07      2.0      problem = dolfinx.fem.petsc.LinearProblem(a,l,
    49         1        284.0    284.0      0.0                                                bcs=dirichlet_bcs,)
    50         1 5430272793.0    5e+09     94.7      return problem.solve()

Total time: 0 s
File: /root/shared/cases/test_profiling/main.py
Function: write_post at line 52

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    52                                           def write_post(mesh1,mesh2,u1,u2):
    53                                               writer1, writer2 = io.VTKFile(mesh1.comm, f"out/result1.vtk",'w'),io.VTKFile(mesh2.comm, f"out/result2.vtk",'w')
    54                                               writer1.write_function(u1)
    55                                               writer2.write_function(u2)
    56                                               writer1.close()
    57                                               writer2.close()


Total time: 0.568868 s
File: /root/shared/dolfinx/python/dolfinx/fem/function.py
Function: interpolate at line 419

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   419                                               def interpolate(
   420                                                   self,
   421                                                   u: typing.Union[typing.Callable, Expression, Function],
   422                                                   cells: typing.Optional[np.ndarray] = None,
   423                                                   nmm_interpolation_data: typing.Optional[PointOwnershipData] = None,
   424                                               ) -> None:
   425                                                   """Interpolate an expression
   426                                           
   427                                                   Args:
   428                                                       u: The function, Expression or Function to interpolate.
   429                                                       cells: The cells to interpolate over. If `None` then all
   430                                                           cells are interpolated over.
   431                                                       nmm_interpolation_data: Data needed to interpolate functions defined on other meshes
   432                                                   """
   433         2       2360.0   1180.0      0.0          if nmm_interpolation_data is None:
   434         1      14532.0  14532.0      0.0              x_dtype = self.function_space.mesh.geometry.x.dtype
   435         2       7494.0   3747.0      0.0              nmm_interpolation_data = PointOwnershipData(
   436         1       5191.0   5191.0      0.0                  src_owner=np.empty(0, dtype=np.int32),
   437         1       1168.0   1168.0      0.0                  dest_owners=np.empty(0, dtype=np.int32),
   438         1        930.0    930.0      0.0                  dest_points=np.empty(0, dtype=x_dtype),
   439         1        906.0    906.0      0.0                  dest_cells=np.empty(0, dtype=np.int32),
   440                                                       )
   441                                           
   442         2       2575.0   1287.5      0.0          @singledispatch
   443         2      81575.0  40787.5      0.0          def _interpolate(u, cells: typing.Optional[np.ndarray] = None):
   444                                                       """Interpolate a cpp.fem.Function"""
   445                                                       self._cpp_object.interpolate(u, cells, nmm_interpolation_data)  # type: ignore
   446                                           
   447         2      11195.0   5597.5      0.0          @_interpolate.register(Function)
   448         2      29680.0  14840.0      0.0          def _(u: Function, cells: typing.Optional[np.ndarray] = None):
   449                                                       """Interpolate a fem.Function"""
   450                                                       self._cpp_object.interpolate(u._cpp_object, cells, nmm_interpolation_data)  # type: ignore
   451                                           
   452         2       4630.0   2315.0      0.0          @_interpolate.register(int)
   453         2      10731.0   5365.5      0.0          def _(u_ptr: int, cells: typing.Optional[np.ndarray] = None):
   454                                                       """Interpolate using a pointer to a function f(x)"""
   455                                                       self._cpp_object.interpolate_ptr(u_ptr, cells)  # type: ignore
   456                                           
   457         2       3924.0   1962.0      0.0          @_interpolate.register(Expression)
   458         2       9801.0   4900.5      0.0          def _(expr: Expression, cells: typing.Optional[np.ndarray] = None):
   459                                                       """Interpolate Expression for the set of cells"""
   460                                                       self._cpp_object.interpolate(expr._cpp_object, cells)  # type: ignore
   461                                           
   462         2        795.0    397.5      0.0          if cells is None:
   463         2       5627.0   2813.5      0.0              mesh = self.function_space.mesh
   464         2      21926.0  10963.0      0.0              map = mesh.topology.index_map(mesh.topology.dim)
   465         2     412585.0 206292.5      0.1              cells = np.arange(map.size_local + map.num_ghosts, dtype=np.int32)
   466                                           
   467         2        767.0    383.5      0.0          try:
   468                                                       # u is a Function or Expression (or pointer to one)
   469         2  514046880.0    3e+08     90.4              _interpolate(u, cells)
   470         1        612.0    612.0      0.0          except TypeError:
   471                                                       # u is callable
   472         1        964.0    964.0      0.0              assert callable(u)
   473         1   22065639.0    2e+07      3.9              x = _cpp.fem.interpolation_coords(self._V.element, self._V.mesh.geometry, cells)
   474         1   32125337.0    3e+07      5.6              self._cpp_object.interpolate(np.asarray(u(x), dtype=self.dtype), cells)  # type: ignore

Appreciate the help @dokken!

I attach profiling data from DEBUG executables for completeness:

Timer unit: 1e-09 s
Total time: 85.7511 s
File: /root/shared/cases/test_profiling/main.py
Function: main at line 59

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    59                                           def main():
    60                                               # Mesh and problems
    61         1        642.0    642.0      0.0      points_side = 512
    62         1        1e+10    1e+10     16.4      mesh1 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.triangle)
    63         1 8582753223.0    9e+09     10.0      mesh2 = mesh.create_unit_square(MPI.COMM_WORLD, points_side, points_side, mesh.CellType.quadrilateral)
    64                                           
    65         1 1521525088.0    2e+09      1.8      (V1, V2) = (fem.functionspace(mesh1, ("Lagrange", 1)),fem.functionspace(mesh2, ("Lagrange", 1)))
    66         1    1537440.0    2e+06      0.0      (u1, u2)   = (fem.Function(V1),fem.Function(V2))
    67                                               # Solve Poisson problem on mesh1
    68         1        1e+10    1e+10     16.0      u1 = solve(mesh1,u1)
    69                                               # Interpolate from mesh1 to mesh2
    70         1        5e+10    5e+10     55.8      interpolate(u1,u2)
    71                                               # Write post
    72                                               #write_post(mesh1,mesh2,u1,u2)

Total time: 47.8807 s
File: /root/shared/cases/test_profiling/main.py
Function: interpolate at line 16

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    16                                           def interpolate(sending_func,
    17                                                           receiving_func,):
    18                                               '''
    19                                               Interpolate sending_func to receiving_func,
    20                                               each comming from separate meshes
    21                                               '''
    22         1       3155.0   3155.0      0.0      targetSpace = receiving_func.function_space
    23         2        4e+10    2e+10     76.1      nmmid = dolfinx.fem.create_nonmatching_meshes_interpolation_data(
    24         1       2348.0   2348.0      0.0                                   targetSpace.mesh,
    25         1       6606.0   6606.0      0.0                                   targetSpace.element,
    26         1       4138.0   4138.0      0.0                                   sending_func.ufl_function_space().mesh,
    27         1        273.0    273.0      0.0                                   padding=0,)
    28         1        1e+10    1e+10     23.9      receiving_func.interpolate(sending_func, nmm_interpolation_data=nmmid)
    29         1        433.0    433.0      0.0      return receiving_func

Total time: 13.7008 s
File: /root/shared/cases/test_profiling/main.py
Function: solve at line 31

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    31                                           def solve(domain,u):
    32                                               '''
    33                                               Solve Poisson problem
    34                                               '''
    35         1       1966.0   1966.0      0.0      V = u.function_space
    36                                               # Dirichlet BC
    37         1      17482.0  17482.0      0.0      cdim = domain.topology.dim
    38         1        768.0    768.0      0.0      fdim = cdim-1
    39         1 2958914361.0    3e+09     21.6      domain.topology.create_connectivity(fdim,cdim)
    40         1   14939164.0    1e+07      0.1      bfacets = mesh.exterior_facet_indices(domain.topology)
    41         1     876764.0 876764.0      0.0      u_bc = fem.Function(V)
    42         1 1203069563.0    1e+09      8.8      u_bc.interpolate(exact_sol)
    43         1    3790859.0    4e+06      0.0      dirichlet_bcs = [fem.dirichletbc(u_bc, fem.locate_dofs_topological(V, fdim, bfacets))]
    44                                               # Forms
    45         1     155273.0 155273.0      0.0      u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
    46         1     918670.0 918670.0      0.0      a =  ufl.inner(ufl.grad(u),ufl.grad(v))*ufl.dx
    47         1     325844.0 325844.0      0.0      l =  rhs()*v*ufl.dx
    48         2 2270950279.0    1e+09     16.6      problem = dolfinx.fem.petsc.LinearProblem(a,l,
    49         1        148.0    148.0      0.0                                                bcs=dirichlet_bcs,)
    50         1 7246822885.0    7e+09     52.9      return problem.solve()

Total time: 0 s
File: /root/shared/cases/test_profiling/main.py
Function: write_post at line 52

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    52                                           def write_post(mesh1,mesh2,u1,u2):
    53                                               writer1, writer2 = io.VTKFile(mesh1.comm, f"out/result1.vtk",'w'),io.VTKFile(mesh2.comm, f"out/result2.vtk",'w')
    54                                               writer1.write_function(u1)
    55                                               writer2.write_function(u2)
    56                                               writer1.close()
    57                                               writer2.close()

Total time: 12.6526 s
File: /root/shared/dolfinx/python/dolfinx/fem/function.py
Function: interpolate at line 419

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   419                                               def interpolate(
   420                                                   self,
   421                                                   u: typing.Union[typing.Callable, Expression, Function],
   422                                                   cells: typing.Optional[np.ndarray] = None,
   423                                                   nmm_interpolation_data: typing.Optional[PointOwnershipData] = None,
   424                                               ) -> None:
   425                                                   """Interpolate an expression
   426                                           
   427                                                   Args:
   428                                                       u: The function, Expression or Function to interpolate.
   429                                                       cells: The cells to interpolate over. If `None` then all
   430                                                           cells are interpolated over.
   431                                                       nmm_interpolation_data: Data needed to interpolate functions defined on other meshes
   432                                                   """
   433         2       2236.0   1118.0      0.0          if nmm_interpolation_data is None:
   434         1      31264.0  31264.0      0.0              x_dtype = self.function_space.mesh.geometry.x.dtype
   435         2       8619.0   4309.5      0.0              nmm_interpolation_data = PointOwnershipData(
   436         1       4718.0   4718.0      0.0                  src_owner=np.empty(0, dtype=np.int32),
   437         1       1127.0   1127.0      0.0                  dest_owners=np.empty(0, dtype=np.int32),
   438         1       1035.0   1035.0      0.0                  dest_points=np.empty(0, dtype=x_dtype),
   439         1       1064.0   1064.0      0.0                  dest_cells=np.empty(0, dtype=np.int32),
   440                                                       )
   441                                           
   442         2       1957.0    978.5      0.0          @singledispatch
   443         2      92576.0  46288.0      0.0          def _interpolate(u, cells: typing.Optional[np.ndarray] = None):
   444                                                       """Interpolate a cpp.fem.Function"""
   445                                                       self._cpp_object.interpolate(u, cells, nmm_interpolation_data)  # type: ignore
   446                                           
   447         2      15112.0   7556.0      0.0          @_interpolate.register(Function)
   448         2      41750.0  20875.0      0.0          def _(u: Function, cells: typing.Optional[np.ndarray] = None):
   449                                                       """Interpolate a fem.Function"""
   450                                                       self._cpp_object.interpolate(u._cpp_object, cells, nmm_interpolation_data)  # type: ignore
   451                                           
   452         2       7014.0   3507.0      0.0          @_interpolate.register(int)
   453         2      17571.0   8785.5      0.0          def _(u_ptr: int, cells: typing.Optional[np.ndarray] = None):
   454                                                       """Interpolate using a pointer to a function f(x)"""
   455                                                       self._cpp_object.interpolate_ptr(u_ptr, cells)  # type: ignore
   456                                           
   457         2       6220.0   3110.0      0.0          @_interpolate.register(Expression)
   458         2      14579.0   7289.5      0.0          def _(expr: Expression, cells: typing.Optional[np.ndarray] = None):
   459                                                       """Interpolate Expression for the set of cells"""
   460                                                       self._cpp_object.interpolate(expr._cpp_object, cells)  # type: ignore
   461                                           
   462         2        991.0    495.5      0.0          if cells is None:
   463         2       5276.0   2638.0      0.0              mesh = self.function_space.mesh
   464         2      81276.0  40638.0      0.0              map = mesh.topology.index_map(mesh.topology.dim)
   465         2     432863.0 216431.5      0.0              cells = np.arange(map.size_local + map.num_ghosts, dtype=np.int32)
   466                                           
   467         2       1003.0    501.5      0.0          try:
   468                                                       # u is a Function or Expression (or pointer to one)
   469         2        1e+10    6e+09     90.5              _interpolate(u, cells)
   470         1        668.0    668.0      0.0          except TypeError:
   471                                                       # u is callable
   472         1       1221.0   1221.0      0.0              assert callable(u)
   473         1  759942823.0    8e+08      6.0              x = _cpp.fem.interpolation_coords(self._V.element, self._V.mesh.geometry, cells)
   474         1  442300897.0    4e+08      3.5              self._cpp_object.interpolate(np.asarray(u(x), dtype=self.dtype), cells)  # type: ignore

Running times are similar to the RELWITHDEBINFO executables. If this is reproducible, I think it means that a proper profiling with debugging executables will also conclude that interpolate is slower than solve. Since the release version is fine, that is what matters the most. Thank you for your help @dokken .

As petsc is not built in debug mode it is not sensible to compare timings between them. Debug mode has many asserts for sanity checks

1 Like