Fluid reconstruction problem setup

Hi, I am trying to solve a fluid reconstruction problem and want to get some feedback on the code:

Helper functions: forward model and objective function

def forward(u0,V,total_steps,objective,storage,points,mask,mesh,loss_fn):

    T = 10   # final time
    num_steps = 10000    # number of time steps
    dt = T / num_steps # time step size
    mu = 0.001
    nu = mu
    rho = 1            # density

    inflow   = 'near(x[0], -1)'
    outflow  = 'near(x[0], 4)'
    walls    = 'near(x[1], -0.5) || near(x[1], 0.5)'
    cylinder = 'on_boundary && not(near(x[0], -1)) && not(near(x[0], 4)) && not(near(x[1], -0.5)) && not(near(x[1], 0.5))'
    
    Q = FunctionSpace(mesh, "Lagrange", 1)
    V1 = FunctionSpace(mesh, "Lagrange", 1)
    
    # Define inflow profile
    inflow_profile = ('1', '0')

    # Define boundary conditions
    bcu_inflow = DirichletBC(V, Expression(inflow_profile, degree=2), inflow)
    bcu_walls = DirichletBC(V, Constant((0, 0)), walls)
    bcu_cylinder = DirichletBC(V, Constant((0, 0)), cylinder)

    bcp_outflow = DirichletBC(Q, Constant(0), outflow)
    bcu = [bcu_inflow,bcu_walls, bcu_cylinder]
    bcp = [bcp_outflow]

    # Define trial and test functions
    u = TrialFunction(V)
    v = TestFunction(V)
    p = TrialFunction(Q)
    q = TestFunction(Q)

    # Define functions for solutions at previous and current time steps
    u1 = Function(V)
    p1 = Function(Q)


    # Define expressions used in variational forms
    n  = FacetNormal(mesh)
    f  = Constant((0, 0))
    k  = Constant(dt)
    nu = Constant(nu)
    rho = Constant(rho)

    # Define variational problem for step 1
    F1 = inner(u - u0, v)*dx + k*inner(grad(u0)*u0, v)*dx + \
         k*nu*inner(grad(u), grad(v))*dx - inner(f, v)*dx
    a1 = lhs(F1)
    L1 = rhs(F1)

    # Define variational problem for step 2
    a2 = inner(grad(p), grad(q))*dx
    L2 = -(1/k)*div(u1)*q*dx

    # Define variational problem for step 3
    a3 = inner(u, v)*dx
    L3 = inner(u1, v)*dx - k*inner(grad(p1), v)*dx

    # Assemble matrices
    A1 = assemble(a1)
    A2 = assemble(a2)
    A3 = assemble(a3)

    # Use amg preconditioner if available
    prec = "amg" if has_krylov_solver_preconditioner("amg") else "default"

    # Use nonzero guesses - essential for CG with non-symmetric BC
    parameters['krylov_solver']['nonzero_initial_guess'] = True

    # Time-stepping
    t = 0
    for n in range(total_steps):

        # Update current time
        t += dt

        # Step 1: Tentative velocity step
        b1 = assemble(L1)
        [bc.apply(A1, b1) for bc in bcu]
        solve(A1, u1.vector(), b1, 'gmres', 'default')

        # Step 2: Pressure correction step
        b2 = assemble(L2)
        [bc.apply(A2, b2) for bc in bcp]
        [bc.apply(p1.vector()) for bc in bcp]
        solve(A2, p1.vector(), b2, 'gmres', 'default')

        # Step 3: Velocity correction step
        b3 = assemble(L3)
        [bc.apply(A3, b3) for bc in bcu]
        solve(A3, u1.vector(), b3, 'gmres', 'default')
        # Plot solution

        # Update previous solution
        u0.assign(u1)
        if n%10==0:
            d_observe = Function(V, name="data")
            d_observe.vector()[:] = storage['observation'][n+1,:]
            if objective is not None:
                objective(u0, d_observe, n,points,loss_fn=loss_fn)
    

def objective_point_wise(storage, u=None, d_observe=None,idx=None,points=None,loss_fn="l2",Q=None,finalize=False,obversation_steps=None):
    if finalize:
        count = 0
        for idx in range(obversation_steps.shape[0]):
            step = obversation_steps[idx]
            if idx==0:
                loss = storage['functional'][step]
                count = count + 1
            else:
                if step>=len(storage['functional']):
                    continue
                else:
                    loss += storage['functional'][step]
                    count = count + 1
        storage["loss"] = loss/count
        return loss/count
    if idx==0:
        storage["functional"] = []
        storage["estimate_field"]=[]
        storage["gt_observation"]=[]
        storage["gt_field"]=[]
        storage["loss"] = 0
        loss = 0
    observation = [d_observe(p) for p in points]
    prediction = [u(p) for p in points]
    if loss_fn=="l2":
        # storage["functional"].append(sum([(observation[i]-(prediction[i][0]**2+prediction[i][1]**2))**2 for i in range(len(points))])/len(points))
        storage["functional"].append(sum([(observation[i][0]-prediction[i][0])**2+(observation[i][1]-prediction[i][1])**2 for i in range(len(points))])/len(points))
    elif loss_fn=="l1":
        storage["functional"].append(sum([(((observation[i]-prediction[i])**2)**0.5) for i in range(len(points))])/len(points))
    

Run optimization/solve inverse problem

observation,observed_coords,mask,_= get_observation(nodes,solution,observation_setup["timestep"])

storage = {"functional": [],'solution':solution,'observation':observation, "time": 0, "estimate_field": [], "gt_observation": [],'gt_field':[],'loss':0}
obj = lambda u, d, idx,observed_coords,loss_fn, Q: objective_point_wise(storage,u, d, idx,observed_coords,loss_fn, Q)
observation_setup = {"start_time":500,"timestep":10,"endtime":580}
obversation_steps = np.arange(observation_setup["endtime"]-observation_setup["start_time"])[0:observation_setup["endtime"]-observation_setup["start_time"]:observation_setup["timestep"]]
obversation_steps = np.floor(obversation_steps/observation_setup["timestep"]).astype(np.int)


# initialize u at 0
V = VectorFunctionSpace(mesh, 'P', 2)
u= Function(V)
u.vector()[:] = 0

forward(u,V=V,total_steps=observation_setup["endtime"]-observation_setup["start_time"],
            objective=obj,storage=storage,points=observed_coords,mask=mask,mesh=mesh,loss_fn=loss_fn)
control = Control(u)

J = objective_point_wise(storage, finalize=True,obversation_steps=obversation_steps)
print("J:{}".format(J))
Jhat = ReducedFunctional(J, control, eval_cb_post=eval_cb)
problem = MinimizationProblem(Jhat)
parameters = {'maximum_iterations': 100}

solver = IPOPTSolver(problem, parameters=parameters)
rho_opt = solver.solve()