Submatrix creation using general index sets

Hello community,

I am assembling a problem with multiple variables (e.g. Navier-Stokes with v and p), hence use a nested matrix for efficient monolithic Jacobian setup.
However, I want to extract submatrices that not necessarily comply with the original nested matrix structure anymore (e.g. since I wanna perform a direct subsolve on some sub-indices, while another iterative solve on the others - inside a block preconditioner).

Consider the following MWE that fails due to l. 60:

#!/usr/bin/env python3
from mpi4py import MPI
from petsc4py import PETSc
import numpy as np
from dolfinx import fem, mesh
import ufl

comm = MPI.COMM_WORLD

msh = mesh.create_box(comm, [np.array([0.0, 0.0, 0.0]),np.array([2.0, 1.0, 1.0])], [1, 1, 1],
               mesh.CellType.tetrahedron, mesh.GhostMode.none)

dim = msh.geometry.dim

# create finite element objects for v and p
P_v = ufl.VectorElement("CG", msh.ufl_cell(), 2)
P_p = ufl.FiniteElement("CG", msh.ufl_cell(), 1)
# function spaces for v and p
V_v = fem.FunctionSpace(msh, P_v)
V_p = fem.FunctionSpace(msh, P_p)
# functions
dv    = ufl.TrialFunction(V_v)            # Incremental velocity
var_v = ufl.TestFunction(V_v)             # Test function
dp    = ufl.TrialFunction(V_p)            # Incremental pressure
var_p = ufl.TestFunction(V_p)             # Test function
v     = fem.Function(V_v, name="v")
p     = fem.Function(V_p, name="p")

form_v = ufl.inner((-ufl.Identity(dim)*p + 0.5*(ufl.grad(v)+ufl.grad(v).T)),ufl.grad(var_v))*ufl.dx
form_p = ufl.div(v)*var_p*ufl.dx

form_lin_vv, form_lin_vp, form_lin_pv = ufl.derivative(form_v, v, dv), ufl.derivative(form_v, p, dp), ufl.derivative(form_p, v, dv)

jac_vv, jac_vp, jac_pv = fem.form(form_lin_vv), fem.form(form_lin_vp), fem.form(form_lin_pv)

K_vv = fem.petsc.assemble_matrix(jac_vv, [])
K_vv.assemble()
K_vp = fem.petsc.assemble_matrix(jac_vp, [])
K_vp.assemble()
K_pv = fem.petsc.assemble_matrix(jac_pv, [])
K_pv.assemble()
K_full_nest = PETSc.Mat().createNest([[K_vv,K_vp],[K_pv,None]], isrows=None, iscols=None, comm=comm)

# index sets
offset_v = v.vector.getOwnershipRange()[0] + p.vector.getOwnershipRange()[0]
iset_v = PETSc.IS().createStride(v.vector.getLocalSize(), first=offset_v, step=1, comm=comm)
offset_p = offset_v + v.vector.getLocalSize()
iset_p = PETSc.IS().createStride(p.vector.getLocalSize(), first=offset_p, step=1, comm=comm)

# get submatrix (to be used in a preconditioner)
A = K_full_nest.createSubMatrix(iset_v,iset_p)

# create another sub index set
iset_r = PETSc.IS().createStride(1, first=offset_v, step=1, comm=comm)

# now suppose we want to eliminate some indices (and group them later to another set)
iset_v = iset_v.difference(iset_r) # subtract

# now we fail - comment to pass
A = K_full_nest.createSubMatrix(iset_v,iset_p)

# quick fix: merge - BUT: this is massively expensive for very large problems!
Kmerge = PETSc.Mat()
K_full_nest.convert("aij", out=Kmerge)
Kmerge.assemble()
K_full_nest = Kmerge

# works
A = K_full_nest.createSubMatrix(iset_v,iset_p)

Error:
petsc4py.PETSc.Error: error code 75
[0] MatCreateSubMatrix() at /usr/local/petsc/src/mat/interface/matrix.c:8431
[0] MatCreateSubMatrix_Nest() at /usr/local/petsc/src/mat/impls/nest/matnest.c:690
[0] MatNestFindSubMat() at /usr/local/petsc/src/mat/impls/nest/matnest.c:677
[0] Arguments are incompatible
[0] Could not find index set

Now comes the immediate remedy: comment l. 60 and merge the nested mat to an aij. Then, arbitrary index sets can be extracted.
However, the remaining problem: This merge of the matrix from nested to aij is very expensive for large problems (> 1M dofs), so I wanna avoid it.

Does anybody know a better solution to this? I know it’s rather a PETSc- than FEniCS-related problem. However, I imagine that it still might be helpful to the community how to efficiently extract submatrices for coupled problems.

Thanks for any help in advance!

Best,
Marc