Preconditioning for monolithic block NSE

Hi guys! I used FEniCS to study the Hopf bifurcation in 2D cavity and the results came out well. I used MUMPS here, which is unaffordable for the next 3D studies. I could solve the operator split+IMEX with conjugate gradient correctly. However since I want to use residual based methods for variational multiscale later on, I need to solve the monolithic block of Navier Stokes. All my experiments with GMRES are failing.
Below is a typical failing 2D code. Basically I have exhausted the ‘shoot in the dark’ GPT ideas on preconditioners so I guess I need to study it in depth. I am quite weak at the theory of Krylov solvers, thus I am really sorry that I cannot explain these choices properly: if I should read any theory first, please suggest a good resource to start reading and I can get back with a much better response.

#!/usr/bin/env python3
# 2D Lid-driven cavity (unit square) | TH Pk/P{k-1} | Fully-implicit BDF2


from dolfin import *
import numpy as np
import os, json

# ----------------------- User parameters ---------------------------------
nx, ny      = 48, 48
k           = 2
Re          = 100.0
dt_value    = 0.1
T_end       = 20.0               # keep small while testing
n_frames    = 50
out_dir     = "ldc_bdf2_TH_SUPG_gmres_blockLU_xdmf"
lid_speed   = 1.0

# SUPG toggle and coefficients
enable_supg = True
Ct_val  = 4.0
Cnu_val = 4.0

# Solver controls
use_gmres          = True     # True: GMRES+FieldSplit, False: global MUMPS (direct)
use_block_lu       = True     # True: LU on A00 and LU on Schur (robust; slow)
newton_atol        = 1e-8
newton_rtol        = 1e-8
newton_maxi        = 40
dt_first_step_scale= 1.0      # e.g. 0.5 for an easier first step

# PSD settings
do_psd          = True
T_meas_start    = 0.5*T_end
fmin, fmax      = 0.0, 5.0

# Checkpoint / restart
use_checkpoint          = True
checkpoint_int          = 10
resume_from_ckpt        = True
T_extend_after_restart  = 0.0
combine_history_for_psd = True
# -------------------------------------------------------------------------

# ----------------------- PETSc options (our own prefix) ------------------
from dolfin import PETScOptions
PETScOptions.clear()
prefix = "ns_"  # private prefix bound to our KSP

if use_gmres:
    from dolfin import PETScOptions
PETScOptions.clear()
prefix = "ns_"

# --- Outer Krylov ---
PETScOptions.set(prefix + "ksp_type", "fgmres")
PETScOptions.set(prefix + "ksp_rtol", "1e-8")
PETScOptions.set(prefix + "ksp_atol", "0.0")
PETScOptions.set(prefix + "ksp_max_it", "500")
PETScOptions.set(prefix + "ksp_converged_reason")
# keep default right-preconditioning (do NOT force left here)

# --- Monolithic FieldSplit(Schur) ---
PETScOptions.set(prefix + "pc_type", "fieldsplit")
PETScOptions.set(prefix + "pc_fieldsplit_detect_saddle_point")
PETScOptions.set(prefix + "pc_fieldsplit_type", "schur")
PETScOptions.set(prefix + "pc_fieldsplit_schur_factorization_type", "lower")
PETScOptions.set(prefix + "pc_fieldsplit_off_diag_use_amat", "1")

# IMPORTANT: precondition with A11 so the pressure PC factors A11, not S
PETScOptions.set(prefix + "pc_fieldsplit_schur_precondition", "a11")

# Use exact A^{-1} (since A00 is LU)
PETScOptions.set(prefix + "mat_schur_complement_ainv_type", "full")

# --- Velocity block A00: LU (MUMPS if available) ---
PETScOptions.set(prefix + "fieldsplit_0_ksp_type", "preonly")
PETScOptions.set(prefix + "fieldsplit_0_pc_type", "lu")
PETScOptions.set(prefix + "fieldsplit_0_pc_factor_mat_solver_type", "mumps")

# --- Pressure block preconditioner (A11): LU (MUMPS) ---
PETScOptions.set(prefix + "fieldsplit_1_ksp_type", "preonly")
PETScOptions.set(prefix + "fieldsplit_1_pc_type", "lu")
PETScOptions.set(prefix + "fieldsplit_1_pc_factor_mat_solver_type", "mumps")

# Diagnostics: you should see "KSP Object: (ns_)" and field splits
PETScOptions.set(prefix + "ksp_view")
PETScOptions.set(prefix + "pc_view")


# ----------------------- Setup & I/O -------------------------------------
assert k >= 2, "Use k >= 2 (Taylor–Hood)."
os.makedirs(out_dir, exist_ok=True)
ckpt_dir = os.path.join(out_dir, "checkpoints")
if use_checkpoint:
    os.makedirs(ckpt_dir, exist_ok=True)
ckpt_h5   = os.path.join(ckpt_dir, "state.h5")
ckpt_meta = os.path.join(ckpt_dir, "state.json")

mesh = UnitSquareMesh(nx, ny)

# Light corner grading (optional)
corner_pts = [Point(0.0, 0.0), Point(1.0, 0.0), Point(1.0, 1.0), Point(0.0, 1.0)]
for r in (0.20, 0.10, 0.05):
    markers = MeshFunction("bool", mesh, mesh.topology().dim(), False)
    for cell in cells(mesh):
        mp = cell.midpoint()
        if min(mp.distance(P) for P in corner_pts) < r:
            markers[cell] = True
    mesh = refine(mesh, markers)

dim  = mesh.topology().dim()
I    = Identity(dim)

# Spaces
V_el = VectorElement("Lagrange", mesh.ufl_cell(), k)
Q_el = FiniteElement("Lagrange", mesh.ufl_cell(), k-1)
W    = FunctionSpace(mesh, MixedElement([V_el, Q_el]))
V    = FunctionSpace(mesh, V_el)
Q    = FunctionSpace(mesh, Q_el)
Ssc  = FunctionSpace(mesh, FiniteElement("Lagrange", mesh.ufl_cell(), k))

# Quadrature
quad_deg_form = max(3*k, 6)
dx_form = Measure("dx", domain=mesh, metadata={"quadrature_degree": quad_deg_form})

# Unknowns/tests
w = Function(W)         # (u, p)
u, p = split(w)
v, q = TestFunctions(W)

# Physics / time
nu = Constant(1.0/Re)
dt = Constant(dt_value)
Ct  = Constant(Ct_val)
Cnu = Constant(Cnu_val)
f   = Constant((0.0, 0.0))

# BDF2 history
u_n   = Function(V); u_n.vector().zero()
u_nm1 = Function(V); u_nm1.vector().zero()

# BCs
noslip = DirichletBC(W.sub(0), Constant((0.0, 0.0)),
                     "near(x[0],0.0) || near(x[0],1.0) || near(x[1],0.0)")
lid_bc = DirichletBC(W.sub(0), Constant((lid_speed, 0.0)), "near(x[1],1.0)")
p_pin  = DirichletBC(W.sub(1), Constant(0.0),
                     "near(x[0],0.0) && near(x[1],0.0)", "pointwise")
bcs = [noslip, lid_bc, p_pin]

# Helpers
def conv(vec):  # (u·∇)u
    return grad(vec) * vec

# Time operators
du_bdf1 = (u - u_n)/dt
du_bdf2 = (3.0*u - 4.0*u_n + u_nm1)/(2.0*dt)

# Strong momentum residuals
r_m_bdf1 = du_bdf1 + conv(u) - nu*div(grad(u)) + grad(p) - f
r_m_bdf2 = du_bdf2 + conv(u) - nu*div(grad(u)) + grad(p) - f

# Galerkin part
F_bdf1 = ( inner(du_bdf1, v)*dx_form
           + inner(conv(u), v)*dx_form
           + nu*inner(grad(u), grad(v))*dx_form
           - div(v)*p*dx_form
           - q*div(u)*dx_form )

F_bdf2 = ( inner(du_bdf2, v)*dx_form
           + inner(conv(u), v)*dx_form
           + nu*inner(grad(u), grad(v))*dx_form
           - div(v)*p*dx_form
           - q*div(u)*dx_form )

# SUPG (momentum only)
if enable_supg:
    h  = CellDiameter(mesh)
    G  = (1.0/(h*h)) * I
    GG = inner(G, G)
    u_tau = u
    tau_M = 1.0 / sqrt(Ct/dt**2 + inner(u_tau, G*u_tau) + (Cnu*nu)**2 * GG)
    supg_test = grad(v) * u
    F_bdf1 += inner(tau_M*supg_test, r_m_bdf1) * dx_form
    F_bdf2 += inner(tau_M*supg_test, r_m_bdf2) * dx_form

# Jacobians
dw     = TrialFunction(W)
J_bdf1 = derivative(F_bdf1, w, dw)
J_bdf2 = derivative(F_bdf2, w, dw)

# Writers
from dolfin import XDMFFile, MPI
comm = MPI.comm_world
xdmf_path = os.path.join(out_dir, "ldc_series.xdmf")
xdmf = XDMFFile(comm, xdmf_path)
xdmf.parameters["flush_output"] = True
xdmf.parameters["functions_share_mesh"] = True
xdmf.parameters["rewrite_function_mesh"] = False

def save_fields(u_fun, p_fun, t):
    u_fun.rename("u", "u");  p_fun.rename("p", "p")
    xdmf.write(u_fun, t);    xdmf.write(p_fun, t)
    ux, uy = u_fun.split(deepcopy=True)
    ux.rename("ux","ux"); xdmf.write(ux, t)
    uy.rename("uy","uy"); xdmf.write(uy, t)
    umag = project(sqrt(inner(u_fun, u_fun)), Ssc)
    umag.rename("umag","umag"); xdmf.write(umag, t)

# Time plan / restart
n_steps_nominal  = int(round(T_end/dt_value))
save_idx = np.linspace(0, n_steps_nominal, num=max(2, n_frames), dtype=int)

start_step = 0; t0 = 0.0
if resume_from_ckpt and use_checkpoint and os.path.exists(ckpt_h5) and os.path.exists(ckpt_meta):
    if comm.rank == 0: print("Loading checkpoint:", ckpt_h5)
    H = HDF5File(comm, ckpt_h5, "r")
    H.read(u_n,   "/u_n"); H.read(u_nm1, "/u_nm1"); H.close()
    if comm.rank == 0:
        with open(ckpt_meta, "r") as f: meta = json.load(f)
        start_step = int(meta.get("step", 0)); t0 = float(meta.get("time", 0.0))
    start_step = comm.bcast(start_step, root=0); t0 = comm.bcast(t0, root=0)

extend_steps = int(round(T_extend_after_restart/dt_value)) if T_extend_after_restart > 0.0 else 0
n_steps = max(n_steps_nominal, start_step + extend_steps)

if comm.rank == 0:
    print(f"=== LDC TH P{k}/P{k-1} | SUPG={'on' if enable_supg else 'off'} ===")
    print(f"Mesh {nx}x{ny} | Re={Re} | dt={dt_value} | total steps target={n_steps} | resume at step {start_step}")
    print(f"Quadrature degree (form) = {quad_deg_form}")
    if enable_supg: print(f"Tezduyar tau: Ct={Ct_val}, Cnu={Cnu_val}")
    print(f"Output -> {xdmf_path}")

# ----------------------- Explicit Newton + our KSP -----------------------
from dolfin import PETScKrylovSolver, PETScLUSolver, assemble_system, assemble

def newton_solve(F_form, J_form, w, bcs, atol=1e-8, rtol=1e-8, maxi=40):
    """Solve F(w)=0 with assembled J and either GMRES+FieldSplit or global LU."""
    def res_norm():
        r = assemble(F_form)
        for bc in bcs: bc.apply(r)
        return r.norm("l2")

    if use_gmres:
        solver = PETScKrylovSolver()
        solver.ksp().setOptionsPrefix(prefix)
        solver.ksp().setFromOptions()
    else:
        solver = PETScLUSolver("mumps")
        solver.parameters["reuse_factorization"] = False

    dw = Function(w.function_space())

    r0 = res_norm()
    if MPI.comm_world.rank == 0:
        print(f"  Newton iteration 0: r (abs) = {r0:.3e}")

    for it in range(1, maxi+1):
        A, b = assemble_system(J_form, -F_form, bcs)
        solver.set_operator(A)
        solver.solve(dw.vector(), b)
        w.vector().axpy(1.0, dw.vector())

        r = res_norm()
        if MPI.comm_world.rank == 0:
            print(f"  Newton iteration {it}: r (abs) = {r:.3e} | r/r0 = {r/r0:.3e}")
        if r < atol or r/r0 < rtol:
            break

# ----------------------- Time loop ---------------------------------------
t = t0
first_dt = float(dt_value) * dt_first_step_scale
for n in range(start_step+1, n_steps+1):
    if n == 1 and start_step == 0 and dt_first_step_scale != 1.0:
        dt.assign(first_dt)
    else:
        dt.assign(dt_value)

    t = n*float(dt_value)
    if comm.rank == 0:
        print("\n--- Step {}/{}  t = {:.6f} ---".format(n, n_steps, t))

    w.vector().zero()
    assign(w.sub(0), u_n)

    if n == 1 and start_step == 0:
        newton_solve(F_bdf1, J_bdf1, w, bcs, atol=newton_atol, rtol=newton_rtol, maxi=newton_maxi)
    else:
        newton_solve(F_bdf2, J_bdf2, w, bcs, atol=newton_atol, rtol=newton_rtol, maxi=newton_maxi)

    u_new, p_new = w.split(deepcopy=True)

    # totals
    keTot  = assemble(0.5*inner(u_new, u_new) * dx)
    ensTot = assemble(0.5*curl(u_new)**2 * dx)
    if comm.rank == 0:
        time_series = globals().get("time_series", []); time_series.append(t)
        keTot_series = globals().get("keTot_series", []); keTot_series.append(keTot)
        ensTot_series = globals().get("ensTot_series", []); ensTot_series.append(ensTot)
        globals()["time_series"], globals()["keTot_series"], globals()["ensTot_series"] = time_series, keTot_series, ensTot_series

    if n in save_idx:
        save_fields(u_new, p_new, t)

    # checkpoint
    if use_checkpoint and (n % checkpoint_int == 0 or n == n_steps):
        if comm.rank == 0:
            print(f"[checkpoint] step={n}, t={t:.6f}")
        H = HDF5File(comm, ckpt_h5, "w")
        H.write(u_new, "/u_n"); H.write(u_n, "/u_nm1"); H.close()
        if comm.rank == 0:
            with open(ckpt_meta, "w") as f:
                json.dump({"step": n, "time": t, "dt": float(dt)}, f)

    # shift BDF2 history
    u_nm1.assign(u_n); u_n.assign(u_new)

xdmf.close()

# ----------------------- Post: plots, CSV, PSD ---------------------------
if MPI.comm_world.rank == 0:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    time_series   = globals().get("time_series", [])
    keTot_series  = globals().get("keTot_series", [])
    ensTot_series = globals().get("ensTot_series", [])

    ts   = np.array(time_series, float)
    keT  = np.array(keTot_series, float)
    ensT = np.array(ensTot_series, float)

    csv_tot = os.path.join(out_dir, "globals_totals_timeseries.csv")
    if combine_history_for_psd and os.path.exists(csv_tot) and ts.size == 0:
        try:
            data_prev = np.loadtxt(csv_tot, delimiter=",", skiprows=1)
            if data_prev.ndim == 1 and data_prev.size >= 3: data_prev = data_prev.reshape(1, -1)
            if data_prev.shape[1] >= 3:
                ts   = data_prev[:,0]; keT = data_prev[:,1]; ensT = data_prev[:,2]
                print(f"Loaded previous totals ({ts.size} samples) from {csv_tot}")
        except Exception as e:
            print(f"Warning: could not load previous totals CSV ({e}). Proceeding without.")

    np.savetxt(csv_tot, np.column_stack([ts, keT, ensT]),
               delimiter=",", header="t,KE_total,Enstrophy_total", comments="")
    print(f"\nSaved totals time series: {csv_tot}")

    plt.figure(figsize=(9,4.2), dpi=140)
    plt.plot(ts, keT, lw=1.6)
    plt.xlabel("t"); plt.ylabel("∫ 0.5 |u|^2 dx")
    plt.title("Total kinetic energy vs time")
    plt.grid(True, alpha=0.25); plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "ke_total_vs_time.png")); plt.close()

    plt.figure(figsize=(9,4.2), dpi=140)
    plt.plot(ts, ensT, lw=1.6)
    plt.xlabel("t"); plt.ylabel("∫ 0.5 ω^2 dx")
    plt.title("Total enstrophy vs time")
    plt.grid(True, alpha=0.25); plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "enstrophy_total_vs_time.png")); plt.close()

    if do_psd and len(ts) >= 64:
        m = ts >= T_meas_start
        if np.count_nonzero(m) < 64: m = np.arange(len(ts)) >= len(ts)//2
        ts2 = ts[m]; ke2 = keT[m]
        if len(ts2) >= 64:
            dt_samp = np.median(np.diff(ts2))
            y = ke2 - np.mean(ke2)
            n = len(y); nseg = max(128, int(n//4)); step = nseg//2
            win = np.hanning(nseg)
            accP = None; freqs = None; nacc = 0
            for start in range(0, n - nseg + 1, step):
                seg = y[start:start+nseg]
                Y = np.fft.rfft(seg * win)
                f = np.fft.rfftfreq(nseg, d=dt_samp)
                P = (np.abs(Y)**2) / np.sum(win**2)
                accP = P if accP is None else accP + P
                freqs = f; nacc += 1
            P = accP / max(nacc,1)
            band = (freqs >= fmin) & (freqs <= fmax)
            f_plot = freqs[band]; P_plot = P[band]
            if f_plot.size > 0 and f_plot[0] == 0.0: P_plot[0] = 0.0
            if f_plot.size > 1:
                kpk = int(np.argmax(P_plot)); fpk = float(f_plot[kpk])
                print(f"PSD peak frequency f ~ {fpk:.6f} (Strouhal)")
            np.savetxt(os.path.join(out_dir, "ke_total_psd.csv"),
                       np.column_stack([f_plot, P_plot]),
                       delimiter=",", header="f,PSD", comments="")
            plt.figure(figsize=(8.5,4.2), dpi=140)
            plt.plot(f_plot, P_plot, lw=1.8)
            if f_plot.size > 1:
                plt.axvline(f_plot[np.argmax(P_plot)], ls="--", lw=1.0)
            plt.xlim(fmin, fmax)
            plt.xlabel("frequency f (Strouhal)"); plt.ylabel("PSD [arb. units]")
            plt.title("Welch PSD of total kinetic energy")
            plt.grid(True, alpha=0.25); plt.tight_layout()
            plt.savefig(os.path.join(out_dir, "ke_total_psd.png")); plt.close()
        else:
            print("PSD: not enough samples after burn-in.")
    else:
        print("PSD skipped or insufficient samples.")

    print("\nAll done. Open in ParaView:", os.path.join(out_dir, "ldc_series.xdmf"))