Hi @dokken and thanks for your reply. The idea is that to accelerate the assembly of F I can perform reduced integration on a (small) subset of the cells, with proper weights obtained by sparse non-negative least squares.
Below an illustrative example on a Poiseuille flow, at the bottom I check that the full residual and the one assembled element by element are the same. These two components are then matrix and rhs of my sparse least squares, and I will get some weights. However, I then need to be able to quickly integrate over the non-zero weighted cells and this is where I wonder if there is a better way than to define a dx measure. For example, in the past I had used for immersed methods the following repo augustjohansson/customquad: Library for using custom quadrature rules at runtime in FEniCSx, where the kernel was extracted to allow the use of custom quadrature rules, so something around those lines should definitely be doable, and it should be even easier since the quadrature points are the standard ones and only the weights change.
I hope this clarifies what I am trying to do, if not I am happy to provide further insight.
from __future__ import annotations
import basix
import dolfinx
import numpy as np
import ufl
from dolfinx.fem.petsc import apply_lifting, assemble_matrix, assemble_vector, set_bc
from mpi4py import MPI
from petsc4py import PETSc
from dolfinx.fem.petsc import create_vector, assemble_vector, create_matrix, assemble_matrix
nx, ny = 50, 50
mesh = dolfinx.mesh.create_rectangle(
MPI.COMM_WORLD,
[(0.0, 0.0), (1.0, 1.0)],
[nx, ny],
cell_type=dolfinx.mesh.CellType.triangle,
)
ve = basix.ufl.element("CG", mesh.basix_cell(), 2, shape=(2,))
pe = basix.ufl.element("CG", mesh.basix_cell(), 1)
ve_space = dolfinx.fem.functionspace(mesh, ve)
pe_space = dolfinx.fem.functionspace(mesh, pe)
VP_elem = basix.ufl.mixed_element([ve, pe])
VP = dolfinx.fem.functionspace(mesh, VP_elem)
u, p = ufl.TrialFunctions(VP)
v, q = ufl.TestFunctions(VP)
up = dolfinx.fem.Function(VP)
u, p = ufl.split(up)
def inlet_boundary(x):
return np.isclose(x[0], 0.0)
def wall_boundary(x):
return np.logical_or(np.isclose(x[1], 0.0), np.isclose(x[1], 1.0))
def outlet_boundary(x):
return np.isclose(x[0], 1.0)
boundaries = [
(1, wall_boundary),
(2, inlet_boundary),
(3, outlet_boundary),
]
facet_indices, facet_markers = [], []
facetdim = mesh.topology.dim - 1
mesh.topology.create_connectivity(facetdim, mesh.topology.dim)
num_facets_local = (
mesh.topology.index_map(facetdim).size_local
+ mesh.topology.index_map(facetdim).num_ghosts
)
facet_markers = np.full(num_facets_local, -1, dtype=np.int32)
for marker, locator in boundaries:
boundary_facets = dolfinx.mesh.locate_entities_boundary(mesh, facetdim, locator)
midpoints = dolfinx.mesh.compute_midpoints(mesh, facetdim, boundary_facets)
boundary_facets = boundary_facets[locator(midpoints.T)]
facet_markers[boundary_facets] = marker
is_marked = facet_markers >= 0
facets = np.flatnonzero(is_marked)
values = facet_markers[facets]
facet_tag = dolfinx.mesh.meshtags(mesh, facetdim, facets, values)
bcs = []
inlet_facets = facet_tag.find(2)
inlet_dofs_p = dolfinx.fem.locate_dofs_topological(
(VP.sub(1), pe_space), mesh.topology.dim - 1, inlet_facets
)
one_p_func = dolfinx.fem.Function(pe_space)
one_p_func.interpolate(lambda x: np.ones_like(x[1]))
p_inlet = dolfinx.fem.dirichletbc(one_p_func, inlet_dofs_p, VP.sub(1))
wall_facets = facet_tag.find(1)
wall_dofs_v = dolfinx.fem.locate_dofs_topological(
(VP.sub(0), ve_space), mesh.topology.dim - 1, wall_facets
)
zero_v_func = dolfinx.fem.Function(ve_space)
u_wall = dolfinx.fem.dirichletbc(zero_v_func, wall_dofs_v, VP.sub(0))
outlet_facets = facet_tag.find(3)
outlet_dofs_p = dolfinx.fem.locate_dofs_topological(
(VP.sub(1), pe_space), mesh.topology.dim - 1, outlet_facets
)
zero_p_func = dolfinx.fem.Function(pe_space)
p_outlet = dolfinx.fem.dirichletbc(zero_p_func, outlet_dofs_p, VP.sub(1))
bcs.extend([u_wall, p_inlet, p_outlet])
rho = 1.0
mu = 0.01
F_integrand = (
rho * ufl.inner(ufl.grad(u) * u, v)
+ mu * ufl.inner(ufl.grad(u), ufl.grad(v))
- p * ufl.div(v)
+ ufl.div(u) * q
)
F = F_integrand * ufl.dx
J = ufl.derivative(F, up)
J_form = dolfinx.fem.form(J)
F_form = dolfinx.fem.form(F)
vp_res = dolfinx.fem.Function(VP)
residual = 10**8
snapshots = []
iter = 0
ksp = PETSc.KSP().create(MPI.COMM_WORLD)
ksp.setType(PETSc.KSP.Type.PREONLY)
ksp.getPC().setType(PETSc.PC.Type.LU)
while residual > 1e-14 and iter < 10:
b = assemble_vector(F_form)
apply_lifting(b, [J_form], [bcs], x0=[up.x.petsc_vec], scale=-1.0)
b.ghostUpdate(
addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE
)
set_bc(b, bcs, up.x.petsc_vec, scale=-1.0)
b.ghostUpdate(
addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD
)
b.scale(-1.0)
A = assemble_matrix(J_form, bcs=bcs)
A.assemble()
ksp.setOperators(A)
ksp.solve(b, vp_res.x.petsc_vec)
vp_res.x.scatter_forward()
up.x.petsc_vec.axpy(1.0, vp_res.x.petsc_vec)
up.x.scatter_forward()
residual = b.norm()
snapshots.append(up.x.array.copy())
A.destroy()
b.destroy()
vp_res.x.petsc_vec.zeroEntries()
print(iter, residual)
iter += 1
snapshots = np.array(snapshots)
tdim = mesh.topology.dim
n_cells_local = mesh.topology.index_map(tdim).size_local
cell_forms = []
for cell in range(n_cells_local):
tag = dolfinx.mesh.meshtags(
mesh,
tdim,
np.array([cell], dtype=np.int32),
np.array([1], dtype=np.int32),
)
dx_cell = ufl.Measure(
"dx",
domain=mesh,
subdomain_data=tag,
)
F_cell = F_integrand * dx_cell(1)
J_cell = ufl.derivative(F_cell, up)
cell_forms.append((dolfinx.fem.form(F_cell), dolfinx.fem.form(J_cell)))
def assemble_full_residual(state_vec: np.ndarray) -> np.ndarray:
up.x.array[:] = state_vec
up.x.scatter_forward()
b = assemble_vector(F_form)
b.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)
b.scale(-1.0)
out = b.array.copy()
b.destroy()
return out
def assemble_cell_residuals(state_vec: np.ndarray) -> np.ndarray:
up.x.array[:] = state_vec
up.x.scatter_forward()
cell_residuals = []
for F_cell_form, J_cell_form in cell_forms:
b = assemble_vector(F_cell_form)
b.ghostUpdate(addv=PETSc.InsertMode.INSERT_VALUES, mode=PETSc.ScatterMode.FORWARD)
b.scale(-1.0)
cell_residuals.append(b.array.copy())
b.destroy()
return np.sum(cell_residuals, axis=0)
for k, snap in enumerate(snapshots):
r_full = assemble_full_residual(snap)
r_cells = assemble_cell_residuals(snap)
mismatch = np.linalg.norm(r_full - r_cells)
print(f"Snapshot {k}: ||r_full - sum_i r_i|| = {mismatch:.3e}")