Run time of fenics adjoint

Hi all, I am using FEniCs and Dolfin adjoint to solve an inverse problem. I found that the time needed to compute the gradient is much longer than the time needed for forward modeling.
I am wondering if it is reasonable? And if you have any tips and notes about how to speed up gradient computation, it would be really helpful! Thanks in advance!

forward modeling takes:30.160411596298218
J computation modeling takes:0.000675201416015625
gradient computation modeling takes:111.1005494594574

This depends on how you build up your code, what dependencies you have etc.
One thing that might help is to use the optimize_tape function:

In general solving the forward and adjoint model should take roughly the same time.

If your problem is time dependent, one thing to avoid is
solve(a==L,u) as it requires certain structures to be rebuilt and recomputed at every time step (that should be time independent). This point holds for both the forward solve and the corresponding automatically derived adjoint solve.

1 Like

Hi Dokken, Thanks for the prompt response. I am looking into optimize_tape function. I tried to look for some examples of how to use optimize_tape function, but find it hard to find any. I am wondering if you know any samples I may check, or how I can figure this out?

My problem is time-dependent so I do have code like solve(a==L,u) in my code. I am wondering how can I avoid using this, are there any reference codes I may check to learn how to avoid this? Thanks a lot!

 b1 = assemble(L1)
[bc.apply(A1, b1) for bc in bcu]
 solve(A1, u1.vector(), b1, 'gmres', 'default')

You can define an LUSolver, as shown here: Minimeze function in Navier-Sokes's problem - #2 by dokken
If A1 does not change per time step, you should only assemble it once, otherwise you can use assemble(a, tensor=A1)
Without a minimal code example illustrating what you are doing (that can reproduce your timings, I cannot offer you any more guideance

1 Like

Thank you so much Dokken, this is so helpful! I will check the notes you shared first. If I still cannot figure this out, I will try to come up with a minimal code example illustrating what you are doing (that can reproduce your timings)!

Hi Dokken, I have a minimal sample code here. I have defined the solver as shown in the reference post you shared. However, the backward time is still significantly longer than the forward time. Do you have any hints/thoughts on what I may do to speed things up here?

some updates: I also try to profile the code, and it seems that
assembling.py:79(assemble)
is taking most amount of time



from dolfin import *
from dolfin_adjoint import *
import numpy as np
import time

def forward(u_init,V,total_steps,objective,storage,points,mesh,timesteps=1,store=False):

    dt = 0.001
    mu = 0.1
    nu = mu
    rho = 1            # density

    inflow   = 'near(x[0], 0)'
    outflow  = 'near(x[0], 1)'
    walls    = 'near(x[1], 0) || near(x[1],1)'

    Q = FunctionSpace(mesh, "Lagrange", 1)
    
    # Define inflow profile
    inflow_profile = ('1', '0')

    # Define boundary conditions
    bcu_inflow = DirichletBC(V, Expression(inflow_profile, degree=2), inflow)
    bcu_walls = DirichletBC(V, Constant((0, 0)), walls)

    bcp_outflow = DirichletBC(Q, Constant(0), outflow)
    bcu = [bcu_inflow,bcu_walls]
    bcp = [bcp_outflow]

    # Define trial and test functions
    u = TrialFunction(V)
    v = TestFunction(V)
    p = TrialFunction(Q)
    q = TestFunction(Q)

    # Define functions for solutions at previous and current time steps
    u0 = u_init
    u1 = Function(V)
    p1 = Function(Q)


    # Define expressions used in variational forms
    n  = FacetNormal(mesh)
    f  = Constant((0, 0))
    k  = Constant(dt)
    nu = Constant(nu)
    rho = Constant(rho)

    # Define variational problem for step 1
    F1 = inner(u - u0, v)*dx + k*inner(grad(u0)*u0, v)*dx + \
         k*nu*inner(grad(u), grad(v))*dx - inner(f, v)*dx
    a1 = lhs(F1)
    L1 = rhs(F1)

    # Define variational problem for step 2
    a2 = inner(grad(p), grad(q))*dx
    L2 = -(1/k)*div(u1)*q*dx

    # Define variational problem for step 3
    a3 = inner(u, v)*dx
    L3 = inner(u1, v)*dx - k*inner(grad(p1), v)*dx

    # Assemble matrices
    A1 = assemble(a1)
    A2 = assemble(a2)
    A3 = assemble(a3)

    # Create solvers
    solver1 = KrylovSolver(A1, 'gmres')
    solver2 = KrylovSolver(A2, 'gmres')
    solver3 = KrylovSolver(A3, 'gmres')



    # Use nonzero guesses - essential for CG with non-symmetric BC
    parameters['krylov_solver']['nonzero_initial_guess'] = True

    # Time-stepping
    t = 0
    u_solution_dof = np.zeros([total_steps+1,int(V.dim())])
    if store:
        u_solution = np.zeros([total_steps+1,2, Q.dim()])
        u_solution_dof = np.zeros([total_steps+1,int(V.dim())])
        flux_x, flux_y = u1.split(deepcopy=True)
        u_solution[0,0, :] = flux_x.compute_vertex_values()
        u_solution[0,1, :] = flux_y.compute_vertex_values()
        u_solution_dof[0,:] = u1.vector().get_local()
        p_solution = np.zeros([total_steps+1, Q.dim()])
        p_solution[0, :] = p1.compute_vertex_values()
        
    for n in range(total_steps):
        # Update current time
        t += dt

        # Step 1: Tentative velocity step
        b1 = assemble(L1)
        [bc.apply(A1, b1) for bc in bcu]
        solver1.solve(u1.vector(), b1)

        # Step 2: Pressure correction step
        b2 = assemble(L2)
        [bc.apply(A2, b2) for bc in bcp]
        [bc.apply(p1.vector()) for bc in bcp]
        solver2.solve(p1.vector(), b2)

        # Step 3: Velocity correction step
        b3 = assemble(L3)
        [bc.apply(A3, b3) for bc in bcu]
        solver3.solve(u1.vector(), b3)
        # Plot solution
        
        # Update previous solution
        u0.assign(u1)
        if not store:
            if n%timesteps==0 and (n!=0 or total_steps==1):
                # print("save:{} timestep".format(n))
                d_observe = Function(V, name="data")
                d_observe.vector()[:] = storage['observation'][n+1,:]
                if objective is not None:
                    objective(u0, d_observe, n,points)

        if store:
            flux_x, flux_y = u1.split(deepcopy=True)
            u_solution[n+1,0, :] = flux_x.compute_vertex_values()
            u_solution[n+1,1, :] = flux_y.compute_vertex_values()     
            u_solution_dof[n+1, :] = u1.vector().get_local()
            p_solution[n+1,:] = p1.compute_vertex_values()
    
    return u_solution_dof


def objective_point_wise(storage, u=None, d_observe=None,idx=None,points=None,finalize=False):
    if finalize:
        for loss in storage['functional']:
            storage["loss"] = storage["loss"]+loss
        loss = storage["loss"]
        return loss
    if idx==0:
        storage["functional"] = []
        storage["gt_field"]=[]
        storage["loss"] = 0
    storage["functional"].append(assemble(inner(d_observe-u,d_observe-u)*dx))
    storage["time"]=idx

def eval_cb(j, u):
    print(j)

    

## get observatons
mesh = UnitSquareMesh(10, 10)
V = VectorFunctionSpace(mesh, 'P', 2)
u0 = Function(V)
u0.vector()[:] = 1
u_solution_dof = forward(u0,V,total_steps=100,objective=None,storage=None,points=None,mesh=mesh,store=True)
observed_coords = (V.tabulate_dof_coordinates())[::2,:]

storage = {"functional": [],'solution':u_solution_dof,'observation':u_solution_dof, "time": 0, "estimate_field": [], "gt_observation": [],'gt_field':[],'loss':0}

obj = lambda u, d, idx,observed_coords: objective_point_wise(storage,u, d, idx,observed_coords)
observation_setup = {"start_time":0,"timestep":1,"endtime":100}
obversation_steps = np.arange(observation_setup["endtime"]-observation_setup["start_time"])[0:observation_setup["endtime"]-observation_setup["start_time"]:observation_setup["timestep"]]
obversation_steps = np.floor(obversation_steps/observation_setup["timestep"]).astype(np.int)

## start recover init state
V = VectorFunctionSpace(mesh, 'P', 2)
u_init = Function(V)
control = Control(u_init)

t0 = time.time()
forward(u_init,V=V,total_steps=observation_setup["endtime"]-observation_setup["start_time"],
            objective=obj,storage=storage,points=observed_coords,mesh=mesh)
t1 = time.time()
print("forward modeling takes:{}".format(t1-t0))

t0 = time.time()
J = objective_point_wise(storage, finalize=True)
print("J:{}".format(J))
t1 = time.time()
print("J computation modeling takes:{}".format(t1-t0))

t0 = time.time()
dJd0 = compute_gradient(J, control)
t1 = time.time()
print("gradient computation modeling takes:{}".format(t1-t0))

and I get:

forward modeling takes:1.0288302898406982
J:12.45944001405794
J computation modeling takes:0.0014615058898925781
gradient computation modeling takes:6.213935136795044

A thing to consider is that any automatically derived adjoint model comes with some model overheard (to set up and solve the equations).

When solving problems on toy models (small meshes), this overhead is relatively big.
However, as one increases the number of cells, this overhead drastically decreases, as shown in the following code:

from dolfin import *
from dolfin_adjoint import *
import numpy as np
import time

def forward(u_init,V,total_steps,objective,storage,points,mesh,timesteps=1,store=False):

    dt = 0.001
    mu = 0.1
    nu = mu
    rho = 1            # density

    inflow   = 'near(x[0], 0)'
    outflow  = 'near(x[0], 1)'
    walls    = 'near(x[1], 0) || near(x[1],1)'

    Q = FunctionSpace(mesh, "Lagrange", 1)
    
    # Define inflow profile
    inflow_profile = ('1', '0')

    # Define boundary conditions
    bcu_inflow = DirichletBC(V, Expression(inflow_profile, degree=2), inflow)
    bcu_walls = DirichletBC(V, Constant((0, 0)), walls)

    bcp_outflow = DirichletBC(Q, Constant(0), outflow)
    bcu = [bcu_inflow,bcu_walls]
    bcp = [bcp_outflow]

    # Define trial and test functions
    u = TrialFunction(V)
    v = TestFunction(V)
    p = TrialFunction(Q)
    q = TestFunction(Q)

    # Define functions for solutions at previous and current time steps
    u0 = u_init
    u1 = Function(V)
    p1 = Function(Q)


    # Define expressions used in variational forms
    n  = FacetNormal(mesh)
    f  = Constant((0, 0))
    k  = Constant(dt)
    nu = Constant(nu)
    rho = Constant(rho)

    # Define variational problem for step 1
    F1 = inner(u - u0, v)*dx + k*inner(grad(u0)*u0, v)*dx + \
         k*nu*inner(grad(u), grad(v))*dx - inner(f, v)*dx
    a1 = lhs(F1)
    L1 = rhs(F1)

    # Define variational problem for step 2
    a2 = inner(grad(p), grad(q))*dx
    L2 = -(1/k)*div(u1)*q*dx

    # Define variational problem for step 3
    a3 = inner(u, v)*dx
    L3 = inner(u1, v)*dx - k*inner(grad(p1), v)*dx

    # Assemble matrices
    A1 = assemble(a1)
    A2 = assemble(a2)
    A3 = assemble(a3)

    # Create solvers
    solver1 = KrylovSolver(A1, 'gmres')
    solver2 = KrylovSolver(A2, 'gmres')
    solver3 = KrylovSolver(A3, 'gmres')



    # Use nonzero guesses - essential for CG with non-symmetric BC
    parameters['krylov_solver']['nonzero_initial_guess'] = True

    # Time-stepping
    t = 0
    u_solution_dof = np.zeros([total_steps+1,int(V.dim())])
    if store:
        u_solution = np.zeros([total_steps+1,2, Q.dim()])
        u_solution_dof = np.zeros([total_steps+1,int(V.dim())])
        flux_x, flux_y = u1.split(deepcopy=True)
        u_solution[0,0, :] = flux_x.compute_vertex_values()
        u_solution[0,1, :] = flux_y.compute_vertex_values()
        u_solution_dof[0,:] = u1.vector().get_local()
        p_solution = np.zeros([total_steps+1, Q.dim()])
        p_solution[0, :] = p1.compute_vertex_values()
        
    for n in range(total_steps):
        # Update current time
        t += dt

        # Step 1: Tentative velocity step
        b1 = assemble(L1)
        [bc.apply(A1, b1) for bc in bcu]
        solver1.solve(u1.vector(), b1)

        # Step 2: Pressure correction step
        b2 = assemble(L2)
        [bc.apply(A2, b2) for bc in bcp]
        [bc.apply(p1.vector()) for bc in bcp]
        solver2.solve(p1.vector(), b2)

        # Step 3: Velocity correction step
        b3 = assemble(L3)
        [bc.apply(A3, b3) for bc in bcu]
        solver3.solve(u1.vector(), b3)
        # Plot solution
        
        # Update previous solution
        u0.assign(u1)
        if not store:
            if n%timesteps==0 and (n!=0 or total_steps==1):
                # print("save:{} timestep".format(n))
                d_observe = Function(V, name="data")
                d_observe.vector()[:] = storage['observation'][n+1,:]
                if objective is not None:
                    objective(u0, d_observe, n,points)

        if store:
            flux_x, flux_y = u1.split(deepcopy=True)
            u_solution[n+1,0, :] = flux_x.compute_vertex_values()
            u_solution[n+1,1, :] = flux_y.compute_vertex_values()     
            u_solution_dof[n+1, :] = u1.vector().get_local()
            p_solution[n+1,:] = p1.compute_vertex_values()
    
    return u_solution_dof


def objective_point_wise(storage, u=None, d_observe=None,idx=None,points=None,finalize=False):
    if finalize:
        for loss in storage['functional']:
            storage["loss"] = storage["loss"]+loss
        loss = storage["loss"]
        return loss
    if idx==0:
        storage["functional"] = []
        storage["gt_field"]=[]
        storage["loss"] = 0
    storage["functional"].append(assemble(inner(d_observe-u,d_observe-u)*dx))
    storage["time"]=idx

def eval_cb(j, u):
    print(j)

    

def problem(N):
    ## get observatons
    mesh = UnitSquareMesh(N, N)
    V = VectorFunctionSpace(mesh, 'P', 2)
    u0 = Function(V)
    u0.vector()[:] = 1
    u_solution_dof = forward(u0,V,total_steps=100,objective=None,storage=None,points=None,mesh=mesh,store=True)
    observed_coords = (V.tabulate_dof_coordinates())[::2,:]

    storage = {"functional": [],'solution':u_solution_dof,'observation':u_solution_dof, "time": 0, "estimate_field": [], "gt_observation": [],'gt_field':[],'loss':0}

    obj = lambda u, d, idx,observed_coords: objective_point_wise(storage,u, d, idx,observed_coords)
    observation_setup = {"start_time":0,"timestep":1,"endtime":100}
    obversation_steps = np.arange(observation_setup["endtime"]-observation_setup["start_time"])[0:observation_setup["endtime"]-observation_setup["start_time"]:observation_setup["timestep"]]
    obversation_steps = np.floor(obversation_steps/observation_setup["timestep"]).astype(np.int)

    ## start recover init state
    V = VectorFunctionSpace(mesh, 'P', 2)
    u_init = Function(V)
    control = Control(u_init)

    t0 = time.time()
    forward(u_init,V=V,total_steps=observation_setup["endtime"]-observation_setup["start_time"],
            objective=obj,storage=storage,points=observed_coords,mesh=mesh)
    t1 = time.time()
    forward_time = t1-t0
    t0 = time.time()
    J = objective_point_wise(storage, finalize=True)
    t1 = time.time()


    t0 = time.time()
    dJd0 = compute_gradient(J, control)
    t1 = time.time()
    gradient_time = t1-t0
    print(f"{N} Forward: {forward_time}, Gradient: {gradient_time}, Gradient/Forward: {gradient_time/forward_time:.2f}")


for N in [10, 20, 40, 80, 160]:
    problem(N)

resulting in

10 Forward: 0.704944372177124, Gradient: 3.5798561573028564, Gradient/Forward: 5.08
20 Forward: 1.0256361961364746, Gradient: 4.098656177520752, Gradient/Forward: 4.00
40 Forward: 2.5030620098114014, Gradient: 5.816085338592529, Gradient/Forward: 2.32
80 Forward: 9.350837707519531, Gradient: 12.537389278411865, Gradient/Forward: 1.34
160 Forward: 46.92704176902771, Gradient: 51.84504055976868, Gradient/Forward: 1.10

Even the problems above are quite small with 882, 3362, 13122, 51842 and 206082 degrees of freedom.

2 Likes

Hi Dokken, Thanks a lot for your detailed response. This makes perfect sense and the test is really good.

I have been trying to study the fenics code in the past few days. Do you think it is feasible to do something so that we don’t have this overhead time, or reduce this overhead time?

Thanks a ton!

As i said previously, there is always some overhead.
You could try using the reduced functional,
and optimize the tape

Jhat = ReducedFunctional(J, control)
Jhat.optimize_tape()
dJdc =Jhat.derivative()
1 Like

Got it, Thanks, Dokken!

Hi Dokken, May I ask would you suggest me to explicitly write out the adjoint gradient and compute the gradient? Would it make the whole process faster?

As of runtime, it might decrease the time it takes. However, please note that manually coding adjoints for time dependent problems can be quite tedious and error prone, as you must propegate variables backwards in time. You also need to consider how to handle boundary conditions etc.

As I showed in the example above, the runtime of the gradient computation converges towards the runtime of the forward computation when you increase the number of degrees of freedom. This means that for applications in proper (large) meshes, the runtime tends towards the optimal runtime.