ValueError: Expecting pulled back expression with shape '()', got '(1,)' during interpolation

Hello everyone,

I encounter this error when I try to interpolate a collapsed subfunction. Here is a simplified code snippet for reproduction:

from dolfinx import fem, mesh, common, default_scalar_type
import basix.ufl
import ufl
import numpy as np
from mpi4py import MPI
comm = MPI.COMM_WORLD

Nx = 20; Ny = 20
msh = mesh.create_rectangle(comm, [np.array([0, 0]), np.array([Nx, Ny])], [Nx, Ny], cell_type=mesh.CellType.quadrilateral, dtype=default_scalar_type)


element1 = basix.ufl.element('Lagrange', msh.topology.cell_name(), 1, shape=(2,))
element2 = basix.ufl.element('Lagrange', msh.topology.cell_name(), 1, shape=(1,))
element = basix.ufl.mixed_element([element1, element2])
V = fem.functionspace(msh, element); uh = fem.Function(V)

u_subs = uh.split()
u1 = u_subs[0].sub(0).collapse()
u2 = u_subs[0].sub(1).collapse()
u3 = u_subs[1].collapse()

Q = fem.functionspace(msh, ('DG', 0))
_u1 = fem.Function(Q)
_u2 = fem.Function(Q)
_u3 = fem.Function(Q)

_u1.interpolate(fem.Expression(ufl.grad(u1)[0], Q.element.interpolation_points()))
_u2.interpolate(fem.Expression(ufl.grad(u2)[0], Q.element.interpolation_points()))
_u3.interpolate(fem.Expression(ufl.grad(u3)[0], Q.element.interpolation_points()))

And error occurs at the last line (when interpolating _u3):

Traceback (most recent call last):
  File "/newhome/test.py", line 29, in <module>
    _u3.interpolate(fem.Expression(ufl.grad(u3)[0], Q.element.interpolation_points()))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/fenicsx/lib/python3.12/site-packages/dolfinx/fem/function.py", line 151, in __init__
    self._ufcx_expression, module, self._code = jit.ffcx_jit(
                                                ^^^^^^^^^^^^^
  ...
File "/home/miniconda3/envs/fenicsx/lib/python3.12/site-packages/ufl/algorithms/apply_function_pullbacks.py", line 43, in form_argument
    raise ValueError(
ValueError: Expecting pulled back expression with shape '()', got '(1,)'

The ufl shape of ufl.grad(u3)[0] is indeed (), so I wonder why this error happens.

I have found out why. Just remove shape=(1,) and the code will work.