How to efficiently use `dolfinx.fem.Expression`?

Hello,

I have created a stationary Navier-Stokes solver in 3D with pseudo time stepping, but the code spends a lot of time when calling dolfinx.fem.Expression. I am using dolfinx installed via Spack on a cluster with 192 processes for my calculation.

I have created a k function on a DG-0 space (named T) that represents the time variable. After each iteration, I update dt using a CFL slope:

vnorm = sqrt(dot(v,v))
dt_local = CFL * h/vnorm
dt_expr = dolfinx.fem.Expression(dt_local, T.element.interpolation_points())
k.interpolate(dt_expr)

I then noticed that calling dolfinx.fem.Expression could take more than 30 seconds to execute for certain processors. This causes a huge loss of time between each Navier-Stokes iteration. Since my goal is to perform optimisations with numerous calculations, I am trying to achieve the best possible performance

Thank you very much!

Please provide a minimal reproducible code.

For instance, are you calling dt_expr = inside a loop?
If so, move it outside the loop, as you should only compile these expressions once (as they generate C code, and you are probably flooding your cache, because either CFL, h or vnorm is not wrapped up in a product/combination of dolfinx.fem.Function, dolfinx.fem.Constant, but instead multiplies with floating point values (forcing recompilation).

This is also explained in more detail in: Efficient usage of the Unified Form Language — FEniCS Workshop

Thanks for your quick reply!

Here is some code that reproduces the problem. I tried to reduce the number of lines as much as possible, but since this only happens in large 3D cases and I am using a ‘modified augmented Lagrangian’ method, the code is still a bit heavy. Nevertheless, lines 265–349 contain the Navier-Stokes loop with the CFL and k (time variable) update.

To respond to your suggestions, if I leave dt_expr = only outside the loop, then the values of k are not correct with the CFL, because I get much slower convergence (steady residual norm) and a very different k norm (increases little compared to when dt_expr = is in the loop).

The mesh: https://drive.google.com/file/d/1Z-lPthDMXNWDLOTaIFLggHgpx9x5-kmR/view?usp=drive_link
The code :

import sys
import petsc4py
petsc4py.init(sys.argv)

from mpi4py import MPI
import dolfinx
import time
from ufl import split, TestFunctions, TestFunction, derivative, grad, inner, dx, div, Measure, TrialFunction, sqrt, dot
import ufl
import basix.ufl
from petsc4py import PETSc
from dolfinx.fem.petsc import  apply_lifting, assemble_vector, set_bc, assemble_matrix, create_matrix, create_vector
from dolfinx.fem import Function, functionspace, Constant, dirichletbc, locate_dofs_topological, form
from dolfinx import mesh
import numpy as np
from dolfinx.io.gmshio import read_from_msh

comm = MPI.COMM_WORLD

msh_file = "Mesh/Rectangle3D.msh"
mesh, cell_tags, facet_tags = read_from_msh(msh_file, MPI.COMM_WORLD, gdim=3)

tags = {}
tags["inlet"] = 1
tags["outlet"] = 2
tags["walls"] = 3

def alpha_initial2(x):
    # Paramètres de la bande inclinée (à définir selon ton cas)
    x_start = 1.45   # Début en x
    x_end = 1.75     # Fin en x
    y_start = 0.5   # Début en y
    y_end = 1.0     # Fin en y
    z_center = 0.75  # Centre en z
    x_center = 1.6  # Centre en x pour l'inclinaison
    m = 1        # pente tan(45°) = 1
    d = 0.075         # Demi-épaisseur de la bande

    # Initialiser alpha à 1 (fluide) partout
    values = np.ones(x.shape[1])

    # Définir la condition pour la bande inclinée
    z_line = z_center + m * (x[0] - x_center)  # Équation de la bande en z
    mask = (
        (x_start <= x[0]) & (x[0] <= x_end) &  # Limite en x
        (y_start <= x[1]) & (x[1] <= y_end) &  # Limite en y
        (np.abs(x[2] - z_line) <= d))

    # Mettre alpha = 0 (solide) dans la bande
    values[mask] = 0.0

    return values


fdim = mesh.topology.dim-1
metadata = {"quadrature_degree": 4}
dx = Measure("dx", domain=mesh, metadata=metadata)

v_max_inlet = Constant(mesh, PETSc.ScalarType(5))
Re = 1000
rho_ = Constant(mesh, PETSc.ScalarType(1))
mu_ = Constant(mesh, PETSc.ScalarType(rho_*v_max_inlet*2/Re))
nu_ = Constant(mesh, PETSc.ScalarType(mu_/rho_))

########################## ELEMENT #############################################
V_element2 = basix.ufl.element("CG", mesh.basix_cell(), 2, shape = (mesh.topology.dim,))
V_element = basix.ufl.element("CG", mesh.basix_cell(), 1)
Bubble_element = basix.ufl.element("Bubble", mesh.basix_cell(), 4)
enriched_element = basix.ufl.enriched_element([V_element, Bubble_element])
Mini_element = basix.ufl.blocked_element(enriched_element, shape = (mesh.topology.dim,))
P_element = basix.ufl.element("CG", mesh.basix_cell(), 1)
U_element = basix.ufl.mixed_element([Mini_element, P_element])
U = functionspace(mesh, U_element);U_v = U.sub(0); U_p = U.sub(1)
if comm.rank ==0:
    print("Nb de DoFs ", U.dofmap.index_map.size_global)

w_v, w_p = TestFunctions(U)
uh = Function(U)
v, p = split(uh)
uh_prev = Function(U)
v_prev, p_prev = split(uh_prev)

T = functionspace(mesh, basix.ufl.element("DG", mesh.basix_cell(), 0))
k = Function(T)

A_element = basix.ufl.element("CG", mesh.basix_cell(), 1)
A = functionspace(mesh, A_element)
alpha = Function(A)

k_min = 0  # Valeur minimale de k
k_max = 10000 # Valeur maximale de k
q_ = 0.01

def kappa(alpha):
    return k_max + (k_min - k_max) * alpha * (1. + q_) / (alpha + q_)

########################## BC #############################################

def inlet_velocity_expr(x):
    values = np.zeros((mesh.topology.dim, x.shape[1]), dtype=PETSc.ScalarType)
    values[0] = v_max_inlet  # v_x = v_max_inlet, v_y = 0, v_z = 0
    return values

V = U.sub(0).collapse()[0]
P = U.sub(1).collapse()[0]
inlet_velocity = Function(V)
inlet_velocity.interpolate(inlet_velocity_expr)
noslip_velocity = Function(V)
noslip_velocity.x.array[:] = 0.0

# Localisation des DOFs
dofs_inlet = locate_dofs_topological((U_v, V), fdim, facet_tags.find(tags["inlet"]))
dofs_outlet = locate_dofs_topological(U.sub(1), fdim, facet_tags.find(tags["outlet"]))
dofs_walls = locate_dofs_topological((U_v, V), fdim, facet_tags.find(tags["walls"]))

# Conditions de Dirichlet
bc_inlet = dirichletbc(inlet_velocity, dofs_inlet, U_v)
bc_outlet = dirichletbc(Constant(mesh, PETSc.ScalarType(0.0)), dofs_outlet, U_p)
bc_walls = dirichletbc(noslip_velocity, dofs_walls, U_v)
bcs = [bc_inlet, bc_walls]

uh.sub(0).interpolate(inlet_velocity_expr)
uh.sub(1).interpolate(lambda x: np.zeros(x.shape[1]))
uh.x.scatter_forward()

alpha.interpolate(alpha_initial2)
alpha.x.scatter_forward()

mu_T = 0
nu_T = 0


########################## Weak form  #############################################

F_PV = (-div(v) * w_p * dx + nu_ * inner(grad(v),grad(w_v)) * dx + inner(grad(v)*v, w_v) * dx - p * div(w_v) * dx + inner(kappa(alpha) * v, w_v)*dx)

# Div-grad and Pressure matrix
gamma_ = 0.20
F_PV +=  gamma_*div(v)*div(w_v) *dx
Press = functionspace(mesh, P_element)
q_p, r_p = ufl.TrialFunction(Press), ufl.TestFunction(Press)
S_form = form(-1.0/(gamma_ + 1.0/Re) * q_p*r_p *dx)
Mp = assemble_matrix(S_form)
Mp.assemble()
Mp_norm = Mp.norm()

J_PV = 1* inner(v, w_v)/k *dx + F_PV

class NonlinearPDE_PseudoProblem:
    def __init__(self, F, Jac, u, bc):
        self.V = u.function_space
        self.du = TrialFunction(self.V)
        self.L = form(F)
        self.a = form(derivative(Jac, u, self.du))
        self.bc = bc
        self._F, self._J = None, None
        self.u = u
        self.norm_res_l2 = 0
        self.norm_res_inf = 0

    def F(self, snes, x, F):
        """Assemble residual vector."""
        x.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
        x.copy(self.u.x.petsc_vec)
        self.u.x.petsc_vec.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)

        with F.localForm() as f_local:
            f_local.set(0.0)
        assemble_vector(F, self.L)
        apply_lifting(F, [self.a], bcs=[self.bc], x0=[x], alpha=-1.0)
        F.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        set_bc(F, self.bc, x, -1.0)
        self.norm_res_l2 = F.norm(PETSc.NormType.NORM_2)
        self.norm_res_inf = F.norm(PETSc.NormType.NORM_INFINITY)

    def J(self, snes, x, J, P):
        """Assemble Jacobian matrix."""
        x.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
        x.copy(self.u.x.petsc_vec)
        self.u.x.petsc_vec.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)

        J.zeroEntries()
        assemble_matrix(J, self.a, bcs=self.bc)
        J.assemble()
        j_norm = J.norm(PETSc.NormType.NORM_FROBENIUS)

        snes.getKSP().setOperators(J)
        snes.getKSP().setUp()

        vx_ksp, vy_ksp, vz_ksp, p_ksp = snes.getKSP().getPC().getFieldSplitSubKSP()
        p_ksp.setOperators(Mp, Mp)


##################################### PETSc Solver #############################################
#Snes solver
problem = NonlinearPDE_PseudoProblem(F_PV,J_PV, uh, bcs)

J = create_matrix(problem.a)
b = create_vector(problem.L)

# Create Newton solver and solve
snes = PETSc.SNES().create()
snes.setFunction(problem.F, b)
snes.setJacobian(problem.J, J)

problem_prefix = "ns_"
snes.setOptionsPrefix(problem_prefix)

opts = PETSc.Options()
opts.prefixPush(problem_prefix)
opts["snes_type"] = "newtonls"
opts['snes_linesearch_type'] ='basic'
#opts['snes_monitor'] = None
#opts["snes_converged_reason"] = None

_, map_vx = U.sub(0).sub(0).collapse()
_, map_vy = U.sub(0).sub(1).collapse()
_, map_vz = U.sub(0).sub(2).collapse()
_, map_vel = U.sub(0).collapse()
_, map_p  = U.sub(1).collapse()

dofmap_U = U.dofmap
U_indexmap = dofmap_U.index_map
localdofs_vx = np.setdiff1d(U_indexmap.local_to_global(np.array(map_vx)), U_indexmap.ghosts)
localdofs_vy = np.setdiff1d(U_indexmap.local_to_global(np.array(map_vy)), U_indexmap.ghosts)
localdofs_vz = np.setdiff1d(U_indexmap.local_to_global(np.array(map_vz)), U_indexmap.ghosts)
localdofs_p = np.setdiff1d(U_indexmap.local_to_global(np.array(map_p)), U_indexmap.ghosts)
localdofs_vel = np.setdiff1d(U_indexmap.local_to_global(np.array(map_vel)), U_indexmap.ghosts)

snes.getKSP().setType("fgmres")
snes.getKSP().getPC().setType(PETSc.PC.Type.FIELDSPLIT)


ISvx_ns = PETSc.IS().createGeneral(np.array(localdofs_vx, dtype=np.int32), comm=mesh.comm)
ISvy_ns = PETSc.IS().createGeneral(np.array(localdofs_vy, dtype=np.int32), comm=mesh.comm)
ISvz_ns = PETSc.IS().createGeneral(np.array(localdofs_vz, dtype=np.int32), comm=mesh.comm)
ISp_ns = PETSc.IS().createGeneral(np.array(localdofs_p, dtype=np.int32), comm=mesh.comm)

ns_ksp = snes.getKSP().getPC()
ns_ksp.setFieldSplitIS(("vx", ISvx_ns), ("vy", ISvy_ns),("vz", ISvz_ns), ("p", ISp_ns))
ns_ksp.setFieldSplitType(PETSc.PC.CompositeType.MULTIPLICATIVE)

opts["ksp_rtol"] = 1e-1
opts["ksp_max_it"] = 10000
opts["ksp_converged_reason"] = None
opts["ksp_gmres_restart"] = 150
#opts["ksp_monitor"] = None
opts["ksp_pc_side"] = "right"

vel_components = ["vx", "vy", "vz"]
for vel in vel_components:
    opts[f"fieldsplit_{vel}_ksp_type"]= "gmres"
    opts[f"fieldsplit_{vel}_ksp_rtol"]= 1e-1
    opts[f"fieldsplit_{vel}_pc_type"]= "asm"
    opts[f"fieldsplit_{vel}_ksp_gmres_restart"] = 50
    opts[f"fieldsplit_{vel}_pc_asm_overlap"] = 1
    opts[f"fieldsplit_{vel}_sub_pc_type"] = "lu"
    opts[f"fieldsplit_{vel}_sub_pc_factor_mat_solver_type"] = "mumps"


opts["fieldsplit_p_ksp_type"]= "cg"
opts["fieldsplit_p_pc_type"]= "jacobi"
opts["fieldsplit_p_ksp_max_it"] = 15

################################ Solver loop #########################################
opts["snes_max_it"] = 1
opts.prefixPop()  # ns_
snes.setFromOptions()

x = uh.x.petsc_vec.copy()
x.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
x_norm = x.norm()

CFL_prev = 0
CFL_min = 1
CFL = CFL_min    # Valeur minimale de CFL
CFL_max = 1000 # Valeur maximale de CFL
beta = 0.8    # Paramètre beta (fixé à 1 dans l'article)

vnorm = sqrt(dot(v,v))
h = ufl.CellDiameter(mesh)
dt_local = CFL * h/vnorm
dt_expr = dolfinx.fem.Expression(dt_local, T.element.interpolation_points())
k.interpolate(dt_expr)

ite = 1
uh_prev.x.array[:] = uh.x.array[:]
uh_prev.x.scatter_forward()

norm_res_l2_0 = 1
norm_res_inf_0 = 1

t0 = time.time()

for i in range(100):

    t2 = time.time()
    comm.barrier()
    snes.solve(None, x)
    comm.barrier()
    t3 = time.time()
    if comm.rank ==0:
        print("time solve: ", t3-t2)

    x.copy(uh.x.petsc_vec)
    uh.x.petsc_vec.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
    uh.x.scatter_forward()
    
    #Update cfl and dt
    if problem.norm_res_l2 != 0 and problem.norm_res_inf != 0 and ite > 1:
        r = max(problem.norm_res_l2 / norm_res_l2_0, problem.norm_res_inf/ norm_res_inf_0 )
        CFL = min(CFL_min/r**beta, CFL_max)

        vnorm = sqrt(dot(v,v))
        dt_local = CFL * h/vnorm

        comm.barrier()
        t12 = time.time()
        dt_expr = dolfinx.fem.Expression(dt_local, T.element.interpolation_points())
        comm.barrier()
        t13 = time.time()

        k.interpolate(dt_expr)

        k.x.scatter_forward()  # Synchronise les données locales vers le vecteur PETSc global
        k.x.petsc_vec.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.REVERSE)
        CFL_prev = CFL

        norm_k = k.x.petsc_vec.norm()
        
        if comm.rank ==0:
            print("time in expression: ", t13-t12)
            print("norm k: ", norm_k)

    else:
        CFL = CFL *1
        CFL_prev = CFL

    if (ite == 1 or ite == 2) and problem.norm_res_l2 != 0 and norm_res_l2_0 == 1:
        norm_res_l2_0  = problem.norm_res_l2

    if MPI.COMM_WORLD.rank == 0:
        print("-" * 60)
        print(f"Iteration {ite:3d} | Résidu L2: {problem.norm_res_l2:.3e} CFL: {CFL:.2f}")

    if (problem.norm_res_l2) < 1e-9 and ite >2:
        break

    ite += 1


uh_norm = uh.x.petsc_vec.norm()
t1 = time.time()
if comm.rank ==0:
    print("Temps écoulé: ", t1-t0)
    print(uh_norm)`

The point is that values that are constant valued (but potentially varying in time), should be wrapped in a dolfinx.fem.Constant, i.e. within your for-loop, you currently have:

where r, CFL_min, CFL_max all are time dependent floating point values.
Thus I would suggest doing the following outside the loop:

r = dolfinx.fem.Constant(mesh, 1.0) # Dummy value
CFL = dolfinx.fem.Constant(mesh, 1.0) # Dummy value
vnorm = ufl.sqrt(ufl.dot(v, v)
dt_local = CFL * h / vnorm
dt_expr = dolfinx.fem.Expression(dt_local, T.element.interpolation_points())

then, within the loop call:

for i in range(100):
    # solve problem here
    # ...

    # Update cfl and dt
    r.value =  max(problem.norm_res_l2 / norm_res_l2_0, problem.norm_res_inf/ norm_res_inf_0 )
    CFL.value =  min(CFL_min/float(r)**beta, CFL_max)
    k.interpolate(dt_expr)

as you would then not have to recompile your code at every time step.

Note that you should probably clear your current cache, as it is probably flooded with various compiled expressions. It is usually located in /home/username/.cache/fenics.

It seems to be working for now, thank you very much!