I would suggest an approach like what is suggested in: https://link.springer.com/content/pdf/10.1007/s00366-025-02107-1.pdf
Below I have made an alternate approach, based of the comments in:
[petsc-users] Outer Product of two vectors
# Usage of PETSc Python-context to assemble to outer product of two vectors matrix free.
# Author: Jørgen S. Dokken
# SPDX-license License-Identifier: MIT
from mpi4py import MPI
from petsc4py import PETSc
import dolfinx.fem.petsc
import ufl
import numpy as np
class CustomOperator:
def __init__(self, form1, form2):
self.form1 = dolfinx.fem.form(form1)
self.form2 = dolfinx.fem.form(form2)
self.b1 = dolfinx.fem.petsc.create_vector(self.form1)
self.b2 = dolfinx.fem.petsc.create_vector(self.form2)
self._assemble_vectors()
def _assemble_vectors(self):
with self.b1.localForm() as loc1, self.b2.localForm() as loc2:
loc1.set(0)
loc2.set(0)
dolfinx.fem.petsc.assemble_vector(self.b1, self.form1)
self.b1.ghostUpdate(PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE)
self.b1.ghostUpdate(PETSc.InsertMode.INSERT, PETSc.ScatterMode.FORWARD)
dolfinx.fem.petsc.assemble_vector(self.b2, self.form2)
self.b2.ghostUpdate(PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE)
self.b2.ghostUpdate(PETSc.InsertMode.INSERT, PETSc.ScatterMode.FORWARD)
def mult(self, mat, X, Y):
# Reassemble vectors, in case of coefficients or other time dependecy
self._assemble_vectors()
# outer(b1, b2) X = dot(b2, X) b1
val = self.b2.dot(X)
self.b1.copy(Y)
Y.scale(val)
def getDiagonal(self, mat, vec):
vec.pointwiseMult(self.b1, self.b2)
mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 15, 7)
V = dolfinx.fem.functionspace(mesh, ("Lagrange", 2))
u = ufl.TrialFunction(V)
v = ufl.TestFunction(V)
x = ufl.SpatialCoordinate(mesh)
form1 = dolfinx.fem.form(u * ufl.ds)
form2 = dolfinx.fem.form(v * ufl.dx)
assembler = CustomOperator(form1, form2)
sp = dolfinx.cpp.la.SparsityPattern(
mesh.comm,
[V.dofmap.index_map, V.dofmap.index_map],
[V.dofmap.index_map_bs, V.dofmap.index_map_bs],
)
sp.finalize()
A = dolfinx.cpp.la.petsc.create_matrix(mesh.comm, sp, "python")
A.setPythonContext(assembler)
# x = ufl.SpatialCoordinate(mesh)
# f = x[0] * v * ufl.dx
# # Set up matrix free solver
# ksp = PETSc.KSP().create(mesh.comm)
# ksp.setOperators(A)
# ksp.setType("cg")
# ksp.getPC().setType("jacobi")
# ksp.setErrorIfNotConverged(True)
# uh = dolfinx.fem.Function(V)
# b = dolfinx.fem.petsc.assemble_vector(dolfinx.fem.form(f))
# b.ghostUpdate(PETSc.InsertMode.ADD, PETSc.ScatterMode.REVERSE)
# ksp.solve(b, uh.x.petsc_vec)
# print(uh.x.array)
# Verification
c = dolfinx.fem.petsc.create_vector(dolfinx.fem.form(form1))
with c.localForm() as loc:
loc.set(1)
d = dolfinx.fem.petsc.create_vector(dolfinx.fem.form(form1))
A.mult(c, d)
A_numpy = np.outer(assembler.b1.array, assembler.b2.array)
d_numpy = np.dot(A_numpy, c.array)
np.testing.assert_allclose(d.array, d_numpy)
print(d.array, d_numpy)
Note that with such an approach, one cannot use a direct solver, as the full dense matrix is newer actually formed.