Conversion from MatrixCSR to PETSc matrix

I am trying to write a Python script to solve a particular heat equation. The issue at the moment seems to be that the PETSc solver seems to require a PETSc matrix, whereas the fem.assemble_matrix function produces a MatrixCSR object (see code below). How would I convert this to be in the correct form for the solution to be found? Any other comments or input would be appreciated!

import dolfinx
import ufl
from mpi4py import MPI
from petsc4py import PETSc
import numpy as np
from dolfinx.io import XDMFFile
from dolfinx.fem import (Function, FunctionSpace, dirichletbc, locate_dofs_geometrical, Constant, form, assemble_vector, assemble_matrix)
from dolfinx import mesh, fem
import matplotlib.pyplot as plt
import pyvista as pv


def visualiseMesh(msh):
    '''
    Provides a 2D rendering file (.xdmf) to show the mesh structure.  
      Has to be mesh object from the dolfinx library.
    '''
    print("Plotting mesh...")
    with XDMFFile(msh.comm, "mesh.xdmf", "w") as xdmf:
        xdmf.write_mesh(msh)

    # Plot the mesh using Matplotlib
    topology, cell_types = dolfinx.plot.create_vtk_mesh(msh)
    grid = pv.UnstructuredGrid(topology, cell_types, msh.geometry.x)

     # Plot the mesh using PyVista
    plotter = pv.Plotter()
    plotter.add_mesh(grid, show_edges=True)
    plotter.show()


def findHeatSolution(length, width, rho_param, c_p_param, k_param ):
    print("""
        ===================
        Starting simulation
        ===================
          """)
    
    print("Generating mesh...\n")
    # Creating the mesh (assuming a rectangular substrate for now)
    msh = mesh.create_rectangle(
        comm=MPI.COMM_WORLD,
        points=((0.0,0.0), (length, width)),
        n = (32, 16),
        cell_type=mesh.CellType.triangle,
    )

    # visualiseMesh(msh)

    # Defining the function space V
    V = fem.FunctionSpace(msh, ("Lagrange", 1))

    # Defining trial and test functions u and v
    T_n = Function(V) # Temperature at previous time step
    v = ufl.TestFunction(V)
    u = ufl.TrialFunction(V) # u is the temperature at the current timestep (what we are trying to solve)
    n = ufl.FacetNormal(msh)

    delta_t = 0.1

    # Defining the localised heat source power dissipation (uW) and coordinates
    # STILL NEED TO DYNAMICALLY READ IN POWER DISSIPATED AND COORDINATES OF RESISTOR
    print("Defining the localised heat source power dissipation and coordinates... \n")
    P1, P2, P3 = 420, 250, 500

    x1, y1 = -50.0, 0.0
    x2, y2 = 0.0, 0.0
    x3, y3 = 50.0, 0.0

    def localized_heat_source(x, power1=P1, power2=P2, power3=P3, radius=0.05):
        return (
        power1 * np.exp(-((x[0] - x1) ** 2 + (x[1] - y1) ** 2) / (2 * radius ** 2))
        + power2 * np.exp(-((x[0] - x2) ** 2 + (x[1] - y2) ** 2) / (2 * radius ** 2))
        + power3 * np.exp(-((x[0] - x3) ** 2 + (x[1] - y3) ** 2) / (2 * radius ** 2))
        )
    
   # Define the heat source as a Function in the FunctionSpace
    Q = Function(V)
    Q.interpolate(localized_heat_source)

    # Define the bilinear form (should contain unknowns u and v)
    a = (rho_param * c_p_param / delta_t) * u * v * ufl.dx + k_param * ufl.dot(ufl.grad(u), ufl.grad(v)) * ufl.dx - k_param * ufl.dot(ufl.grad(u), n) * v * ufl.ds

    # Define the linear form (including contributions from known T_n and Q)
    L = (rho_param * c_p_param / delta_t) * T_n * v * ufl.dx + Q * v * ufl.dx

    # Apply boundary conditions (e.g., Dirichlet with T=0 on all boundaries)
    boundary_conditions = [dirichletbc(PETSc.ScalarType(0.0), locate_dofs_geometrical(V, lambda x: np.full(x.shape[1], True)), V)]

    # Time-stepping parameters
    num_steps = 50  # Number of time steps
    time = 0.0

    with XDMFFile(MPI.COMM_WORLD, "heat_solution.xdmf", "w") as file:
        file.write_mesh(msh)
        # Set the function to be written initially
        file.write_function(T_n, time)

        # Assemble the system matrix (A)
        A = fem.create_matrix(form(a))
        fem.assemble_matrix(A, form(a), bcs=boundary_conditions)

        # Create a solver
        solver = PETSc.KSP().create(msh.comm)
        solver.setOperators(A.petsc)  # Use the PETSc matrix
        solver.setType(PETSc.KSP.Type.PREONLY)
        solver.getPC().setType(PETSc.PC.Type.LU)

        for step in range(num_steps):
            time += delta_t
            print(f"Time step {step + 1}/{num_steps}, Time: {time:.2f}")

            # Assemble the right-hand side vector (b)
            b = fem.create_vector(form(L))
            fem.assemble_vector(b, form(L))
            fem.apply_lifting(b, [form(a)], bcs=[boundary_conditions])
            b.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
            fem.set_bc(b, boundary_conditions)

            # Create the solution vector
            T_new = Function(V)
            solver.solve(b, T_new.vector)
            T_new.x.scatter_forward()

            # Update T_n for the next time step
            T_n.x.array[:] = T_new.x.array[:]

            # Save the current solution
            file.write_function(T_new, time)


    print("""
        ===================
        Simulation complete
        ===================
          """)

Why not just use
from dolfinx.fem.petsc import assemble_vector, assemble_matrix instead as you will then get them in PETSc matrix and vector format

This seemed to have solved the problem I was having. At least I think it did…
My simulation is still not running so I’m not sure if this change had allowed the simulation to work, but it did seem to fix the error I was getting at those lines. Thanks!

I would then suggest opening a separate post for any future issue with a minimal reproducible example:)