Efficient way to sum up fem.function's

Ok, so I am solving a homogenous linear wave equation

image

image

image

and for the time derivatives I am considering the central scheme

image

or Newmark’s scheme solving for image

image

Code:

from mpi4py import MPI
from dolfinx import fem, mesh

import dolfinx.fem.petsc
import ufl, time

import numpy as np

msh = mesh.create_unit_square(MPI.COMM_WORLD, 30, 30, cell_type=mesh.CellType.quadrilateral)

t = fem.Constant(msh, 0.)
dt = fem.Constant(msh, 5e-3)
T = 1 
c = 1

V = fem.FunctionSpace(msh, ("Lagrange", 1))
f = fem.Function(V)

uh = fem.Function(V)
v = ufl.TestFunction(V)
u = ufl.TrialFunction(V)

def u_init(x, sigma=0.05, mu=0.5):
    return np.exp(-0.5*((x[0]-mu)/sigma)**2)*np.exp(-0.5*((x[1]-mu)/sigma)**2) 
    
TIME_SCHEME = "central"

if TIME_SCHEME == "central":
    u_n = fem.Function(V)
    u_nn = fem.Function(V)
        
    dudtdt = (u - 2 * u_n + u_nn) / dt**2
    du = (u + u_nn) / 2

    u_n.interpolate(u_init)
    u_nn.interpolate(u_init)

elif TIME_SCHEME == "Newmark":
    u_n = fem.Function(V)
    dudt_n = fem.Function(V)
    dudtdt_n = fem.Function(V)

    u_temp = fem.Function(V)
    dudt_temp = fem.Function(V)

    beta = 0.25
    gamma = 0.5

    dudtdt = u
    du = u_n + dudt_n*dt + ((1-beta)*dudtdt_n + 2*beta*dudtdt) * dt**2/2

    u_n.interpolate(u_init)
    dudt_n.interpolate(u_init)
    dudtdt_n.interpolate(u_init)
    
F = 1 / c**2 * ufl.inner(dudtdt, v) * ufl.dx \
    + ufl.inner(ufl.grad(du), ufl.grad(v)) * ufl.dx \
    - ufl.dot(f, v) * ufl.dx 

a, L = ufl.system(F)
petsc_options = {"ksp_type": "preonly",
                "pc_type": "lu", "pc_factor_mat_solver_type": "mumps"}
problem = dolfinx.fem.petsc.LinearProblem(a, L, u=uh, bcs=[], petsc_options=petsc_options)

start = time.perf_counter()
while t.value < T:
    t.value += dt.value
    problem.solve()
                            
    if TIME_SCHEME == "central":
        u_nn.x.array[:] = u_n.x.array
        u_n.x.array[:] = uh.x.array  
        
    elif TIME_SCHEME == "Newmark":
        u_temp.x.array[:] = u_n.x.array + dudt_n.x.array*dt + ((1-beta)*dudtdt_n.x.array + 2*beta*uh.x.array) * dt**2/2
        dudt_temp.x.array[:] = dudt_n.x.array + ((1-gamma)*dudtdt_n.x.array + gamma*uh.x.array) * dt

        u_n.x.array[:] = u_temp.x.array
        dudt_n.x.array[:] = dudt_temp.x.array
        dudtdt_n.x.array[:] = uh.x.array

end = time.perf_counter()
print(end-start)

Running this code with

TIME_SCHEME = "central"

takes

0.33813453299990215

but with

TIME_SCHEME = "Newmark"

it takes

40.131102617999204