ArityMismatch for Crank Nicolson method on mixed function space

As a starting experiment in FEniCS I am trying to solve the 1D wave equation \frac{\partial^2 u}{\partial t^2} = c^2\frac{\partial^2 u}{\partial x^2} with fixed boundary conditions on the unit interval. As initial conditions I choose u_0(x)=x(1-x).

For solving the time dependence I want to use the Crank Nicolson method, i.e. u_{n+1}=u_n + \Delta t (\frac{1}{2}\frac{du}{dt}(u_{n+1}) + \frac{1}{2}\frac{du}{dt}(u_n)).

Since there is a second time derivative in the wave equation I want to use a 2D function space to solve the resulting equation system. So far my code for solving the first time step is

from fenics import *

T = 1
nt = 100
dt = T / nt
c = 1

mesh = UnitIntervalMesh(100)
V1 = FiniteElement('Lagrange', mesh.ufl_cell(), 1)
V2 = FiniteElement('Lagrange', mesh.ufl_cell(), 1)
element = MixedElement([V1, V2])
V = FunctionSpace(mesh, element)

def boundary(x, on_boundary):
    return on_boundary
bc1 = DirichletBC(V.sub(0), Constant(0), boundary)
bc2 = DirichletBC(V.sub(0), Constant(0), boundary)

class InitialConditions(UserExpression):
    def __init__(self, **kwargs):
    def eval(self, values, x):
        values[0] = x[0] * (1 - x[0])
        values[1] = 0.0
    def value_shape(self):
        return (2,)

un = Function(V)
u0 = InitialConditions(degree=2)
un1, un2 = split(un)
u1, u2 = TrialFunctions(V)
v1, v2 = TestFunctions(V)
a1 = u1 * v1 * dx - Constant(0.5 * dt) * u2 * v1 * dx
L1 = un1 * v1 * dx + Constant(0.5 * dt) * un2 * v1 * dx
a2 = u2 * v2 * dx + Constant(0.5 * c**2 * dt) * inner(grad(u1), grad(v2)) * dx
L2 = un2 * v2 * dx - Constant(0.5 * c**2 * dt) * inner(grad(u2), grad(v2)) * dx
a = a1 + a2
L = L1 + L2

uh = Function(V)
solve(a==L, uh, [bc1, bc2])

When I run this I get an error message about an ArityMismatch which I cannot make sense of

ArityMismatch                             Traceback (most recent call last)
<ipython-input-9-f0f502a6e073> in <module>()
      1 uh = Function(V)
----> 2 solve(a==L, uh, [bc1, bc2])

/usr/local/lib/python3.6/dist-packages/dolfin/fem/ in solve(*args, **kwargs)
    218     # tolerance)
    219     elif isinstance(args[0], ufl.classes.Equation):
--> 220         _solve_varproblem(*args, **kwargs)
    222     # Default case, just call the wrapped C++ solve function

/usr/local/lib/python3.6/dist-packages/dolfin/fem/ in _solve_varproblem(*args, **kwargs)
    240         # Create problem
    241         problem = LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs,
--> 242                                            form_compiler_parameters=form_compiler_parameters)
    244         # Create solver and call solve

/usr/local/lib/python3.6/dist-packages/dolfin/fem/ in __init__(self, a, L, u, bcs, form_compiler_parameters)
     53             L = cpp.fem.Form(1, 0)
     54         else:
---> 55             L = Form(L, form_compiler_parameters=form_compiler_parameters)
     56         a = Form(a, form_compiler_parameters=form_compiler_parameters)

/usr/local/lib/python3.6/dist-packages/dolfin/fem/ in __init__(self, form, **kwargs)
     44         ufc_form = ffc_jit(form, form_compiler_parameters=form_compiler_parameters,
---> 45                            mpi_comm=mesh.mpi_comm())
     46         ufc_form = cpp.fem.make_ufc_form(ufc_form[0])

/usr/local/lib/python3.6/dist-packages/dolfin/jit/ in mpi_jit(*args, **kwargs)
     45         # Just call JIT compiler when running in serial
     46         if MPI.size(mpi_comm) == 1:
---> 47             return local_jit(*args, **kwargs)
     49         # Default status (0 == ok, 1 == fail)

/usr/local/lib/python3.6/dist-packages/dolfin/jit/ in ffc_jit(ufl_form, form_compiler_parameters)
     95     p.update(dict(parameters["form_compiler"]))
     96     p.update(form_compiler_parameters or {})
---> 97     return ffc.jit(ufl_form, parameters=p)

/usr/local/lib/python3.6/dist-packages/ffc/ in jit(ufl_object, parameters, indirect)
    216     # Inspect cache and generate+build if necessary
--> 217     module = jit_build(ufl_object, module_name, parameters)
    219     # Raise exception on failure to build or import module

/usr/local/lib/python3.6/dist-packages/ffc/ in jit_build(ufl_object, module_name, parameters)
    131                                     name=module_name,
    132                                     params=params,
--> 133                                     generate=jit_generate)
    134     return module

/usr/local/lib/python3.6/dist-packages/dijitso/ in jit(jitable, name, params, generate, send, receive, wait)
    163         elif generate:
    164             # 1) Generate source code
--> 165             header, source, dependencies = generate(jitable, name, signature, params["generator"])
    166             # Ensure we got unicode from generate
    167             header = as_unicode(header)

/usr/local/lib/python3.6/dist-packages/ffc/ in jit_generate(ufl_object, module_name, signature, parameters)
     65     code_h, code_c, dependent_ufl_objects = compile_object(ufl_object,
---> 66             prefix=module_name, parameters=parameters, jit=True)
     68     # Jit compile dependent objects separately,

/usr/local/lib/python3.6/dist-packages/ffc/ in compile_form(forms, object_names, prefix, parameters, jit)
    141     """This function generates UFC code for a given UFL form or list of UFL forms."""
    142     return compile_ufl_objects(forms, "form", object_names,
--> 143                                prefix, parameters, jit)

/usr/local/lib/python3.6/dist-packages/ffc/ in compile_ufl_objects(ufl_objects, kind, object_names, prefix, parameters, jit)
    183     # Stage 1: analysis
    184     cpu_time = time()
--> 185     analysis = analyze_ufl_objects(ufl_objects, kind, parameters)
    186     _print_timing(1, time() - cpu_time)

/usr/local/lib/python3.6/dist-packages/ffc/ in analyze_ufl_objects(ufl_objects, kind, parameters)
     88         # Analyze forms
     89         form_datas = tuple(_analyze_form(form, parameters)
---> 90                            for form in forms)
     92         # Extract unique elements accross all forms

/usr/local/lib/python3.6/dist-packages/ffc/ in <genexpr>(.0)
     88         # Analyze forms
     89         form_datas = tuple(_analyze_form(form, parameters)
---> 90                            for form in forms)
     92         # Extract unique elements accross all forms

/usr/local/lib/python3.6/dist-packages/ffc/ in _analyze_form(form, parameters)
    172                                       do_apply_geometry_lowering=True,
    173                                       preserve_geometry_types=(Jacobian,),
--> 174                                       do_apply_restrictions=True)
    175     elif r == "tsfc":
    176         try:

/usr/local/lib/python3.6/dist-packages/ufl/algorithms/ in compute_form_data(form, do_apply_function_pullbacks, do_apply_integral_scaling, do_apply_geometry_lowering, preserve_geometry_types, do_apply_default_restrictions, do_apply_restrictions, do_estimate_degrees)
    392     # faster!
    393     preprocessed_form = reconstruct_form_from_integral_data(self.integral_data)
--> 394     check_form_arity(preprocessed_form, self.original_form.arguments())  # Currently testing how fast this is
    396     # TODO: This member is used by unit tests, change the tests to

/usr/local/lib/python3.6/dist-packages/ufl/algorithms/ in check_form_arity(form, arguments)
    150 def check_form_arity(form, arguments):
    151     for itg in form.integrals():
--> 152         check_integrand_arity(itg.integrand(), arguments)

/usr/local/lib/python3.6/dist-packages/ufl/algorithms/ in check_integrand_arity(expr, arguments)
    143                              key=lambda x: (x.number(), x.part())))
    144     rules = ArityChecker(arguments)
--> 145     args = map_expr_dag(rules, expr, compress=False)
    146     if args != arguments:
    147         raise ArityMismatch("Integrand arguments {0} differ from form arguments {1}.".format(args, arguments))

/usr/local/lib/python3.6/dist-packages/ufl/corealg/ in map_expr_dag(function, expression, compress)
     35     Return the result of the final function call.
     36     """
---> 37     result, = map_expr_dags(function, [expression], compress=compress)
     38     return result

/usr/local/lib/python3.6/dist-packages/ufl/corealg/ in map_expr_dags(function, expressions, compress)
     84                 r = handlers[v._ufl_typecode_](v)
     85             else:
---> 86                 r = handlers[v._ufl_typecode_](v, *[vcache[u] for u in v.ufl_operands])
     88             # Optionally check if r is in rcache, a memory optimization

/usr/local/lib/python3.6/dist-packages/ufl/algorithms/ in sum(self, o, a, b)
     40     def sum(self, o, a, b):
     41         if a != b:
---> 42             raise ArityMismatch("Adding expressions with non-matching form arguments {0} vs {1}.".format(a, b))
     43         return a

ArityMismatch: Adding expressions with non-matching form arguments (Argument(FunctionSpace(Mesh(VectorElement(FiniteElement('Lagrange', interval, 1), dim=1), 5), MixedElement(FiniteElement('Lagrange', interval, 1), FiniteElement('Lagrange', interval, 1))), 0, None),) vs (Argument(FunctionSpace(Mesh(VectorElement(FiniteElement('Lagrange', interval, 1), dim=1), 5), MixedElement(FiniteElement('Lagrange', interval, 1), FiniteElement('Lagrange', interval, 1))), 0, None), Argument(FunctionSpace(Mesh(VectorElement(FiniteElement('Lagrange', interval, 1), dim=1), 5), MixedElement(FiniteElement('Lagrange', interval, 1), FiniteElement('Lagrange', interval, 1))), 1, None)).

I assume that some definitions for the forms are not compatible with each other. Can someone help me with this issue? I am using the FEniCS 2019.1 in a Docker container on Windows 10

Hi wavy,

the ArityMismatch arises, since your right hand side L is not a linear form, but a bilinear form.

In [1]: len(L2.arguments())                                                                                      
Out[1]: 2

For a linear form the number of arguments should be 1.
I didn’t go over the derivation of the weak form myself, maybe you could check it again. You could just push the bilinear term to the left hand side, and fenics actually will do it for you as well:

form = a1 + a2 - L1 - L2
a = lhs(form)
L = rhs(form)


Thanks. This solved the problem. My mistake was that I used inner(grad(u2), grad(v2)) instead of fe.inner(fe.grad(un1), fe.grad(v2)) in L2.

For completeness, here is my example code for the Crank Nicolson time integration of the wave/string equation

from IPython.display import clear_output
import fenics as fe
import matplotlib.pyplot as plt
import numpy as np

T = 2
nt = 200
dt = T / nt
c = 1
nx = 200

mesh = fe.UnitIntervalMesh(nx)
V1 = fe.FiniteElement('Lagrange', mesh.ufl_cell(), 1)
V2 = fe.FiniteElement('Lagrange', mesh.ufl_cell(), 1)
element = fe.MixedElement([V1, V2])
V = fe.FunctionSpace(mesh, element)

def boundary(x, on_boundary):
    return on_boundary
bc1 = fe.DirichletBC(V.sub(0), fe.Constant(0), boundary)
bc2 = fe.DirichletBC(V.sub(1), fe.Constant(0), boundary)

class InitialConditions(fe.UserExpression):
    def __init__(self, **kwargs):
        super(InitialConditions, self).__init__(**kwargs)
    def eval(self, values, x):
        values[0] = x[0] * (1 - x[0])
        values[1] = 0.0
    def value_shape(self):
        return (2,)

un = fe.Function(V)
u0 = InitialConditions(degree=2)
un1, un2 = fe.split(un)

u1, u2 = fe.TrialFunctions(V)
v1, v2 = fe.TestFunctions(V)
a1 = u1 * v1 * fe.dx - fe.Constant(0.5 * dt) * u2 * v1 * fe.dx
L1 = un1 * v1 * fe.dx + fe.Constant(0.5 * dt) * un2 * v1 * fe.dx
a2 = u2 * v2 * fe.dx + fe.Constant(0.5 * c**2 * dt) * fe.inner(fe.grad(u1), fe.grad(v2)) * fe.dx
L2 = un2 * v2 * fe.dx - fe.Constant(0.5 * c**2 * dt) * fe.inner(fe.grad(un1), fe.grad(v2)) * fe.dx
a = a1 + a2
L = L1 + L2

uh = fe.Function(V)
x = np.linspace(0, 1, nx)
y0 = np.zeros(nx)
y1 = np.zeros(nx)
for i in range(nt):
    fe.solve(a==L, uh, [bc1, bc2])
    for j in range(nx):
        y0[j] = uh(x[j])[0]
        y1[j] = uh(x[j])[1]
    plt.figure(1, figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.plot(x, y0)
    plt.ylim(-0.3, 0.3)
    plt.subplot(1, 2, 2)
    plt.plot(x, y1)
    plt.ylim(-1.0, 1.0)
    plt.savefig('frames 1/solution_%03i.png' % i, dpi=150)

This plotting is surely not optimzed. The (animated) result looks like this:


Cool example. I added it to the gallery of vtkplotter.dolfin as an example of 1D visualization (I hope you don’t mind!),

I am happy if you find the code useful. Nice visualization. I was not aware of vtkplotter so far.

