Implementing Sherman-Morisson/Woodbury matrix formula in Newton solver?

I am trying to overload the built-in newton solver to implement a linear algebra trick. Specifically,
I am solving a system of linearized virtual work equations of the form [\mathbf{K} - k\mathbf{a} \otimes \mathbf{a}]\Delta\mathbf{u} =- \mathbf{F}. Here the tangent stiffness matrix \mathbf{K} is modified by the term k\mathbf{a} \otimes \mathbf{a}. Notably, \mathbf{K} is sparse, while the term k\mathbf{a} \otimes \mathbf{a} results in a full matrix. This problem can be solved efficiently (without ever forming the full matrix) using the Sherman-Morisson formula as follows:
(\mathbf{K} - k\mathbf{a} \otimes \mathbf{a})^{-1} = \mathbf{K}^{-1} - \frac{k\mathbf{K}^{-1} \mathbf{aa^T}\mathbf{K}^{-1}} {1 + k\mathbf{a^T} \mathbf{K}^{-1} \mathbf{a}} so

\Delta\mathbf{u} = -\mathbf{K}^{-1} \mathbf{F} + \frac{k \mathbf{K}^{-1}\mathbf{a}\mathbf{F }} {1 + k\mathbf{a} \mathbf{K}^{-1} \mathbf{a}} \mathbf{K}^{-1} \mathbf{a}

(See for example:

I have implemented the above as shown in the pseudo-code below (MWE is difficult…), but it is quite slow and I have no real control over the size of the Newton step. I am hoping that by overloading the built in solver, I can get access to PETSc’s line-search. From what I understand, I just have to overload methods for updating K and F, but I can’t seem to find the docs to help me. I did find this on the firedrake docs -, but I have very limited experience with PETSc so any tips would be greatly appreciated. Thanks!

def NRsolver(self, external_load):
    delU = Function(self.V)  # incremental displacement
    self.inc_load = Function(self.Vs)  # incremental load
    discr = 10
    f_form1 =    #....bunch of calls to collect different parts of the rhs form
    f_form2 = ...

    k_form1 =  ... # lhs (form of K above)
    k_form2 = ...#  self.K_load()  
    a_form = # code for form of vector a

    # step up pressure vector in discr increments
    for i in np.linspace(0.0, 1.0, discr):

        # code to update external load vector goes here
        n = self.update_normals()
        delU.vector()[:] = -0.3
        counter = 0

        while (norm(delU) > self.tol):
            current_int_p = ..... #return Constant
            k = # ... call to function to calculate scalar k 
            a = assemble(a_form)   # assemble a here because it for Sherman-Morisson
            # collect all RHS
            F = assemble(f_form2 - f_form1)  - current_int_p*a
            # collect all LHS
            K = assemble(k_form1 + current_int_p*k_form2)

            for bc in self.bc:
                bc.apply(K, F)
                bc.apply(K, a)

            # Sherman-Morisson formula
            Q = Function(self.V).vector()
            R = Function(self.V).vector()
            solve(K, Q, -F)
            solve(K, R, -a)
            beta = 1 / (1 + k * a.inner(R))
            delU.vector()[:] = Q + beta * k * R.inner(F) * R

            self.u.vector().axpy(0.25, delU.vector())  #  poor man's line search
            counter += 1

            if (counter > 80):
                raise Exception("Counter allowance exceeded")
        j += 1