Speeding up time-dependent diffusion equation

Hi all,

I am working on solving a diffusion equation with a time-dependent diffusion tensor. For that, I am following the Diffusion of a Gaussian function demo, as well as the discussion in the similar thread Time dependent diffusion constant in dolfinx. However, because my diffusion coefficient is a tensor, I cannot use a simple trick from the previous post (D_.value[0]= 0.1*D) to reassign the value of the diffusion constituent and have to rewrite the tensor completely. As a result, the following re-assembly of the A matrix is taking a very long time on each timestep and I was wondering if there is a way to speed things up.

The example code is as follows (very minor edits to the original Diffusion of a Gaussian function demo). While this example still works quite fast although noticeably slower than the original, this approach results in an extremely large delay for the more complex problem that I am trying to solve.

import numpy as np

from mpi4py import MPI
from petsc4py import PETSc

from dolfinx import fem, mesh, io, plot

# Define temporal parameters
t = 0 # Start time
T = 1.0 # Final time
num_steps = 50
dt = T / num_steps # time step size

# Define mesh
nx, ny = 50, 50
domain = mesh.create_rectangle(MPI.COMM_WORLD, [np.array([-2, -2]), np.array([2, 2])],
                               [nx, ny], mesh.CellType.triangle)
V = fem.FunctionSpace(domain, ("CG", 1))


# Create initial condition
def initial_condition(x, a=5):
    return np.exp(-a*(x[0]**2+x[1]**2))
u_n = fem.Function(V)
u_n.name = "u_n"
u_n.interpolate(initial_condition)

# Create boundary condition
fdim = domain.topology.dim - 1
boundary_facets = mesh.locate_entities_boundary(
    domain, fdim, lambda x: np.full(x.shape[1], True, dtype=bool))
bc = fem.dirichletbc(PETSc.ScalarType(0), fem.locate_dofs_topological(V, fdim, boundary_facets), V)



uh = fem.Function(V)
uh.name = "uh"
uh.interpolate(initial_condition)

import ufl
u, v = ufl.TrialFunction(V), ufl.TestFunction(V)
f = fem.Constant(domain, PETSc.ScalarType(0))

D = ufl.as_tensor(((1, 0), (0, 2)))
a = u * v * ufl.dx + dt*ufl.inner(D * ufl.grad(u), ufl.grad(v)) * ufl.dx
L = (u_n + dt * f) * v * ufl.dx

bilinear_form = fem.form(a)
linear_form = fem.form(L)

A = fem.petsc.assemble_matrix(bilinear_form, bcs=[bc])
A.assemble()
b = fem.petsc.create_vector(linear_form)

solver = PETSc.KSP().create(domain.comm)
solver.setOperators(A)
solver.setType(PETSc.KSP.Type.PREONLY)
solver.getPC().setType(PETSc.PC.Type.LU)

for i in range(num_steps):
    t += dt

    print(t)
    D = ufl.as_tensor(((1 - t, 0), (0, 2 - t)))
    a = u * v * ufl.dx + dt*ufl.inner(D * ufl.grad(u), ufl.grad(v)) * ufl.dx
    print('----> Assembling')
    A = fem.petsc.assemble_matrix(fem.form(a), bcs=[bc])
    A.assemble()
    solver.setOperators(A)

    # Update the right hand side reusing the initial vector
    with b.localForm() as loc_b:
        loc_b.set(0)
    fem.petsc.assemble_vector(b, linear_form)

    # Apply Dirichlet boundary condition to the vector
    fem.petsc.apply_lifting(b, [bilinear_form], [[bc]])
    b.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
    fem.petsc.set_bc(b, [bc])

    # Solve linear problem
    solver.solve(b, uh.vector)
    uh.x.scatter_forward()

    # Update solution at previous time step (u_n)
    u_n.x.array[:] = uh.x.array