Although the same code works for the Allen–Cahn equation, I am encountering convergence issues when using the fixed-point method for the nonlinear term in the Cahn–Hilliard equation

Hello, everyone, i am trying to solve the nonlinear Cahn-Hilliard equation using Fixed point method for nonlinear term. I have used the following fixed point method code for the Allen-Cahn equation.

import numpy as np
from mpi4py import MPI
from dolfinx import mesh, fem, io
from dolfinx.fem.petsc import assemble_matrix, assemble_vector, create_vector
from dolfinx import default_real_type, log, plot
import ufl
from petsc4py import PETSc
from ufl import Measure, SpatialCoordinate, TestFunctions, TrialFunctions, div, exp, inner, grad
from dolfinx.fem.petsc import LinearProblem
from dolfinx.fem.petsc import NonlinearProblem

def run_simulation(N, dt, T=1.0, degree=1, max_fp_iter=10, fp_tol=1e-6):
    # Create structured mesh
    domain = mesh.create_unit_square(MPI.COMM_WORLD, N, N, mesh.CellType.quadrilateral)
    
    x = SpatialCoordinate(domain)
    dx = Measure("dx", domain)
    
    V = fem.functionspace(domain, ("Lagrange", degree))

    class exact_solution:
        def __init__(self, t):
            self.t = t
        def __call__(self, x):
            return np.exp(np.cos(self.t)) * np.cos(np.pi * x[0]) * np.cos(np.pi * x[1])
            
    def f_expr(x, t):
        return (-np.sin(t) * np.exp(np.cos(t)) * np.cos(np.pi * x[0]) * np.cos(np.pi * x[1]) 
                + 2 * np.pi**2 * np.exp(np.cos(t)) * np.cos(np.pi * x[0]) * np.cos(np.pi * x[1])
                + (np.exp(np.cos(t)) * np.cos(np.pi*x[0]) * np.cos(np.pi*x[1]))**3
                - np.exp(np.cos(t)) * np.cos(np.pi*x[0]) * np.cos(np.pi*x[1]))
                
    f = fem.Function(V)
    t = 0
    u_exact = exact_solution(t)
    
    # Initial condition
    u_n = fem.Function(V)
    u_n.interpolate(u_exact)

    # Variational problem components
    u, phi = ufl.TrialFunction(V), ufl.TestFunction(V)
    uh = fem.Function(V)  # Current solution
    
    # Time stepping
    num_steps = int(T / dt)
    for n in range(num_steps):
        t += dt
        u_exact.t = t
        f.interpolate(lambda x: f_expr(x, t))
        
        # Fixed-point iteration
        uh.x.array[:] = u_n.x.array  # Initial guess for fixed-point iteration
        for k in range(max_fp_iter):
            # Linearized form (Newton-like)
            a = (inner(u, phi)*dx + dt*inner(grad(u), grad(phi))*dx 
                + dt*inner(3*uh**2 * u, phi)*dx - dt*inner(u, phi)*dx)
    
            # Remaining nonlinear terms
            L = (inner(u_n + dt*f, phi)*dx + dt*inner(2*uh**3, phi)*dx)
            
            # Solve the linear system
            problem = LinearProblem(
                a, L,
                petsc_options={
                    "ksp_type": "preonly",
                    "pc_type": "lu",
                    "pc_factor_mat_solver_type": "superlu_dist",
                },
            )
            uh_new = problem.solve()
            
            # Compute difference between iterations
            diff = uh_new.x.array - uh.x.array
            diff_norm = np.sqrt(domain.comm.allreduce(
                np.sum(diff**2), op=MPI.SUM))
            
            # Update solution for next iteration
            uh.x.array[:] = uh_new.x.array
            
            # Check for convergence
            if diff_norm < fp_tol:
                if MPI.COMM_WORLD.rank == 0 and n % 10 == 0:
                    print(f"Time step {n}, FP iteration {k}: converged with diff {diff_norm:.2e}")
                break
            
        # Update for next time step
        u_n.x.array[:] = uh.x.array
    
    # Save solution
    with io.XDMFFile(domain.comm, "solution.xdmf", "w") as xdmf:
        xdmf.write_mesh(domain)
        xdmf.write_function(uh, t)

    # Error computation
    V_ex = fem.functionspace(domain, ("Lagrange", degree + 1))
    u_ex = fem.Function(V_ex)
    u_ex.interpolate(u_exact)
    error = np.sqrt(domain.comm.allreduce(
        fem.assemble_scalar(fem.form(inner(uh - u_ex, uh - u_ex)*dx)), op=MPI.SUM))
    
    h = 1.0 / N
    return h, dt, error

if __name__ == "__main__":
    Ns = [8, 16, 32, 64]
    hs, dts, errors = [], [], []

    for N in Ns:
        h = 1.0 / N
        dt = h**2
        h_val, dt_val, error = run_simulation(N, dt)
        hs.append(h_val)
        dts.append(dt_val)
        errors.append(error)
        if MPI.COMM_WORLD.rank == 0:
            print(f"N={N:2d}, h={h_val:.3e}, dt={dt_val:.3e}, L2 error={error:.3e}")

    # Compute convergence rate
    if MPI.COMM_WORLD.rank == 0:
        print("\nConvergence rates (combined h and dt):")
        for i in range(1, len(errors)):
            ratesp = np.log(errors[i-1]/errors[i])/np.log(hs[i-1]/hs[i])
            print(f"From N={Ns[i-1]} to N={Ns[i]}: ratesp ≈ {ratesp:.2f}")
            ratetime = np.log(errors[i-1]/errors[i])/np.log(dts[i-1]/dts[i])
            print(f"From dt={dts[i-1]:.3e} to dt={dts[i]:.3e}: ratetime ≈ {ratetime:.2f}")

Above code yields for P1 elements with dt=h^2:

But while i am trying to use same fixed point loop in the Cahn-Hilliard equation in mixed form, it does not converging.

Here, is my CH code:

import os

try:
    from petsc4py import PETSc
    import dolfinx
    if not dolfinx.has_petsc:
        print("This demo requires DOLFINx to be compiled with PETSc enabled.")
        exit(0)
except ModuleNotFoundError:
    print("This demo requires petsc4py.")
    exit(0)

from mpi4py import MPI
import numpy as np
import ufl
from basix.ufl import element, mixed_element
from dolfinx import default_real_type, log, plot
from dolfinx.fem import Function, functionspace
from dolfinx.fem.petsc import LinearProblem
from dolfinx.io import XDMFFile
from dolfinx.mesh import CellType, create_unit_square
from dolfinx import mesh, fem, io
from ufl import Measure, SpatialCoordinate, TestFunctions, TrialFunctions, div, exp, inner, grad
from math import log2  # Add this import at the top of your file

try:
    import pyvista as pv
    import pyvistaqt as pvqt
    have_pyvista = True
    if pv.OFF_SCREEN:
        pv.start_xvfb(wait=0.5)
except ModuleNotFoundError:
    print("pyvista and pyvistaqt are required to visualise the solution")
    have_pyvista = False

# Save all logging to file
#log.set_output_file("log.txt")

def run_simulation(N, dt, T=1.0, degree=1, max_fp_iter=10, fp_tol=1e-6,p=0.1):
    #msh = create_unit_square(MPI.COMM_WORLD, N, N, CellType.triangle)
    msh = create_unit_square(MPI.COMM_WORLD, N, N, CellType.quadrilateral)
    P1 = element("Lagrange", msh.basix_cell(), 1, dtype=default_real_type)
    ME = functionspace(msh, mixed_element([P1, P1]))
    V = functionspace(msh, P1)
    
    # Test and trial functions for the mixed space
    u, w = ufl.TrialFunctions(ME)
    phi, psi = ufl.TestFunctions(ME)
    
    x = SpatialCoordinate(msh)
    dx = Measure("dx", msh)
    
    """
    # Set default tolerance if not provided
    if fp_tol is None:
        fp_tol = 1e-3/log2(N) if N > 1 else 1e-3
    """
    
    class exact_solution_u:
        def __init__(self, t):
            self.t = t
        def __call__(self, x):
            return np.exp(np.cos(self.t)) * np.cos(np.pi * x[0]) * np.cos(np.pi * x[1])
     
        
    class exact_solution_w:
        def __init__(self,t):
            self.t = t
            self.p = p  # Assuming p is a parameter needed for the solution
        def __call__(self, x):
            return ((np.exp(np.cos(self.t)) *np.cos(np.pi * x[0])* np.cos(np.pi * x[1]))**3
                   -(np.exp(np.cos(self.t)) *np.cos(np.pi * x[0])* np.cos(np.pi * x[1]))
                   +2*(np.pi**2)*self.p*(np.exp(np.cos(self.t)) *np.cos(np.pi * x[0])* np.cos(np.pi * x[1])))
                   
                   
    def f_expr(x,t):
        return (np.cos(np.pi*x[0])*np.cos(np.pi*x[1])*np.exp(np.cos(t))*(-np.sin(t)) 
               - 6*(np.pi)**2*np.exp(3*np.cos(t))*(np.cos(np.pi*x[0])*((np.sin(np.pi*x[0]))**2)*(np.cos(np.pi*x[1]))**3 
               + np.cos(np.pi*x[1])*((np.sin(np.pi*x[1]))**2)*(np.cos(np.pi*x[0]))**3) 
               + 6*(np.pi)**2*np.exp(3*np.cos(t))*(np.cos(np.pi*x[0])*np.cos(np.pi*x[1]))**3 
               + (4*(np.pi)**4*p - 2*(np.pi)**2)*np.cos(np.pi*x[0])*np.cos(np.pi*x[1])*np.exp(np.cos(t)))

    f = fem.Function(V)

    
    
    t = 0
    u_exact = exact_solution_u(t)
    w_exact = exact_solution_w(t)
    
    # Initial conditions
    u_n = fem.Function(V)
    u_n.interpolate(u_exact)
    w_n = fem.Function(V)
    w_n.interpolate(w_exact)
    
    # Solution functions
    vh = fem.Function(ME)
    uh, wh = vh.split()
    
    
    
    
    #stab = fem.Constant(msh, default_real_type(10.0))
    #epsilon = 1e-2  # Stabilization parameter (adjust as needed)
    
    # Time stepping
    num_steps = int(T / dt)
    for n in range(num_steps):
        t += dt
        
        # Update exact solutions and source term
        u_exact.t = t
        w_exact.t = t
        f.interpolate(lambda x: f_expr(x, t))
        
        # Initialize with previous solution
        uh.interpolate(u_n)
        wh.interpolate(w_n)  # Important: initialize both variables
        
        # Fixed-point iterations
        for k in range(max_fp_iter):
            
    
            # Bilinear form
            a1 =(inner(u, phi)*dx
                    + dt*inner(grad(w), grad(phi))*dx
                    +inner(w, psi)*dx
                    - p*dt*inner(grad(u), grad(psi))*dx
                    - inner(3*uh**2 * u, psi)*dx 
                    #+stab*inner(u,psi)*dx
                    + inner(u, psi)*dx
                )
            
            # Linear form
            L = (inner(u_n + dt*f, phi)*dx +inner(2*uh**3, psi)*dx)
                #-stab*inner(u_n,psi)*dx)
            
            # Solve the linearized system
            problem1 = LinearProblem(
                a1, L,
                petsc_options={
                    "ksp_type": "preonly",
                    "pc_type": "lu",
                    "pc_factor_mat_solver_type": "superlu_dist",
                },
            )
            
            
            
            
            #Solve 
            vh_new = problem1.solve()
            uh_new, wh_new = vh_new.split()
            
            # Compute difference between iterations
            diff_u = uh_new.x.array - uh.x.array
            diff_w = wh_new.x.array - wh.x.array
            diff_norm_u = np.sqrt(msh.comm.allreduce(np.sum(diff_u**2), op=MPI.SUM))
            diff_norm_w = np.sqrt(msh.comm.allreduce(np.sum(diff_w**2), op=MPI.SUM))
            diff_norm = max(diff_norm_u, diff_norm_w)  # Use worst-case convergence

            # Update solution for next FP iteration
            #uh.x.array[:] = uh_new.x.array
            #wh.x.array[:] = wh_new.x.array
            uh.interpolate(uh_new)
            wh.interpolate(wh_new)

            # Check for convergence
            if diff_norm < fp_tol:
                if MPI.COMM_WORLD.rank == 0:  # Removed n%10 condition to see all convergences
                    print(f"Time step {n:3d}, FP iteration {k:2d}: converged with diff {diff_norm:.2e} "
                         f"(u: {diff_norm_u:.2e}, w: {diff_norm_w:.2e})")
                break
            elif k == max_fp_iter-1 and MPI.COMM_WORLD.rank == 0:
               print(f"WARNING: Time step {n} did not converge after {max_fp_iter} iterations "
                         f"(final diff: {diff_norm:.2e})")
        
            # Only update time step variables AFTER FP iterations converge
            u_n.interpolate(uh)
            w_n.interpolate(wh)
    
    # Error computation (moved outside time loop)
    V_ex = fem.functionspace(msh, ("Lagrange", degree + 1))
    u_ex = fem.Function(V_ex)
    w_ex = fem.Function(V_ex)
    u_ex.interpolate(u_exact)
    w_ex.interpolate(w_exact)
    erroru = np.sqrt(msh.comm.allreduce(
        fem.assemble_scalar(fem.form(inner((uh - u_ex),(uh - u_ex))*dx)), op=MPI.SUM))
    errorw = np.sqrt(msh.comm.allreduce(
        fem.assemble_scalar(fem.form(inner((wh - w_ex),(wh - w_ex))*dx)), op=MPI.SUM))
    h = 1.0 / N
    return h, dt, erroru, errorw

if __name__ == "__main__":
    Ns = [8, 16, 32, 64]
    hs, dts, errorsu, errorsw = [], [], [], []

    for N in Ns:
        h = 1.0 / N
        dt = h**2
        h_val, dt_val, erroru, errorw = run_simulation(N, dt)
        hs.append(h_val)
        dts.append(dt_val)
        errorsu.append(erroru)
        errorsw.append(errorw)
        if MPI.COMM_WORLD.rank == 0:
            print(f"N={N:2d}, h={h_val:.3e}, dt={dt_val:.3e}, L2 erroru={erroru:.3e}, L2 errorw={errorw:.3e}")

    # Compute convergence rate of u
    if MPI.COMM_WORLD.rank == 0:
        print("\nConvergence rates (combined h and dt):")
        for i in range(1, len(errorsu)):
            ratespu = np.log(errorsu[i-1] / errorsu[i]) / np.log(hs[i-1] / hs[i])
            print(f"From N={Ns[i-1]} to N={Ns[i]}: ratespu ≈ {ratespu:.2f}")
            ratetimeu = np.log(errorsu[i-1] / errorsu[i]) / np.log(dts[i-1] / dts[i])
            print(f"From dt={dts[i-1]:.3e} to dt={dts[i]:.3e}: ratetimeu ≈ {ratetimeu:.2f}")
            
    # Compute convergence rate of w
    if MPI.COMM_WORLD.rank == 0:
        print("\nConvergence rates (combined h and dt):")
        for i in range(1, len(errorsw)):
            ratespw = np.log(errorsw[i-1] / errorsw[i]) / np.log(hs[i-1] / hs[i])
            print(f"From N={Ns[i-1]} to N={Ns[i]}: ratespw ≈ {ratespw:.2f}")
            ratetimew = np.log(errorsw[i-1] / errorsw[i]) / np.log(dts[i-1] / dts[i])
            print(f"From dt={dts[i-1]:.3e} to dt={dts[i]:.3e}: ratetimew ≈ {ratetimew:.2f}")

(For mixed i have used the following poisson mixed form code ideas:
Mixed formulation for the Poisson equation — DOLFINx 0.9.0 documentation)
What is the issues in CH code?. Is it due to mixed for or for nonlinear term. But same nonlinear term code i have used in above Allen-Chan code that works fine.

I will very gratefull if i get some help on this CH code to get optimal order for P1 and P2.

Thanks in advance!.