Thank you. Here is my implementation and comments:
def create_custom_element(degree_p, degree_t, dim=3):
#degree_t: degree on the z direction
#degree_p: degree on the x,y direction
#assume degree_t is larger than degree_p
assert degree_t >= degree_p
#get the quadrature points, legendre polys, and function values on the three directions
pts_p, wts_p = basix.make_quadrature(CellType.interval, degree_p*2)
pts_t, wts_t = basix.make_quadrature(CellType.interval, degree_t*2)
poly_p = basix.tabulate_polynomials(PolynomialType.legendre, CellType.interval, degree_p, pts_p)
poly_t = basix.tabulate_polynomials(PolynomialType.legendre, CellType.interval, degree_t, pts_t)
lagrange_p = basix.create_element(ElementFamily.P, CellType.interval, degree_p, LagrangeVariant.equispaced)
lagrange_t = basix.create_element(ElementFamily.P, CellType.interval, degree_t, LagrangeVariant.equispaced)
f_p = lagrange_p.tabulate(0, pts_p)[0, :, :, 0]
f_t = lagrange_t.tabulate(0, pts_t)[0, :, :, 0]
#calculate the coefficients on the three directions
wcoeffs_p = f_p.T@((poly_p*wts_p).T)
wcoeffs_t = f_t.T@((poly_t*wts_t).T)
#because degree_t>=degree_p, some columns of wcoeffs_x/y will be zeros
wcoeffs_x = np.zeros((degree_p+1, degree_t+1))
wcoeffs_x[:, :degree_p+1] = wcoeffs_p
wcoeffs_y = np.zeros((degree_p+1, degree_t+1))
wcoeffs_y[:, :degree_p+1] = wcoeffs_p
wcoeffs_z = wcoeffs_t
#kron product
#the shape of the coefficients matrix will be ((degree_p+1)^2*(degree_t+1), (degree_t+1)^3)
wcoeffs = np.kron(np.kron(wcoeffs_x, wcoeffs_y), wcoeffs_z)
geometry = basix.geometry(basix.CellType.hexahedron)
topology = basix.topology(basix.CellType.hexahedron)
x = [[], [], [], []]
M = [[], [], [], []]
#vertexes
for v in topology[0]:
x[0].append(np.array(geometry[v]))
M[0].append(np.array([[[[1.0]]]]))
pts_p = basix.create_lattice(CellType.interval, degree_p, LatticeType.equispaced, False)
pts_t = basix.create_lattice(CellType.interval, degree_t, LatticeType.equispaced, False)
#edges
for e in topology[1]:
v0 = geometry[e[0]]
v1 = geometry[e[1]]
if np.abs(v1-v0)[2] < 1e-3: #x,y direction
edge_pts = np.array([v0 + p * (v1-v0) for p in pts_p]).reshape(pts_p.shape[0],3)
mat = np.eye(pts_p.shape[0], pts_p.shape[0]).reshape(pts_p.shape[0], 1, pts_p.shape[0], 1)
else: #z-direction
edge_pts = np.array([v0 + p * (v1-v0) for p in pts_t]).reshape(pts_t.shape[0], 3)
mat = np.eye(pts_t.shape[0], pts_t.shape[0]).reshape(pts_t.shape[0], 1, pts_t.shape[0], 1)
x[1].append(edge_pts)
M[1].append(mat)
#facets
for f in topology[2]:
v0 = geometry[f[0]]
v1 = geometry[f[1]]
v2 = geometry[f[2]]
if np.abs(v1-v0)[2] < 1e-3 and np.abs(v2-v0)[2] < 1e-3: #x,y direction
face_pts = np.array([v0 + px * (v1-v0) + py * (v2-v0) for py in pts_p for px in pts_p]).reshape(pts_p.shape[0]**2,3)
mat = np.eye(pts_p.shape[0]**2, pts_p.shape[0]**2).reshape(pts_p.shape[0]**2, 1, pts_p.shape[0]**2, 1)
else: #z direction
face_pts = np.array([v0 + pp * (v1-v0) + pt * (v2-v0) for pt in pts_t for pp in pts_p]).reshape(pts_p.shape[0]*pts_t.shape[0],3)
mat = np.eye(pts_p.shape[0]*pts_t.shape[0], pts_p.shape[0]*pts_t.shape[0]).reshape(pts_p.shape[0]*pts_t.shape[0], 1, pts_p.shape[0]*pts_t.shape[0], 1)
x[2].append(face_pts)
M[2].append(mat)
#body
for b in topology[3]:
v0 = geometry[b[0]]
v1 = geometry[b[1]]
v2 = geometry[b[2]]
v3 = geometry[b[4]]
body_pts = np.array([v0 + px * (v1-v0) + py * (v2-v0) + pz * (v3-v0) for pz in pts_t for py in pts_p for px in pts_p]).reshape(pts_p.shape[0]**2*pts_t.shape[0],3)
mat = np.eye(pts_p.shape[0]**2*pts_t.shape[0], pts_p.shape[0]**2*pts_t.shape[0]).reshape(pts_p.shape[0]**2*pts_t.shape[0], 1, pts_p.shape[0]**2*pts_t.shape[0], 1)
x[3].append(body_pts)
M[3].append(mat)
#create the element
if dim > 1:
custom_element = basix.ufl.blocked_element(basix.ufl.custom_element(
CellType.hexahedron,
[],
wcoeffs,
x,
M,
0,
MapType.identity,
SobolevSpace.H1,
False,
degree_p,
degree_t,
PolysetType.standard,
), (dim,))
else:
custom_element = basix.ufl.custom_element(
CellType.hexahedron,
[],
wcoeffs,
x,
M,
0,
MapType.identity,
SobolevSpace.H1,
False,
degree_p,
degree_t,
PolysetType.standard,
)
return custom_element