Use of G(u,t) for PETSc TS

Hey there,

Im trying to implement a time stepping algorithm for coupled differential equation. Whenever I implement F(u,udot,t)=0 everything seems to work fine. But trying to use setRHSFunction does not work. It seems to me that the solver just ignores G because the result does not change when I set G to zero. Am I making any obvious mistakes? Code example, below, thanks in advance!

import numpy as np
from mpi4py import MPI
from dolfinx import mesh, log
from dolfinx.io import XDMFFile
from dolfinx.fem import (functionspace, Constant, Function, locate_dofs_topological, dirichletbc, DirichletBC,
                         Expression, assemble_scalar, assemble_vector, apply_lifting, set_bc, form)
import dolfinx.fem.petsc as petsc
from basix.ufl import element
from ufl import derivative, dx, div, grad, inner, SpatialCoordinate, TrialFunction, TestFunction
from petsc4py import PETSc
import matplotlib.pyplot as plt

class PETScTSSolver:
    # F(t, u, u') = G(t, u)
    def __init__(self, ts_problem, u, bcs):
        self.u = u
        self.V = u.ufl_function_space()
        self.comm = self.V.mesh.comm
        self.u_t = Function(self.V)
        self.sigma_t = Constant(self.V.mesh, 0.0)
        self.f = ts_problem.F(self.u, self.u_t)
        self.g = ts_problem.G(self.u)
        self.f_j = self.sigma_t * ts_problem.dFdut(self.u, self.u_t) + ts_problem.dFdu(self.u, self.u_t)
        self.g_j = ts_problem.dGdu(self.u)
        self.bcs = bcs
        self.ts = PETSc.TS().create(self.comm)


    def evalIFunction(self, ts, t, x, xdot, F):
        r.interpolate(r_fun(t))
        x.copy(self.u.x.petsc_vec)
        xdot.copy(self.u_t.x.petsc_vec)

        with F.localForm() as f_local:
            f_local.set(0.0)
        petsc.assemble_vector(F, form(self.f))
        bc1_fun.interpolate(u_exact_fun(t))
        bc2_fun.interpolate(u_exact_fun(t))

        apply_lifting(F, a=[form(self.f_j)], bcs=[self.bcs], x0=[x], alpha=-1.0)
        F.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        set_bc(F, self.bcs, self.u.x.petsc_vec, -1.0)

        F.assemble()
        return True

    def evalRHSFunction(self, ts, t, x, G):
        r.interpolate(r_fun(t))
        x.copy(self.u.x.petsc_vec)

        with G.localForm() as g_local:
            g_local.set(0.0)

        petsc.assemble_vector(G, form(self.g))
        bc1_fun.interpolate(u_exact_fun(t))
        bc2_fun.interpolate(u_exact_fun(t))
        
        apply_lifting(G, a=[form(self.g_j)], bcs=[self.bcs], x0=[x], alpha=-1.0)
        G.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        set_bc(G, self.bcs, self.u.x.petsc_vec, -1.0)
        
        #G.set(0.0)
        G.assemble()
        return True

    def evalIJacobian(self, ts, t, x, xdot, sigma_t, J, P):
        x.copy(self.u.x.petsc_vec)
        xdot.copy(self.u_t.x.petsc_vec)
        self.sigma_t.value = sigma_t
        J.zeroEntries()
        petsc.assemble_matrix(J, form(self.f_j), self.bcs)
        J.assemble()
        return True

    def set_from_options(self):
        ksp = self.ts.getSNES().getKSP()
        ksp.setType(PETSc.KSP.Type.PREONLY) 
        pc = ksp.getPC()
        pc.setType(PETSc.PC.Type.LU)
        ksp.setFromOptions()
        pc.setFromOptions()


    def solve(self, t0, dt_init, max_t, max_steps):
        w = Function(self.V)
        u_exact.x.petsc_vec.copy(w.x.petsc_vec)
        #self.ts.setEquationType(1000)
        self.ts.setRHSFunction(self.evalRHSFunction)
        self.ts.setIFunction(self.evalIFunction)#, self.F.vec())
        self.ts.setIJacobian(self.evalIJacobian)#, J=self.FJ.mat(), P=self.FJ.mat())
        self.ts.setSolution(w.x.petsc_vec)
        self.ts.setType(PETSc.TS.Type.ARKIMEX)  # ARKIMEX
        self.ts.setMaxTime(max_t)
        self.ts.solve(w.x.petsc_vec)
        w.x.petsc_vec.copy(self.u.x.petsc_vec)


# 1d geometry parameters
interval_length = 1
cells_number = 4
t = 0

msh = mesh.create_unit_interval(MPI.COMM_WORLD, cells_number)
CG1_elem = element("Lagrange", msh.basix_cell(), 1)
V = functionspace(msh, CG1_elem)

u_test = TestFunction(V)
u = Function(V)

tdim = msh.topology.dim
fdim = tdim - 1
msh.topology.create_connectivity(fdim, tdim)

boundary_conditions = [
    (lambda x: np.isclose(x[0], 0), 'dof_bdry_1'),
    (lambda x: np.isclose(x[0],  1), 'dof_bdry_2')
]

dof_boundaries = {}

for marker, dof_name in boundary_conditions:
    facets_bdry = mesh.locate_entities_boundary(msh, dim=fdim, marker=marker)
    dof_boundaries[dof_name] = locate_dofs_topological(V=V, entity_dim=fdim, entities=facets_bdry)


class u_exact_fun():
    def __init__(self, t):
        self.t = t

    def __call__(self, x):
        values = np.zeros((1, x.shape[1]), dtype=PETSc.ScalarType)
        values[0] = np.exp(-self.t) * x[0]**2 
        return values


class r_fun():
    def __init__(self, t):
        self.t = t

    def __call__(self, x):
        values = np.zeros((1, x.shape[1]), dtype=PETSc.ScalarType)
        values[0] = -np.exp(-self.t) * (x[0]**2 + 2) 
        return values

class DiffusionProblem():

    def F(self, u, u_t):
        return u_t * u_test * dx #+ inner(grad(u), grad(u_test)) * dx - r * u_test * dx

    def dFdu(self, u, u_t):
        return derivative(self.F(u, u_t), u)

    def dFdut(self, u, u_t):
        return derivative(self.F(u, u_t), u_t)

    def G(self, u):
        return -inner(grad(u), grad(u_test)) * dx + r * u_test * dx

    def dGdu(self, u):
        return derivative(self.G(u), u)

u_exact = Function(V)
u_exact0 = Function(V)
u_exact.interpolate(u_exact_fun(t))

r = Function(V)
r.interpolate(r_fun(t))

bc1_fun = Function(V)
bc1_fun.interpolate(u_exact_fun(t))
bc1 = dirichletbc(bc1_fun, dof_boundaries['dof_bdry_1'])

bc2_fun = Function(V)
bc2_fun.interpolate(u_exact_fun(t))
bc2 = dirichletbc(bc2_fun, dof_boundaries['dof_bdry_2'])

Dirichlet_bcs = [bc1, bc2]

fig, ax = plt.subplots()

PETScTSSolver_instance = PETScTSSolver(DiffusionProblem(), u, Dirichlet_bcs)
PETScTSSolver_instance.set_from_options()
PETScTSSolver_instance.solve(t0=0.0, dt_init=1e-3, max_t=10, max_steps=1000)