Issues with Checkpointing in Mixed Element Function Spaces (FEniCS Legacy)

Hello everyone,

I’m working with a mixed element function space in FEniCS Legacy, and I’m trying to implement checkpointing (both reading and writing) to save and reload simulation states. However, after reading the checkpoint files, the solution vectors on the mesh seem to be incorrect.

Here’s an outline of my approach:

  • I have a mixed element function space consisting of two fields.
  • For checkpointing, I write the mesh and the solution functions (c and phi) into XDMF files.
  • When reloading the solution, I use FunctionAssigner to assign the values to the mixed function.

I suspect that something is going wrong during the reading or reassignment process, but I’m unsure of the correct procedure. Below is a part of my checkpointing code:

def writecheckpoint(self):
    with fe.XDMFFile("mesh.xdmf") as mesh_file:
        mesh_file.write(self.mesh)
    c_, phi_ = self.sf0.split(deepcopy=True)
    with fe.XDMFFile("phi_.xdmf") as file:
        file.write_checkpoint(phi_, "phi_", self.Time, fe.XDMFFile.Encoding.HDF5, append=False)
    with fe.XDMFFile("c_.xdmf") as file:
        file.write_checkpoint(c_, "c_", self.Time, fe.XDMFFile.Encoding.HDF5, append=False)

def readcheckpoint(self):
    phi_ = fe.Function(self.spacephi_)
    c_ = fe.Function(self.spacec_)
    with fe.XDMFFile("phi_.xdmf") as file:
        file.read_checkpoint(phi_, "phi_")
    with fe.XDMFFile("c_.xdmf") as file:
        file.read_checkpoint(c_, "c_")

    assigner = fe.FunctionAssigner(self.ME, [self.spacec_, self.spacephi_])
    assigner.assign(self.sf0, [c_, phi_])

The issue arises after reading the files—the solution vectors seem incorrect, and the results don’t match what I expect from the saved state.

What is the correct way to write checkpoint functions for a mixed element function space in FEniCS Legacy? Am I missing any crucial steps in reading or writing these files?

Any insights or suggestions would be greatly appreciated!

This is the minimal code for solving a diffusion equation using mixed elemnts. i am solving for C and phi or the second argument is allways zero.

import fenics as fe
import numpy as np
from tqdm import tqdm
fe.set_log_level(fe.LogLevel.ERROR)
#################### Define Parallel Variables ####################
comm = fe.MPI.comm_world 
rank = fe.MPI.rank(comm)
size = fe.MPI.size(comm)
#############################  END  ################################
class Diffusion: 

    def __init__(self, params_dict): 

        self.params_dict = params_dict
        self.mesh = self.params_dict["mesh"]
        self.dt = self.params_dict["dt"]
        self.D = self.params_dict["D"] # Diffusion Coefficient
        self.T_hot = self.params_dict["T_hot"]
        self.T_cold = self.params_dict["T_cold"]
        self.Center = self.params_dict["Center"]
        self.rad = self.params_dict["rad"]
        self.Time = self.params_dict["Time"]
        # Initialize
        self.funcspace()
        self.form()
        self.define_solver()
        if self.params_dict["checkpoint"]== True: 
            self.readcheckpoint()
        else:
            self.InitialCondition()
        self.define_solver()

    def funcspace(self):

        P1 = fe.FiniteElement("Lagrange", self.mesh.ufl_cell(), 1)  
        element = fe.MixedElement([P1, P1])
        self.ME = fe.FunctionSpace(self.mesh, element)
        self.v_c , self.v_phi = fe.TestFunctions(self.ME)
        self.sf = fe.Function(self.ME)  
        self.sf0 = fe.Function(self.ME) 
        self.c , self.phi = fe.split(self.sf)  # current time step
        self.c_ , self.phi_ = fe.split(self.sf0) # previous time step
        self.spacec_, _ = self.ME.sub(0).collapse(collapsed_dofs=True)
        self.spacephi_, _ = self.ME.sub(1).collapse(collapsed_dofs=True)
    
    def form(self):

        # 1. ∫( (dc/dt * v_c) + D * (∇c · ∇v_c) ) dx = 0
        dcdt = (self.c - self.c_) / self.dt
        gradc = fe.grad(self.c)
        eqdiff = ( 
            fe.inner(dcdt,self.v_c)
            +self.D*fe.inner(gradc, fe.grad(self.v_c )) 
            ) * fe.dx
        # 2. phi is always zero
        dphidt = (self.phi-self.phi_)/self.dt
        eqphi = fe.inner(dphidt,self.v_phi)*fe.dx
        
        self.F = eqdiff + eqphi

    def define_solver(self):

        J = fe.derivative(self.F, self.sf)
        problem = fe.NonlinearVariationalProblem(self.F, self.sf, J=J)
        solver  = fe.NonlinearVariationalSolver(problem)
        prm = solver.parameters
        prm["newton_solver"]["absolute_tolerance"] = 1E-8
        prm["newton_solver"]["relative_tolerance"] = 1E-7
        prm["newton_solver"]["maximum_iterations"] = 100
        self.solver = solver

    def solve(self):

        self.solver.solve()
        self.sf0.assign(self.sf)

    def InitialCondition(self):

        class InitialConditions(fe.UserExpression):
            def __init__(self, rad, center, T_hot, T_cold, **kwargs):
                super().__init__(**kwargs)
                self.rad = rad 
                self.center = center  
                self.T_hot = T_hot  
                self.T_cold = T_cold  

            def eval(self, values, x):
                xc, yc = self.center
                x, y = x[0], x[1]  
                dist = (x-xc)**2+(y-yc)**2  
                if dist <= self.rad**2:
                    values[0] = self.T_hot  
                else:
                    values[0] = self.T_cold  
                values[1] = 0  
            def value_shape(self):
                return (2,)


        Expr = InitialConditions(rad= self.rad, center = self.Center, T_hot= self.T_hot, T_cold= self.T_cold)
        self.sf0.interpolate(Expr)

    def writecheckpoint(self):

        with fe.XDMFFile("mesh.xdmf") as mesh_file:
            mesh_file.write(self.mesh)
        c_, phi_ = self.sf0.split(deepcopy=True)
        with fe.XDMFFile("phi_.xdmf") as file:
            file.write_checkpoint(phi_, "phi_", self.Time, fe.XDMFFile.Encoding.HDF5, append=False)
        with fe.XDMFFile("c_.xdmf") as file:
            file.write_checkpoint(c_, "c_", self.Time, fe.XDMFFile.Encoding.HDF5, append=False)

    def readcheckpoint(self):

        if self.params_dict["checkpoint"]== True: 
            phi_ = fe.Function(self.spacephi_)
            c_ = fe.Function(self.spacec_)
            with fe.XDMFFile("phi_.xdmf") as file:
                file.read_checkpoint(phi_, "phi_")
            with fe.XDMFFile("c_.xdmf") as file:
                file.read_checkpoint(c_, "c_")

            assigner = fe.FunctionAssigner(self.ME, [self.spacec_, self.spacephi_])
            assigner.assign(self.sf0, [c_, phi_])

############### Define Parameters ############
params = {
    "dt" : 1E-2,
    "dy": 1E-1, 
    "D" : 0.4,
    "Nx" : 15,
    "Ny" : 15,
    "Center" :[7.5,7.5],
    "rad": 2,
    "T_cold": 0,
    "T_hot": 20,
    "Time": 0,
    "checkpoint": False
}
############## Define Mesh ################
nx = (int)(params["Nx"]/params["dy"])
ny = (int)(params["Ny"]/params["dy"]) 
mesh = fe.RectangleMesh(fe.Point(0, 0), fe.Point(params["Nx"], params["Ny"]), nx, ny)
# if checkpoint is true, read mesh from file and start from previous solution
if params["checkpoint"]== True:
    mesh = fe.Mesh()
    with fe.XDMFFile("mesh.xdmf") as file:
        file.read(mesh)
params["mesh"] = mesh
################# END ######################
################# Define file ##############
file = fe.XDMFFile("Diff.xdmf") 
file.parameters["rewrite_function_mesh"] = True
file.parameters["flush_output"] = True
file.parameters["functions_share_mesh"] = True
################# END ######################
################# Define Diffusion Object ###
diff_problem = Diffusion(params)
Time = params["Time"]
dt = params["dt"]


for it in (range(int(1e4))):

    diff_problem.solve()
    Time += dt
    params["Time"] = Time

    if it%20 == 0:

        c, phi = diff_problem.sf0.split(deepcopy=True)
        c.rename("c", "c")
        phi.rename("phi", "phi")
        file.write(c, Time)
        file.write(phi, Time)
        if rank == 0: 
            print(f"iteration = ",it , flush=True)

    # write checkpoint
    if it%20 == 0:

        diff_problem.writecheckpoint()

You need to re-read the mesh from the checkpoint, rather than reading the function onto its initial grid. See

For further questions, please try to reduce the amount of code to something similar to the post I am referring to. You do not need to solve a diffusion problem to illustrate your issue. Simply create some initial conditions in the mixed space, and write them to file and try to read them back in.

Thank you for your guidance.

I have followed your suggestion to minimize the code and focus on the core issue of checkpointing in a mixed element function space. Below is the simplified example. The code interpolates an initial condition and writes it to a checkpoint file. When I use one procesor it seems correct but the problem arises when I attempt to read the checkpoint file using multiple processors (mpirun with 4 processes). The solution vectors on the mesh differ between the written and read versions.

Here’s the minimal code:

import fenics as fe
Time = 0
################## Define Parameters ######
params = {
    "dy": 1E-1, 
    "Nx" : 15,
    "Ny" : 15,
    "Center" :[7.5,7.5],
    "rad": 2,
    "T_cold": 0,
    "T_hot": 20,
}
################### END ###################
############## Define Mesh ################
nx = (int)(params["Nx"]/params["dy"])
ny = (int)(params["Ny"]/params["dy"]) 
mesh = fe.RectangleMesh(fe.Point(0, 0), fe.Point(params["Nx"], params["Ny"]), nx, ny)
################### END ####################
############## Define Function space #######
def FunctionSpace(mesh):
    P1 = fe.FiniteElement("Lagrange", mesh.ufl_cell(), 1)  
    element = fe.MixedElement([P1, P1])
    ME = fe.FunctionSpace(mesh, element)
    v_c , v_phi = fe.TestFunctions(ME)
    sf = fe.Function(ME)  
    c , phi = fe.split(sf)  # current time step
    spacec_, _ = ME.sub(0).collapse(collapsed_dofs=True)
    spacephi_, _ = ME.sub(1).collapse(collapsed_dofs=True)
    return {
        "ME": ME, 
        "v_c": v_c, 
        "v_phi": v_phi, 
        "sf": sf, 
        "c": c, 
        "phi": phi, 
        "spacec_": spacec_, 
        "spacephi_": spacephi_
        } # dict_variables

################### END ####################
################# Define fils ##############
file_write = fe.XDMFFile("write_checkpoint.xdmf") #
file_write.parameters["rewrite_function_mesh"] = True
file_write.parameters["flush_output"] = True
file_write.parameters["functions_share_mesh"] = True
file_read = fe.XDMFFile("read_checkpoint.xdmf") #
file_read.parameters["rewrite_function_mesh"] = True
file_read.parameters["flush_output"] = True
file_read.parameters["functions_share_mesh"] = True
################# END ######################
################# Functions ###############
def writecheckpoint(mesh, dict_variables):

    sf = dict_variables["sf"]
    c_, phi_ = sf.split(deepcopy=True)
    with fe.XDMFFile("mesh.xdmf") as mesh_file:
        mesh_file.write(mesh)
    with fe.XDMFFile("phi_.xdmf") as file:
        file.write_checkpoint(phi_, "phi_", fe.XDMFFile.Encoding.HDF5, append=False)
    with fe.XDMFFile("c_.xdmf") as file:
        file.write_checkpoint(c_, "c_", fe.XDMFFile.Encoding.HDF5, append=False)
    return

def readcheckpoint(dict_variables):

    spacec_ = dict_variables["spacec_"]
    spacephi_ = dict_variables["spacephi_"]
    phi_ = fe.Function(spacephi_)
    c_ = fe.Function(spacec_)

    with fe.XDMFFile("phi_.xdmf") as file:
        file.read_checkpoint(phi_, "phi_")
    with fe.XDMFFile("c_.xdmf") as file:
        file.read_checkpoint(c_, "c_")
    return c_, phi_

def InitialCondition(params, dict_variables):

    sf = dict_variables["sf"]
    class InitialConditions(fe.UserExpression):
        def __init__(self, params, **kwargs):
            super().__init__(**kwargs)
            self.rad = params["rad"]
            self.center =  params["Center"]
            self.T_hot =  params["T_hot"]
            self.T_cold = params["T_cold"]

        def eval(self, values, x):
            xc, yc = self.center
            x, y = x[0], x[1]  
            dist = (x-xc)**2+(y-yc)**2  
            if dist <= self.rad**2:
                values[0] = self.T_hot  
            else:
                values[0] = self.T_cold  
            values[1] = 0  
        def value_shape(self):
            return (2,)

    Expr = InitialConditions(params)
    sf.interpolate(Expr)

################# END ######################
################# main ####################

# interpolate the solution function sf with the initial condition and make a checkpoint files:
dict_variables = FunctionSpace(mesh)
InitialCondition(params, dict_variables)
writecheckpoint(mesh, dict_variables)
written_c, written_phi = dict_variables["sf"].split(deepcopy=True)
# read the ceckpoint files and assign the values to the solution function sf on the mesh which 
# has been read form the checkpoint file:
mesh_checkpoint = fe.Mesh()
with fe.XDMFFile("mesh.xdmf") as file:
    file.read(mesh_checkpoint)
dict_variables = FunctionSpace(mesh_checkpoint)
c_, phi_ = readcheckpoint(dict_variables)
######## assign the values read form the checkpoints to the new solution function sf on the new mesh:
ME = dict_variables["ME"]
spacec_ = dict_variables["spacec_"]
spacephi_ = dict_variables["spacephi_"]
sf = fe.Function(ME)
assigner = fe.FunctionAssigner(ME, [spacec_, spacephi_])
assigner.assign(sf, [c_, phi_])
read_c, read_phi = sf.split(deepcopy=True)
# write to files to compare: 
written_c.rename("written_c", "written_c")
written_phi.rename("written_phi", "written_phi")
read_c.rename("read_c", "read_c")
read_phi.rename("read_phi", "read_phi")
file_read.write(read_c, Time)
file_read.write(read_phi, Time)
file_write.write(written_c, Time)
file_write.write(written_phi, Time)
file_read.close()
file_write.close()