What does apply_lifting exactly do?

I am observing that apply_lifting is relatively slow. The time it takes is comparable to the assembly of the vector itself.

MWE

import dolfinx
import dolfinx.fem.petsc
import matplotlib.pyplot as plt
import numpy as np
import pyvista
import ufl
from mpi4py import MPI
from petsc4py import PETSc

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

def main():
    def q(u):
        return 1 + u**2

    nelems = 10000
    domain = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, nelems)
    x = ufl.SpatialCoordinate(domain)
    u_ufl = 1 + 2 * x[0]
    f = - ufl.div(q(u_ufl) * ufl.grad(u_ufl))

    def u_exact(x):
        return eval(str(u_ufl))

    V = dolfinx.fem.functionspace(domain, ("Lagrange", 1))
    u_D = dolfinx.fem.Function(V)
    u_D.interpolate(u_exact)
    fdim = domain.topology.dim - 1
    domain.topology.create_connectivity(fdim, fdim + 1)
    boundary_facets = dolfinx.mesh.exterior_facet_indices(domain.topology)
    bc = dolfinx.fem.dirichletbc(u_D, dolfinx.fem.locate_dofs_topological(V, fdim, boundary_facets))

    uh = dolfinx.fem.Function(V)
    v = ufl.TestFunction(V)
    F = q(uh) * ufl.dot(ufl.grad(uh), ufl.grad(v)) * ufl.dx - f * v * ufl.dx
    J = ufl.derivative(F, uh)
    residual = dolfinx.fem.form(F)
    jacobian = dolfinx.fem.form(J)

    du = dolfinx.fem.Function(V)
    A = dolfinx.fem.petsc.create_matrix(jacobian)
    L = dolfinx.fem.petsc.create_vector(residual)
    solver = PETSc.KSP().create(domain.comm)
    solver.setOperators(A)

    i = 0
    error = dolfinx.fem.form(ufl.inner(uh - u_ufl, uh - u_ufl) * ufl.dx(metadata={"quadrature_degree": 4}))
    L2_error = []
    du_norm = []

    max_iterations = 25
    while i < max_iterations:
        # Assemble Jacobian and residual
        with L.localForm() as loc_L:
            loc_L.set(0)
        A.zeroEntries()
        dolfinx.fem.petsc.assemble_matrix(A, jacobian, bcs=[bc])
        A.assemble()
        dolfinx.fem.petsc.assemble_vector(L, residual)
        L.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        L.scale(-1)

        # Compute b - J(u_D-u_(i-1))
        dolfinx.fem.petsc.apply_lifting(L, [jacobian], [[bc]], x0=[uh.x.petsc_vec], alpha=1)
        # Set du|_bc = u_{i-1}-u_D
        dolfinx.fem.petsc.set_bc(L, [bc], uh.x.petsc_vec, 1.0)
        L.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)

        # Solve linear problem
        solver.solve(L, du.x.petsc_vec)
        du.x.scatter_forward()

        # Update u_{i+1} = u_i + delta u_i
        uh.x.array[:] += du.x.array
        i += 1

        # Compute norm of update
        correction_norm = du.x.petsc_vec.norm(0)

        # Compute L2 error comparing to the analytical solution
        L2_error.append(np.sqrt(domain.comm.allreduce(dolfinx.fem.assemble_scalar(error), op=MPI.SUM)))
        du_norm.append(correction_norm)

        print(f"Iteration {i}: Correction norm {correction_norm}, L2 error: {L2_error[-1]}")
        if correction_norm < 1e-10:
            break

if __name__=="__main__":
    from line_profiler import LineProfiler
    lp = LineProfiler()
    lp_wrapper = lp(main)
    lp_wrapper()
    profiling_file = f"profiling_rank{rank}.txt"
    with open(profiling_file, 'w') as pf:
        lp.print_stats(stream=pf)

, which I basely copied from this tutorial from @dokken and imported the LineProfiler library to obtain profiling results. In particular, for this 1D 10K elements example:

Timer unit: 1e-09 s
Total time: 69.9461 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================

    60        10        1e+10    1e+09     20.5          dolfinx.fem.petsc.assemble_vector(L, residual)
    65        10 7422697952.0    7e+08     10.6          dolfinx.fem.petsc.apply_lifting(L, [jacobian], [[bc]], x0=[uh.x.petsc_vec], alpha=1)

using a Developper compiled version of the latest dolfinx ran from the test-env container. The weird thing for me is that for this example with order 1 elements, only two entries of the RHS vector should be modified by apply_lifting but it still takes half the time it takes assemble_vector to assemble 10K elements. So I am wondering what is apply_lifting actually doing. I past the profiling dump here in case anyone is interested in comparing:

Timer unit: 1e-09 s

Total time: 69.9461 s
File: /root/shared/jorgensd_nr_example_dirichlet/main.py
Function: main at line 13

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    13                                           def main():
    14         1    2002979.0    2e+06      0.0      def q(u):
    15                                                   return 1 + u**2
    16                                           
    17         1      50000.0  50000.0      0.0      nelems = 10000
    18         1 8136370529.0    8e+09     11.6      domain = dolfinx.mesh.create_unit_interval(MPI.COMM_WORLD, nelems)
    19         1    3159787.0    3e+06      0.0      x = ufl.SpatialCoordinate(domain)
    20         1   17659926.0    2e+07      0.0      u_ufl = 1 + 2 * x[0]
    21         1   57820158.0    6e+07      0.1      f = - ufl.div(q(u_ufl) * ufl.grad(u_ufl))
    22                                           
    23         1      41700.0  41700.0      0.0      def u_exact(x):
    24                                                   return eval(str(u_ufl))
    25                                           
    26         1  280213824.0    3e+08      0.4      V = dolfinx.fem.functionspace(domain, ("Lagrange", 1))
    27         1   57316960.0    6e+07      0.1      u_D = dolfinx.fem.Function(V)
    28         1 2015771241.0    2e+09      2.9      u_D.interpolate(u_exact)
    29         1     211899.0 211899.0      0.0      fdim = domain.topology.dim - 1
    30         1   37620242.0    4e+07      0.1      domain.topology.create_connectivity(fdim, fdim + 1)
    31         1   14590039.0    1e+07      0.0      boundary_facets = dolfinx.mesh.exterior_facet_indices(domain.topology)
    32         1  101360975.0    1e+08      0.1      bc = dolfinx.fem.dirichletbc(u_D, dolfinx.fem.locate_dofs_topological(V, fdim, boundary_facets))
    33                                           
    34         1    5075178.0    5e+06      0.0      uh = dolfinx.fem.Function(V)
    35         1    4625481.0    5e+06      0.0      v = ufl.TestFunction(V)
    36         1  140063512.0    1e+08      0.2      F = q(uh) * ufl.dot(ufl.grad(uh), ufl.grad(v)) * ufl.dx - f * v * ufl.dx
    37         1   77671074.0    8e+07      0.1      J = ufl.derivative(F, uh)
    38         1 6058416009.0    6e+09      8.7      residual = dolfinx.fem.form(F)
    39         1  287922433.0    3e+08      0.4      jacobian = dolfinx.fem.form(J)
    40                                           
    41         1    8652710.0    9e+06      0.0      du = dolfinx.fem.Function(V)
    42         1  670876085.0    7e+08      1.0      A = dolfinx.fem.petsc.create_matrix(jacobian)
    43         1  384135142.0    4e+08      0.5      L = dolfinx.fem.petsc.create_vector(residual)
    44         1    9427903.0    9e+06      0.0      solver = PETSc.KSP().create(domain.comm)
    45         1   30273988.0    3e+07      0.0      solver.setOperators(A)
    46                                           
    47         1      26799.0  26799.0      0.0      i = 0
    48         1  187641566.0    2e+08      0.3      error = dolfinx.fem.form(ufl.inner(uh - u_ufl, uh - u_ufl) * ufl.dx(metadata={"quadrature_degree": 4}))
    49         1      15900.0  15900.0      0.0      L2_error = []
    50         1      39399.0  39399.0      0.0      du_norm = []
    51                                           
    52         1      12800.0  12800.0      0.0      max_iterations = 25
    53        10     448797.0  44879.7      0.0      while i < max_iterations:
    54                                                   # Assemble Jacobian and residual
    55        20   20229711.0    1e+06      0.0          with L.localForm() as loc_L:
    56        10   60706273.0    6e+06      0.1              loc_L.set(0)
    57        10  112133110.0    1e+07      0.2          A.zeroEntries()
    58        10        2e+10    2e+09     25.9          dolfinx.fem.petsc.assemble_matrix(A, jacobian, bcs=[bc])
    59        10   74877648.0    7e+06      0.1          A.assemble()
    60        10        1e+10    1e+09     20.5          dolfinx.fem.petsc.assemble_vector(L, residual)
    61        10   69274559.0    7e+06      0.1          L.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
    62        10   15525009.0    2e+06      0.0          L.scale(-1)
    63                                           
    64                                                   # Compute b - J(u_D-u_(i-1))
    65        10 7422697952.0    7e+08     10.6          dolfinx.fem.petsc.apply_lifting(L, [jacobian], [[bc]], x0=[uh.x.petsc_vec], alpha=1)
    66                                                   # Set du|_bc = u_{i-1}-u_D
    67        10   35894777.0    4e+06      0.1          dolfinx.fem.petsc.set_bc(L, [bc], uh.x.petsc_vec, 1.0)
    68        10   21038173.0    2e+06      0.0          L.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)
    69                                           
    70                                                   # Solve linear problem
    71        10 1229572810.0    1e+08      1.8          solver.solve(L, du.x.petsc_vec)
    72        10    9607860.0 960786.0      0.0          du.x.scatter_forward()
    73                                           
    74                                                   # Update u_{i+1} = u_i + delta u_i
    75        10   14458225.0    1e+06      0.0          uh.x.array[:] += du.x.array
    76        10     212400.0  21240.0      0.0          i += 1
    77                                           
    78                                                   # Compute norm of update
    79        10    9343964.0 934396.4      0.0          correction_norm = du.x.petsc_vec.norm(0)
    80                                           
    81                                                   # Compute L2 error comparing to the analytical solution
    82        10 9756239235.0    1e+09     13.9          L2_error.append(np.sqrt(domain.comm.allreduce(dolfinx.fem.assemble_scalar(error), op=MPI.SUM)))
    83        10     288500.0  28850.0      0.0          du_norm.append(correction_norm)
    84                                           
    85        10   38677879.0    4e+06      0.1          print(f"Iteration {i}: Correction norm {correction_norm}, L2 error: {L2_error[-1]}")
    86        10     462300.0  46230.0      0.0          if correction_norm < 1e-10:
    87         1     217600.0 217600.0      0.0              break

Looked a bit more into it. So inside of apply_lifting it spends 1/3 of the time “packing coefficients” and 2/3 in the c++ apply_lifting.

   332        10       7600.0    760.0      0.0      x0 = [] if x0 is None else x0
   333        10       4100.0    410.0      0.0      constants = (
   334        20      24099.0   1205.0      0.0          [
   335        10     250498.0  25049.8      0.3              _pack_constants(form._cpp_object) if form is not None else np.array([], dtype=b.dtype)
   336        20       6600.0    330.0      0.0              for form in a
   337                                                   ]
   338        10       3600.0    360.0      0.0          if constants is None
   339                                                   else constants
   340                                               )
   341        10       3800.0    380.0      0.0      coeffs = (
   342        20   27053587.0    1e+06     36.0          [{} if form is None else _pack_coefficients(form._cpp_object) for form in a]
   343        10       4300.0    430.0      0.0          if coeffs is None
   344                                                   else coeffs
   345                                               )
   346        20      17999.0    900.0      0.0      _a = [None if form is None else form._cpp_object for form in a]
   347        30      40500.0   1350.0      0.1      _bcs = [[bc._cpp_object for bc in bcs0] for bcs0 in bcs]
   348        10   47811522.0    5e+06     63.6      _cpp.fem.apply_lifting(b, _a, constants, coeffs, _bcs, x0, alpha)

Would appreciate any insight. It is not an urgent matter.

I’ve written out what apply lifting does in detail at:
https://jsdokken.com/FEniCS-workshop/src/form_compilation/alternate_form.html#lifting

There is also an alternative approach for lifting sketched out in
https://jsdokken.com/FEniCS-workshop/src/deep_dive/lifting.html#alternative-lifting-procedure

1 Like