Solving 1D nonlinear mixed domain coupled PDE with MixedFunctionSpace API

Hello everyone,

I am trying to assess the feasibility of implementing a set of coupled PDE to model (in 1 dimension) a solid-state battery based on previous work with a commercial software using dolfinx v0.9.0 and the MixedFunctionSpace API. I am very new to FEM and some important steps in the examples at hand are missing so I was wondering if I was going in the right direction or if someone could point me towards the right tools for the job in the Fenics ecosystem based on the description of the problem and my first try at implementing it:

If you want to skip the (rather lengthy) problem description and starting implementation, my questions are:

  • What is the best way of implementing double Cauchy boundary condition in my nonlinear mixed domain case ? As I understood, it’s not trivial in this case to go for Nitsche method, so I would guess Lagrange multipliers, but I’m not sure it’s straightforward with this API.
  • Example at hand uses Lagrange 2nd Order basis function, is this the way to go for this type of problems in Fenicsx or should I switch to Discontinuous Galerkin approach given the fact that I have variable coefficients (diffusion coefficient in the flux uses intercalation fraction and concentration) ?
  • Is a simple backward Euler time-stepping scheme a good way to go or do I have to look further for time evolution and might run into some convergence issues?
  • I need refined mesh at interface or I expect severe numerical instabilities, will I run into trouble with parallelized approach or can I simply generate the 1D mesh with desired interval and keep the same approach ?
  • Is the MixedFunctionSpace API the way to go here or should I turn to other options such as the mixedelements or multiphenicsx ?

Any guidance/advice/critics will be most welcomed.

1D model :
|-------1-------|--------2--------|
cathode ------- electrolyte–

The system involves Lithium diffusion over two domains (cathode, electrolyte), governed by pure diffusion in cathode domain (1), with c_1 the Lithium concentration over domain, x_1 the intercalation fraction \frac{c_1}{c_{1max}}:

\frac{\partial c_1}{\partial t} = - \vec{\nabla} \cdot ({- c_1 (1 - x_1) \frac{D_1}{RT} \vec{\nabla} \mu_1})

and \mu_1 the chemical potential defined with respect to concentration c_1:

\mu_1 = \mu_{01} + RT \, ln{\left(\frac{x_1}{1 - x_1}\right)}

in the electrolyte-domain we have drift-diffusion:

\frac{\partial c_1}{\partial t} = - \vec{\nabla} \cdot (c_2 (1 - x_2) \frac{D_2}{RT} \vec{\nabla}(\mu_2 - z F \phi_2))

and Poisson equation for the electric field:

\Delta \phi_2 = - \frac{F(c_2 - c_{2_{eq}})}{\epsilon_0 \epsilon_r}

We have the following boundary conditions on each subdomains (1D only):

On cathode domain:

J_1|_{x=0} = 0

with J_1 = - c_1 (1 - x_1) \frac{D_1}{RT} \frac{\partial \mu_1}{\partial x}

At the cathode/electrolyte internal boundary we have a Neumann boundary condition:

\left. \frac{d \phi}{d x}\right|_{x=l_{cathode}} = - \frac{\sigma_{12}}{\epsilon_0 \epsilon_r}

We calculate the surface charge density \sigma_{12} by accumulating externally imposed current j_{ext}:

\sigma_{12}(t) = \int_{0}^{t} \left(j_{ext} - j_{12}\right) \, dt

and reaction current j_{12} is obtained from Butler-Volmer equation:

j_{12} = c_1 \, \left( 1 - \frac{c_2}{c_{2max}} \right) k_{c01} F \, exp\left( \frac{-(1 - \alpha_1) \eta_1 F}{RT} \right) - c_2 \left( 1 - \frac{c_1}{c_{1max}}\right) k_{a01} F \, exp\left(\frac{\alpha_1 \eta_1 F}{RT}\right)

We have two more Neumann boundary conditions on this internal facet (for c_1 in cathode and c_2 electrolyte domain):

j_{12} = J_1|_{x=l_{cathode}} F = -J_2|_{x=l_{cathode}} F

Similarly at electrolyte/anode (which is considered pure lithium and therefore not modelled), we have the Neumann Boundary condition on c_2:

J_2|_{x=l_{total}} F = j_{23}

And a double Cauchy boundary condition composed of a Dirichlet BC:

\phi|_{x=l_{total}} = 0

and a Neumann BC:

\left. \frac{d \phi}{dx}\right|_{x=l_{total}} = - \frac{\sigma_{23}}{\epsilon_0 \epsilon_r}

with j_{23} and \sigma_{23} defined as:

j_{23} = c_2 k_{a02} F \, exp\left(\frac{\alpha_2 \eta_2 F}{R T}\right) - \left( 1 - \frac{c_2}{c_{2max}} \right) k_{c02} F \, exp\left(\frac{-(1 - \alpha_2) \eta_2 F}{RT}\right)
\sigma_{23}(t) = \int_{0}^{t} \left(j_{23} - j_{app}\right) \, dt

REF : Minimal Architecture Lithium Batteries: Toward High Energy Density Storage Solutions - Celè - 2023 - Small - Wiley Online Library

Apologies for the rather lengthy code but large chunks are straight adaptations from Coupling PDEs of multiple dimensions — FEniCS Workshop and Poisson submesh DG coupling. I thought it would be better to reinclude whole functions to ease reproducibility.

import dolfinx
import numpy.typing as npt
from mpi4py import MPI
import numpy as np
from dolfinx.mesh import create_interval, locate_entities
from petsc4py import PETSc
import ufl
from dolfinx.fem import Constant
from dolfinx.nls.petsc import NewtonSolver
from packaging.version import Version

def compute_interface_data(
    cell_tags: dolfinx.mesh.MeshTags, facet_indices: npt.NDArray[np.int32]
) -> npt.NDArray[np.int32]:
    """Compute interior facet integrals consistently ordered according to cell_tags"""
    integration_args: tuple[int] | tuple
    if Version("0.10.0") <= Version(dolfinx.__version__):
        integration_args = ()
    else:
        fdim = cell_tags.dim - 1
        integration_args = (fdim,)
    idata = dolfinx.cpp.fem.compute_integration_domains(
        dolfinx.fem.IntegralType.interior_facet,
        cell_tags.topology,
        facet_indices,
        *integration_args,
    )
    ordered_idata = idata.reshape(-1, 4).copy()
    switch = cell_tags.values[ordered_idata[:, 0]] > cell_tags.values[ordered_idata[:, 2]]
    if True in switch:
        ordered_idata[switch, :] = ordered_idata[switch][:, [2, 3, 0, 1]]
    return ordered_idata

def communicate_indices(
    index_map: dolfinx.common.IndexMap, local_indices: npt.NDArray[np.int32]
) -> npt.NDArray[np.int32]:
    """Find ghosted indices marked by any other process"""
    index_accumulator = dolfinx.la.vector(index_map, 1)
    index_accumulator.array[:] = 0
    index_accumulator.array[local_indices] = 1
    index_accumulator.scatter_reverse(dolfinx.la.InsertMode.add)
    return np.flatnonzero(index_accumulator.array).astype(np.int32)

def transfer_meshtags_to_submesh(
    mesh, entity_tag, submesh, sub_vertex_to_parent, sub_cell_to_parent
):
    """Transfer a meshtag from a parent mesh to a sub-mesh."""
    tdim = mesh.topology.dim
    cell_imap = mesh.topology.index_map(tdim)
    num_cells = cell_imap.size_local + cell_imap.num_ghosts
    mesh_to_submesh = np.full(num_cells, -1)
    mesh_to_submesh[sub_cell_to_parent] = np.arange(
        len(sub_cell_to_parent), dtype=np.int32
    )
    sub_vertex_to_parent = np.asarray(sub_vertex_to_parent)

    submesh.topology.create_connectivity(entity_tag.dim, 0)

    num_child_entities = (
        submesh.topology.index_map(entity_tag.dim).size_local
        + submesh.topology.index_map(entity_tag.dim).num_ghosts
    )
    submesh.topology.create_connectivity(submesh.topology.dim, entity_tag.dim)

    c_c_to_e = submesh.topology.connectivity(submesh.topology.dim, entity_tag.dim)
    c_e_to_v = submesh.topology.connectivity(entity_tag.dim, 0)

    child_markers = np.full(num_child_entities, 0, dtype=np.int32)

    mesh.topology.create_connectivity(entity_tag.dim, 0)
    mesh.topology.create_connectivity(entity_tag.dim, mesh.topology.dim)
    p_f_to_v = mesh.topology.connectivity(entity_tag.dim, 0)
    p_f_to_c = mesh.topology.connectivity(entity_tag.dim, mesh.topology.dim)
    sub_to_parent_entity_map = np.full(num_child_entities, -1, dtype=np.int32)
    for facet, value in zip(entity_tag.indices, entity_tag.values):
        facet_found = False
        for cell in p_f_to_c.links(facet):
            if facet_found:
                break
            if (child_cell := mesh_to_submesh[cell]) != -1:
                for child_facet in c_c_to_e.links(child_cell):
                    child_vertices = c_e_to_v.links(child_facet)
                    child_vertices_as_parent = sub_vertex_to_parent[child_vertices]
                    is_facet = np.isin(
                        child_vertices_as_parent, p_f_to_v.links(facet)
                    ).all()
                    if is_facet:
                        child_markers[child_facet] = value
                        facet_found = True
                        sub_to_parent_entity_map[child_facet] = facet
    tags = dolfinx.mesh.meshtags(
        submesh,
        entity_tag.dim,
        np.arange(num_child_entities, dtype=np.int32),
        child_markers,
    )
    tags.name = entity_tag.name
    return tags, sub_to_parent_entity_map

class NewtonSolver:
    def __init__(
        self,
        F: list[dolfinx.fem.form],
        J: list[list[dolfinx.fem.form]],
        w: list[dolfinx.fem.Function],
        bcs: list[dolfinx.fem.DirichletBC] | None = None,
        max_iterations: int = 10,
        petsc_options: dict[str, str | float | int | None] = None,
    ):
        self.max_iterations = max_iterations
        self.bcs = [] if bcs is None else bcs
        self.b = dolfinx.fem.petsc.create_vector_block(F)
        self.F = F
        self.J = J
        self.A = dolfinx.fem.petsc.create_matrix_block(J)
        self.dx = self.A.createVecLeft()
        self.w = w
        self.x = dolfinx.fem.petsc.create_vector_block(F)

        # Set PETSc options
        opts = PETSc.Options()
        if petsc_options is not None:
            for k, v in petsc_options.items():
                opts[k] = v

        # Define KSP solver
        self._solver = PETSc.KSP().create(self.b.getComm().tompi4py())
        self._solver.setOperators(self.A)
        self._solver.setFromOptions()

        # Set matrix and vector PETSc options
        self.A.setFromOptions()
        self.b.setFromOptions()

    def solve(self, tol=1e-6, beta=1.0):
        i = 0
        while i < self.max_iterations:
            dolfinx.cpp.la.petsc.scatter_local_vectors(
                self.x,
                [si.x.petsc_vec.array_r for si in self.w],
                [
                    (
                        si.function_space.dofmap.index_map,
                        si.function_space.dofmap.index_map_bs,
                    )
                    for si in self.w
                ],
            )
            self.x.ghostUpdate(
                addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD
            )

            # Assemble F(u_{i-1}) - J(u_D - u_{i-1}) and set du|_bc= u_D - u_{i-1}
            with self.b.localForm() as b_local:
                b_local.set(0.0)

            if Version(dolfinx.__version__) < Version("0.9.0"):
                dolfinx.fem.petsc.assemble_vector_block(
                    self.b, self.F, self.J, bcs=self.bcs, x0=self.x, scale=-1.0
                )
            else:
                dolfinx.fem.petsc.assemble_vector_block(
                    self.b, self.F, self.J, bcs=self.bcs, x0=self.x, alpha=-1.0
                )
            self.b.ghostUpdate(
                PETSc.InsertMode.INSERT_VALUES, PETSc.ScatterMode.FORWARD
            )

            # Assemble Jacobian
            self.A.zeroEntries()
            dolfinx.fem.petsc.assemble_matrix_block(self.A, self.J, bcs=self.bcs)
            self.A.assemble()

            self._solver.solve(self.b, self.dx)
            assert (
                self._solver.getConvergedReason() > 0
            ), "Linear solver did not converge"
            
            offset_start = 0
            for s in self.w:
                num_sub_dofs = (
                    s.function_space.dofmap.index_map.size_local
                    * s.function_space.dofmap.index_map_bs
                )
                s.x.petsc_vec.array_w[:num_sub_dofs] -= (
                    beta * self.dx.array_r[offset_start : offset_start + num_sub_dofs]
                )
                s.x.petsc_vec.ghostUpdate(
                    addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD
                )
                offset_start += num_sub_dofs

            # Compute norm of update
            correction_norm = self.dx.norm(0)
            if i < 3 or correction_norm > 1e-6:
                print(f"  Newton iteration {i}: Correction norm {correction_norm:.2e}")
            if correction_norm < tol:
                break
            i += 1

# ========================
# MESH AND SUBDOMAIN SETUP 
# ========================

l_cathode = 20e-6     # Cathode thickness (m)
l_electrolyte = 10e-6 # Electrolyte thickness (m)
l_total = l_cathode + l_electrolyte

mesh = create_interval(MPI.COMM_WORLD, 300, [0.0, l_total], 
                      ghost_mode=dolfinx.mesh.GhostMode.shared_facet)

cdim = mesh.geometry.dim
tdim = mesh.topology.dim
fdim = cdim - 1
mesh.topology.create_connectivity(fdim, tdim)
facet_map = mesh.topology.index_map(fdim)

# Mark boundaries and domains
collector_facet = communicate_indices(facet_map, 
    dolfinx.mesh.locate_entities_boundary(mesh, fdim, lambda x: np.isclose(x[0], 0.0)))
anode_facet = communicate_indices(facet_map, 
    dolfinx.mesh.locate_entities_boundary(mesh, fdim, lambda x: np.isclose(x[0], l_total)))
num_facets_local = mesh.topology.index_map(fdim).size_local + mesh.topology.index_map(fdim).num_ghosts
facets = np.arange(num_facets_local, dtype=np.int32)
electrolyte_marker = 5
values = np.full_like(facets, electrolyte_marker, dtype=np.int32)
collector_marker = 1
anode_marker = 2
values[collector_facet] = collector_marker
values[anode_facet] = anode_marker
cell_map = mesh.topology.index_map(tdim)
cathode_cells = locate_entities(mesh, tdim, lambda x: x[0] <= l_cathode + 1e-14)
num_cells_local = cell_map.size_local + cell_map.num_ghosts
cells = np.full(num_cells_local, electrolyte_marker, dtype=np.int32)
cathode_marker = 4
cells[cathode_cells] = cathode_marker
ct = dolfinx.mesh.meshtags(mesh, tdim, np.arange(num_cells_local, dtype=np.int32), cells)

ct_cathode = communicate_indices(cell_map, ct.find(cathode_marker))
ct_electrolyte = communicate_indices(cell_map, ct.find(electrolyte_marker))

all_cathode_facets = dolfinx.mesh.compute_incident_entities(mesh.topology, ct_cathode, tdim, fdim)
all_electrolyte_facets = dolfinx.mesh.compute_incident_entities(mesh.topology, ct_electrolyte, tdim, fdim)

acf = communicate_indices(facet_map, all_cathode_facets)
aef = communicate_indices(facet_map, all_electrolyte_facets)

interface = np.intersect1d(acf, aef)

interface_marker = 6
values[interface] = interface_marker

mt = dolfinx.mesh.meshtags(mesh, mesh.topology.dim - 1, facets, values)

submesh_cathode, submesh_cathode_to_mesh, cathode_v_map = dolfinx.mesh.create_submesh(mesh, tdim, ct_cathode)[0:3]
submesh_electrolyte, submesh_electrolyte_to_mesh, electrolyte_v_map = dolfinx.mesh.create_submesh(mesh, tdim, ct_electrolyte)[0:3]

parent_to_sub_cathode = np.full(num_facets_local, -1, dtype=np.int32)
parent_to_sub_cathode[submesh_cathode_to_mesh] = np.arange(len(submesh_cathode_to_mesh), dtype=np.int32)
parent_to_sub_electrolyte = np.full(num_facets_local, -1, dtype=np.int32)
parent_to_sub_electrolyte[submesh_electrolyte_to_mesh] = np.arange(len(submesh_electrolyte_to_mesh), dtype=np.int32)

ft_cathode, cathode_facet_to_parent = transfer_meshtags_to_submesh(mesh, mt, submesh_cathode, cathode_v_map, submesh_cathode_to_mesh)
ft_electrolyte, electrolyte_facet_to_parent = transfer_meshtags_to_submesh(mesh, mt, submesh_electrolyte, electrolyte_v_map, submesh_electrolyte_to_mesh)

electrolyte_parent_to_facet = np.full(num_facets_local, -1)
electrolyte_parent_to_facet[electrolyte_facet_to_parent] = np.arange(len(electrolyte_facet_to_parent), dtype=np.int32)

mesh.topology.create_connectivity(fdim, tdim)
f_to_c = mesh.topology.connectivity(fdim, tdim)

for facet in mt.find(interface_marker):
    cells = f_to_c.links(facet)
    assert len(cells) == 2
    cathode_map = parent_to_sub_cathode[cells]
    electrolyte_map = parent_to_sub_electrolyte[cells]
    parent_to_sub_cathode[cells] = max(cathode_map)
    parent_to_sub_electrolyte[cells] = max(electrolyte_map)

entity_maps = {submesh_cathode: parent_to_sub_cathode, submesh_electrolyte: parent_to_sub_electrolyte}

# ===================
# MATERIAL PARAMETERS  
# ===================

# Physical constants
R = 8.314462618      # Gas constant (J/(mol·K))
T = 273.15 + 37      # Temperature (K)
F_const = 96485.33212      # Faraday constant (C/mol)

D_cathode = Constant(mesh, 1e-14)      #Li diffusion coefficient in cathode (m²/s)
D_electrolyte = Constant(mesh, 1e-14)  # Li diffusion coefficient in electrolyte (m²/s)
c_max_cathode = Constant(mesh, 5.6e4)  # Max Li concentration in cathode (mol/m³)
c_max_electrolyte = Constant(mesh, 100.0)  # Max Li concentration in electrolyte (mol/m³)
c_eq_electrolyte = Constant(mesh, 50.0)    # Equilibrium Li concentration (mol/m³)

# Interface kinetics parameters
exchange_current = 1.17e2  # Exchange current density (A/m²)

k_c01 = Constant(mesh, exchange_current / (F_const * c_max_cathode.value))           # Cathodic reaction rate constant (m/s)
k_a01 = Constant(mesh, exchange_current / (F_const * c_max_electrolyte.value))       # Anodic reaction rate constant (m/s)
k_c02 = Constant(mesh, exchange_current / (F_const * c_max_cathode.value))           # Cathodic reaction rate constant at anode (m/s)
k_a02 = Constant(mesh, exchange_current / (F_const * c_max_electrolyte.value))       # Anodic reaction rate constant at anode (m/s)

alpha_1 = Constant(mesh, 0.5)          # Transfer coefficient (cathode-electrolyte)
alpha_2 = Constant(mesh, 0.5)          # Transfer coefficient (electrolyte-anode)

# Double layer parameters
d_dl12 = Constant(mesh, 1.5e-9)        # Double layer thickness cathode interface (m)
d_dl23 = Constant(mesh, 1.5e-9)        # Double layer thickness anode interface (m)
A_elec = Constant(mesh, 4.7e-7)        # Electrode surface area (m²)
eps_0 = Constant(mesh, 8.8541878188e-12)      # Vacuum permittivity (F/m)
eps_r = Constant(mesh, 220.0)          # LiPON relative permittivity

# Time discretization
dt_val = 0.00001  # time step size (s)
dt = Constant(mesh, dt_val)
t_max = 0.00003  # maximum simulation time (s)

# =============================
# FUNCTION SPACES AND VARIABLES
# =============================

# Function spaces
V_cathode = dolfinx.fem.functionspace(submesh_cathode, ("Lagrange", 2))
V_electrolyte_c = dolfinx.fem.functionspace(submesh_electrolyte, ("Lagrange", 2))
V_electrolyte_phi = dolfinx.fem.functionspace(submesh_electrolyte, ("Lagrange", 2))

# Test functions
v1, v2, vphi = ufl.TestFunctions(ufl.MixedFunctionSpace(V_cathode, V_electrolyte_c, V_electrolyte_phi))

# Solution functions
c1 = dolfinx.fem.Function(V_cathode, name="c_cathode")      # Li concentration in cathode
c2 = dolfinx.fem.Function(V_electrolyte_c, name="c_electrolyte")  # Li concentration in electrolyte
phi = dolfinx.fem.Function(V_electrolyte_phi, name="phi_electrolyte")  # Electric potential

# Previous time step functions
c1_old = dolfinx.fem.Function(V_cathode)
c2_old = dolfinx.fem.Function(V_electrolyte_c)

# ==================
# INITIAL CONDITIONS
# ==================

starting_intercalation_1 = 0.5  
starting_intercalation_2 = 0.5
c1_init = starting_intercalation_1 * c_max_cathode.value    
c2_init = c_eq_electrolyte.value * starting_intercalation_2

c1_old.interpolate(lambda x: np.full(x.shape[1], c1_init))
c2_old.interpolate(lambda x: np.full(x.shape[1], c2_init))
c1.interpolate(lambda x: np.full(x.shape[1], c1_init))
c2.interpolate(lambda x: np.full(x.shape[1], c2_init))

phi.x.array[:] = 0.0
phi.x.scatter_forward()

# ===========================
# SURFACE CHARGE ACCUMULATORS
# ===========================

# Surface charge density accumulators (C/m²)
sigma_12 = 0.0  # Surface charge at cathode-electrolyte interface
sigma_23 = 0.0  # Surface charge at electrolyte-anode interface

# Applied current (A)
I_app = 0.0 # Applied current

# Double layer capacitances
C_dl12 = eps_0.value * eps_r.value * A_elec.value / d_dl12.value  
C_dl23 = eps_0.value * eps_r.value * A_elec.value / d_dl23.value  

# ==========
# WEAK FORMS
# ==========

# Cathode Domain
sort_order_cathode = np.argsort(submesh_cathode_to_mesh)
ct_r = dolfinx.mesh.meshtags(mesh, mesh.topology.dim, submesh_cathode_to_mesh[sort_order_cathode], 
                            np.full_like(submesh_cathode_to_mesh, 1, dtype=np.int32)[sort_order_cathode])

dx_r = ufl.Measure("dx", domain=mesh, subdomain_data=ct_r, subdomain_id=1)

x1 = c1 / c_max_cathode

# Chemical potential (μ1⁰ cancels in gradient)
mu1 = R * T * ufl.ln(x1 / (1.0 - x1))
time_term = ((c1 - c1_old) / dt) * v1 * dx_r
flux_term =  c1 * (1 - x1) * (D_cathode / (R * T)) * ufl.dot(ufl.grad(mu1), ufl.grad(v1)) * dx_r
F_cathode = time_term + flux_term


# Electrolyte Domain
sort_order_electrolyte = np.argsort(submesh_electrolyte_to_mesh)
ct_r = dolfinx.mesh.meshtags(mesh, mesh.topology.dim, submesh_electrolyte_to_mesh[sort_order_electrolyte], 
                            np.full_like(submesh_electrolyte_to_mesh, 1, dtype=np.int32)[sort_order_electrolyte])

dx_r = ufl.Measure("dx", domain=mesh, subdomain_data=ct_r, subdomain_id=1)

x2 = c2 / c_max_electrolyte 
mu2 = R * T * ufl.ln(x2 / (1.0 - x2))
mu2_bar = mu2 - F_const * phi

# Drift-diffusion weak form
time_term = ((c2 - c2_old) / dt) * v2 * dx_r

# Flux term: c2(1-x2) D2/(RT) ∇μ̄2·∇v2
flux_term = - c2 * (1 - x2) * (D_electrolyte / (R * T)) * ufl.dot(ufl.grad(mu2_bar), ufl.grad(v2)) * dx_r

# Complete drift-diffusion equation
F_electrolyte_c = time_term + flux_term

# Poisson equation: ∇²φ = - F(c2 - c2_eq)/(ε0εr)
laplacian_term = ufl.inner(ufl.grad(phi), ufl.grad(vphi)) * dx_r
charge_density = F_const * (c2 - c_eq_electrolyte)
source_term = - (charge_density/(eps_0 * eps_r)) * vphi * dx_r

# Complete Poisson equation
F_electrolyte_phi = laplacian_term + source_term

# =============
# BUTLER-VOLMER
# =============

# Surface charge constants (will be updated in time loop)
sigma_12_const = Constant(mesh, sigma_12)
sigma_23_const = Constant(mesh, sigma_23)

# Overpotentials from surface charge densities
eta1 = -sigma_12_const * A_elec / C_dl12 
eta2 = -sigma_23_const * A_elec / C_dl23 

# BUTLER-VOLMER j12 (Cathode-Electrolyte Interface):
j12_cathodic = c1 * (1 - x2) * k_c01 * F_const * \
               ufl.exp(-(1-alpha_1) * eta1* F_const / (R * T))
j12_anodic = c2 * (1 - x1) * k_a01 * F_const * \
             ufl.exp(alpha_1 * eta1 * F_const / (R * T))
j12 = j12_cathodic - j12_anodic

# BUTLER-VOLMER j23 (Electrolyte-Anode Interface):
j23_anodic = c2 * k_a02 * F_const * \
             ufl.exp(alpha_2 * eta2 * F_const / (R * T))
j23_cathodic = (1 - x2) * k_c02 * F_const * \
               ufl.exp(-(1-alpha_2) * eta2 * F_const / (R * T))

j23 = j23_anodic - j23_cathodic


# =====================
# INTERFACE INTEGRATION 
# =====================

# Interface integration setup
ordered_integration_data = compute_interface_data(ct, interface)
parent_cells_plus = ordered_integration_data[:, 0]
parent_cells_minus = ordered_integration_data[:, 2]
entity_maps[submesh_cathode][parent_cells_minus] = entity_maps[submesh_cathode][parent_cells_plus]
entity_maps[submesh_electrolyte][parent_cells_plus] = entity_maps[submesh_electrolyte][parent_cells_minus]
ordered_integration_data = ordered_integration_data.flatten()
integral_data_interface = [(13, ordered_integration_data)]

# Interface restrictions
cathode_res = "+"
electrolyte_res = "-"
c1_int = c1(cathode_res)
c2_int = c2(electrolyte_res)
v1_int = v1(cathode_res)
v2_int = v2(electrolyte_res)

dInterface = ufl.Measure(
    "dS",
    domain=mesh,
    subdomain_data=integral_data_interface,
    subdomain_id=13,
)

n = ufl.FacetNormal(mesh)
n_cathode = n(cathode_res)
n_electrolyte = n(electrolyte_res)

# Exterior facets
ds = ufl.Measure("ds", domain=mesh, subdomain_data=mt, subdomain_id=anode_marker)

# Interface flux balance: 
# At cathode-electrolyte interface: j12 = J1*F = -J2*F (flux continuity)

F_interface_cathode_electrolyte = - (j12(cathode_res)/F_const) * v1_int * dInterface
F_interface_cathode_electrolyte += (j12(electrolyte_res)/F_const) * v2_int * dInterface

# At electrolyte-anode interface: j23 = J2*F (boundary condition)
F_anode_boundary = - (j23 / F_const) * v2 * ds

# Neumann BC for electric field at cathode-electrolyte interface: ∂φ/∂x|interface = -σ12/(ε0εr)
F_interface_phi_12 = (sigma_12_const / (eps_0 * eps_r))  * vphi(electrolyte_res) * dInterface

# Double Cauchy BC for electric field at electrolyte-anode interface: ∂φ/∂x|interface = -σ23/(ε0εr) and φ|interface = 0
# -> Pure neumann (wrong)
F_interface_phi_23 = (sigma_23_const / (eps_0 * eps_r)) * vphi * ds

# Complete weak form
F_total = F_cathode + F_electrolyte_c + F_electrolyte_phi + F_interface_cathode_electrolyte + F_anode_boundary + F_interface_phi_12 + F_interface_phi_23

F = dolfinx.fem.form(ufl.extract_blocks(F_total), entity_maps=entity_maps)
J_blocks = ufl.extract_blocks(ufl.derivative(F_total, [c1, c2, phi], ufl.TrialFunctions(ufl.MixedFunctionSpace(V_cathode, V_electrolyte_c, V_electrolyte_phi))))
J = dolfinx.fem.form(J_blocks, entity_maps=entity_maps)

solver = NewtonSolver(
    F,
    J,
    [c1, c2, phi],
    max_iterations=20,
    petsc_options={
        "ksp_type": "preonly",
        "pc_type": "lu",
        "pc_factor_mat_solver_type": "mumps",
    },
)

t = 0.0  # Current time (s)
step = 0

def update_surface_charges(step, dt_val):
    global sigma_12, sigma_23
    x_cathode = submesh_cathode.geometry.x[:, 0]
    x_electrolyte = submesh_electrolyte.geometry.x[:, 0]

    cathode_interface_idx = np.argmin(np.abs(x_cathode - l_cathode))
    c1_interface_val = c1.x.array[cathode_interface_idx]
    electrolyte_cathode_interface_idx = np.argmin(np.abs(x_electrolyte - l_cathode))
    c2_interface_cathode_val = c2.x.array[electrolyte_cathode_interface_idx]
    
    electrolyte_anode_interface_idx = np.argmin(np.abs(x_electrolyte - l_total))
    c2_interface_anode_val = c2.x.array[electrolyte_anode_interface_idx]
    
    eta1 = -sigma_12 * A_elec.value / C_dl12
    eta2 = -sigma_23 * A_elec.value / C_dl23
    
    # For j12 (cathode-electrolyte interface)
    x1_12 = c1_interface_val / c_max_cathode.value
    x2_12 = c2_interface_cathode_val / c_max_electrolyte.value
    j12_cathodic = c1_interface_val * (1 - x2_12) * k_c01.value * F_const * \
                   np.exp(-(1-alpha_1.value) * eta1 * F_const / (R * T))
    j12_anodic = c2_interface_cathode_val * (1 - x1_12) * k_a01.value * F_const * \
                 np.exp(alpha_1.value * eta1 * F_const / (R * T))
    j12 = j12_cathodic - j12_anodic
    
    # For j23 (electrolyte-anode interface)
    x2_23 = c2_interface_anode_val / c_max_electrolyte.value
    j23_anodic = c2_interface_anode_val * k_a02.value * F_const * \
                 np.exp(alpha_2.value * eta2 * F_const / (R * T))
    j23_cathodic = (1 - x2_23) * k_c02.value * F_const * \
                   np.exp(-(1-alpha_2.value) * eta2 * F_const / (R * T))
    j23 = j23_anodic - j23_cathodic
    
    # Update surface charges
    j_app = I_app / A_elec.value
    sigma_12 += (j_app - j12) * dt_val
    sigma_23 += (j23 - j_app) * dt_val
    
    sigma_12_const.value = sigma_12
    sigma_23_const.value = sigma_23
    

solver.solve(1e-3)

while t < t_max:
    t += dt_val
    step += 1
    
    update_surface_charges(step, dt_val)
    
    solver.solve(1e-3)
    
    c1_old.x.array[:] = c1.x.array
    c1_old.x.scatter_forward()
    c2_old.x.array[:] = c2.x.array
    c2_old.x.scatter_forward()