An assembly about optimizing the bilinear matrix A

Hi,guys!I am optimizing a code for solving large PDEs.
It will probably solve a system of linear equations with a matrix of 150000 * 150000 in the end.
Therefore, assembling the bilinear matrix A at each time step requires a lot of time.
But A=A0+A1,where A0 remains constant at each time step.A1 contains a weak form of upwind format, therefore it is variable.
So my idea is to assemble A1 only in the time loop and then add it to A0 to form A.
use A.axpy(1.0,A0), But the program was killed.
Below, I have provided the minimum runnable code:

You only need to view the code after the weak form.

from mpi4py import MPI
import petsc4py
from petsc4py import PETSc
import numpy as np
import dolfinx
from dolfinx import default_real_type, fem, io, mesh,default_scalar_type
from dolfinx.fem.petsc import assemble_matrix, assemble_vector
from ufl import (CellDiameter, FacetNormal, TestFunction, TrialFunction, avg,
                 conditional, div, dot, dS, ds, dx, grad, gt, inner, outer,
                 TrialFunctions,TestFunctions,curl,cross)
import ufl
from dolfinx.io import gmshio,XDMFFile
from basix.ufl import element,mixed_element 
from dolfinx.fem.petsc import (apply_lifting, assemble_matrix, assemble_vector, 
                                 create_vector, set_bc,create_matrix)
import time

startT=time.time()
num_time_steps =1000
out_interval=1
Ekman=0.001
Pm=5
Ra=100
Pr=1
ro=20/13
ri=7/13
tag="itertry3"


read_mesh = io.XDMFFile(MPI.COMM_WORLD, "cp6n_0.24w.xdmf" ,"r")
msh=read_mesh.read_mesh(name="Grid")

gdim = msh.geometry.dim


U_element=element("RT", msh.basix_cell(), 2)
#U_element=element("DG", msh.basix_cell(), k+1, shape=(msh.geometry.dim,))
P_element=element("DG", msh.basix_cell(), 1)
A_element=element("N1curl", msh.basix_cell(), 2)
T_element=element("DG", msh.basix_cell(), 1)
MIXED= mixed_element([U_element, P_element, A_element,T_element])


W =fem.functionspace(msh,MIXED)
W0,Wmap0=W.sub(0).collapse()
W1,Wmap1=W.sub(1).collapse()
W2,Wmap2=W.sub(2).collapse()
W3,Wmap3=W.sub(3).collapse()


(u,P,A,T)=ufl.TrialFunctions(W) 
(v,q,Av,Tv)=ufl.TestFunctions(W)


#####dofs
imap = msh.topology.index_map(gdim)
num_cells = imap.size_local
ghost_cells = imap.num_ghosts
print(f"ID {msh.comm.rank}, num owned cells {num_cells} num ghost cells {ghost_cells},num dofs {W.dofmap.index_map.size_local * W.dofmap.index_map_bs}", flush=True)

u_n_1=fem.Function(W0)
u_n_2=fem.Function(W0)
u_D = fem.Function(W0)
A_n_1=fem.Function(W2)
A_n_2=fem.Function(W2)
T_n_1=fem.Function(W3)
T_n_2=fem.Function(W3)
T_D=fem.Function(W3)
T_D0=fem.Function(W3)

####weak form    
t_end=0.01
delta_t = fem.Constant(msh, default_real_type(t_end / num_time_steps))
alpha = fem.Constant(msh, default_real_type(6.0))

h = CellDiameter(msh)
n = FacetNormal(msh)
x = ufl.SpatialCoordinate(msh)

def mean(vn,vn_1):
    return 1/2*(vn+vn_1)

def star(v_n_1,v_n_2):
    return 1/2*(3*v_n_1-v_n_2)

def jump(phi, n):
    return outer(phi("+"), n("+")) + outer(phi("-"), n("-"))

lmbda = conditional(gt(dot(star(u_n_1,u_n_2), n), 0), 1, 0)
u_uw = lmbda("+") * mean(u,u_n_1)("+") + lmbda("-") * mean(u,u_n_1)("-")
T_uw = lmbda("+") * mean(T,T_n_1)("+") + lmbda("-") * mean(T,T_n_1)("-")

F0=Ekman*(inner(u / delta_t, v) * dx - inner(u_n_1 / delta_t, v) * dx)

F1= Ekman*(-inner(mean(u,u_n_1), div(outer(v, star(u_n_1,u_n_2)))) * dx +\
    inner((dot(star(u_n_1,u_n_2), n))("+") * u_uw, v("+")) * dS + \
    inner((dot(star(u_n_1,u_n_2), n))("-") * u_uw, v("-")) * dS + \
    inner(dot(star(u_n_1,u_n_2), n) * lmbda * u, v) * ds)


F0+=Ekman * (inner(grad(mean(u,u_n_1)), grad(v)) * dx
                     - inner(avg(grad(mean(u,u_n_1))), jump(v, n)) * dS
                     - inner(jump(mean(u,u_n_1), n), avg(grad(v))) * dS
                     + (alpha / avg(h)) * inner(jump(mean(u,u_n_1), n), jump(v, n)) * dS
                     - inner(grad(mean(u,u_n_1)), outer(v, n)) * ds
                     - inner(outer(mean(u,u_n_1), n), grad(v)) * ds
                     + (alpha / h) * inner(outer(mean(u,u_n_1), n), outer(v, n)) * ds)

F0-= Ekman * (- inner(outer(mean(u_D,u_n_1), n), grad(v)) * ds
                + (alpha / h) * inner(outer(mean(u_D,u_n_1), n), outer(v, n)) * ds)


F0+=2*inner(cross(ufl.as_vector([0,0,1]),mean(u,u_n_1)),v)*dx


F0+=- inner(P, div(v)) * dx - inner(div(u), q) * dx


F0+= 1.0/Pm * inner(A/ delta_t,cross(curl(star(A_n_1,A_n_2)),v))*dx - 1.0/Pm * inner(A_n_1/ delta_t,cross(curl(star(A_n_1,A_n_2)),v))*dx
F0+= 1.0/Pm *inner(cross(curl(star(A_n_1,A_n_2)),mean(u,u_n_1)),cross(curl(star(A_n_1,A_n_2)),v))*dx

F0-= Ra * inner(ufl.as_vector([x[0],x[1],x[2]])*star(T_n_1,T_n_2),v) *dx 

F0+= inner(A/ delta_t,Av) * dx -inner(A_n_1/ delta_t,Av) * dx

F0+=inner(cross(curl(star(A_n_1,A_n_2)),mean(u,u_n_1)),Av) *dx

F0+=1/Pm *inner(curl(mean(A,A_n_1)),curl(Av)) *dx

F0+= inner(T / delta_t, Tv) * dx -inner(T_n_1 / delta_t, Tv) * dx

F1+= -inner(mean(T,T_n_1), div(outer(Tv, star(u_n_1,u_n_2)))) * dx +\
    inner((dot(star(u_n_1,u_n_2), n))("+") * T_uw, Tv("+")) * dS + \
    inner((dot(star(u_n_1,u_n_2), n))("-") * T_uw, Tv("-")) * dS + \
    inner(dot(star(u_n_1,u_n_2), n) * lmbda * T, Tv) * ds

F0+= (1/Pr)*(inner(grad(mean(T,T_n_1)), grad(Tv)) * dx
                     - inner(avg(grad(mean(T,T_n_1))), jump(Tv, n)) * dS
                     - inner(jump(mean(T,T_n_1), n), avg(grad(Tv))) * dS
                     + (alpha / avg(h)) * inner(jump(mean(T,T_n_1), n), jump(Tv, n)) * dS
                     - inner(grad(mean(T,T_n_1)),outer(Tv,n)) * ds
                     - inner(outer(mean(T,T_n_1), n), grad(Tv)) * ds
                     + (alpha / h) * inner(mean(T,T_n_1), Tv) * ds)
                     
aa0=ufl.lhs(F0)
aa1=ufl.lhs(F1)     
LL=ufl.rhs(F1+F0)    

a0=fem.form(aa0)   
a1=fem.form(aa1)  

a=fem.form(aa0+aa1)
L=fem.form(LL)

##
star=time.time()
A0 = create_matrix(a0)    
A1 = create_matrix(a0)   
A = create_matrix(a0)     

bcs=[]

A= assemble_matrix(a, bcs=bcs)    
A.assemble()

A0 = assemble_matrix(a0, bcs=bcs)
A0.assemble()

A1 = assemble_matrix(a1, bcs=bcs)
A1.assemble()

A.zeroEntries()
A.axpy(1.0,A0)
#A.axpy(1.0,A1)
A.assemble()
b=create_vector(L)

endtime=time.time()
print(endtime-star)

ksp = PETSc.KSP().create(msh.comm)  
ksp.setOperators(A)

ksp.setInitialGuessNonzero(True)     
ksp.setConvergenceHistory(reset=True)  

ksp.setType("gmres")          
ksp.getPC().setType("ilu")      

opts = PETSc.Options()  
opts["ksp_rtol"] = 1.0e-8   
ksp.setFromOptions()

A.axpy(1.0,A0) Resulted in the program being killed.
Thank you very much for any comments!Thanks!

should A1 be created based of A1?
If you know the sparisty patterns are the same you can optimize the multiplication:
https://petsc.org/release/manualpages/Mat/MatAXPY/
as for instance done in

Thanks dokken!
If it is serial, it runs normally. But if parallel computing is performed, the memory will be full, thus killing the program.
Does this mean that parallel is more memory intensive than serial?

As i do not have your mesh, i cant reproduce the issue. Parallel shouldn’t be much more memory intensive than serial (there will be some extra memory usage to account for ghost cells and dofs). How much memory do you have available?

I don’t know how to upload my file (. xdmf file)

https://easylink.cc/5ugjtp

I created a hyperlink and I don’t know if it can be used

It cannot, as you have use the hdf5 format for storing the mesh, meaning that one would need both the xdmf file and the h5 file.

Could you try running the same assembly on a built in mesh (unit cube) and try to reproduce the memory error? Then it would be easy for anyone else to try to reproduce it.

Hello dokken !
I have already resolved this issue
I have added the following code:

A_sp = dolfinx.fem.create_sparsity_pattern(a)
A_sp.finalize()
A = dolfinx.cpp.la.petsc.create_matrix(msh.comm, A_sp)

in the timelope:

A.zeroEntries()
assemble_matrix(A,a1,bcs)
A.setOption(PETSc.Mat.Option.KEEP_NONZERO_PATTERN, True) # type: ignore
A.setOption(PETSc.Mat.Option.IGNORE_ZERO_ENTRIES, False) # type: ignore
A.assemble()
A.axpy(1.0,A0,PETSc.Mat.Structure.SUBSET_NONZERO_PATTERN)

where a1 is time dependent form.A0 is time independent.A0 has already been assembled.

I uploaded both my XDMF and H5 simultaneously:
https://easylink.cc/gmmjvf
https://easylink.cc/8bhsvd