Parallel computation TAO solver

Hello,

I have developed some code that uses PETSc TAO optimization solver with box constraints.
When I run the code on a single processor, the code works as expected.
When I run on more than one processor, the code it gets somehow locked at a certain point and doesn’t terminate. It looks like some of the processors call one of the functions f, F, or J from the class TAOProblem and the code cannot proceed because the other processors don’t do the same.

This is a minimal working example of my code.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
from petsc4py import PETSc
import petsc4py
import ufl
from dolfinx import fem, mesh, la
from mpi4py import MPI

petsc4py.init(sys.argv)
os.system('clear')

# Create MPI communicator
comm = MPI.COMM_WORLD

# ----- Create the mesh
L = 1.0 
N = 10
domain = mesh.create_box(comm=comm, points=[[0.0, 0.0, 0.0], [L, L, L]],\
                         n=[N, N, N], cell_type=mesh.CellType.tetrahedron)

# ----- Create the measure dv
metadata = {'quadrature_degree' : 4}
dv = ufl.Measure("dx", domain=domain, metadata=metadata)

# Define the function space
V_theta = fem.VectorFunctionSpace(domain,("CG", 1), dim=2)

# Solution, Test, and Trial functions 
theta = fem.Function(V_theta)
theta_test, theta_trial =  ufl.TestFunction(V_theta), ufl.TrialFunction(V_theta)

# initalise the value of theta with a specific function
def initalization(pos):
    radius = pos[0]**2 + pos[1]**2
    if radius<1.0e-8:
        return (0., np.pi/2)
    mx = -pos[1]/radius * np.sqrt(1-np.exp(-4*radius**2/L**2))
    my = pos[0]/radius * np.sqrt(1-np.exp(-4*radius**2/L**2))
    mz = np.exp(-2*radius**2/L**2)
    if mx==0:
        return (0., np.pi/2)
    else:
        return (np.arctan(my/mx), np.arcsin(mz))

def eval_theta(x):
    d = x.shape[1]
    angles = [np.zeros(d), np.zeros(d)]
    for i in range(d):
        angles[0][i], angles[1][i] = initalization([x[0][i], x[1][i], x[2][i]])
    return angles

theta.interpolate(eval_theta)

""" 
---------------------------------------------------------
Define the optimization problem class
---------------------------------------------------------
"""       
class TAOProblem:

    """Nonlinear problem class compatible with PETSC.TAO solver"""

    def __init__(self, f, F, J, u, bcs):

        self.bcs = bcs                 # bcs = boundary conditions
        self.u   = u                   # u = solution 
        self.L   = fem.form(F) # F = residual
        self.a   = fem.form(J) # J = jacobian
        self.obj = fem.form(f) # f = objective function

        # .. Generation of matrices in the no-linear problem
        self.A = fem.petsc.create_matrix(self.a)
        self.b = fem.petsc.create_vector(self.L)
        
    def f(self, tao, x):
        
        # .... Connection between ranks in the vector x: FORWARD
        x.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
        # ..... We copy the vector x in the vector u
        x.copy(self.u.vector)
        # .... Connection between ranks in the vector u in the class
        self.u.vector.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
        print(fem.assemble_scalar(fem.form(self.obj)), flush=True)
        return fem.assemble_scalar(fem.form(self.obj))

    def F(self, tao, x, F):
    
        # .... Connection between ranks in the vector x: FORWARD
        x.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
        # ..... We copy the vector x in the vector u
        x.copy(self.u.vector)
        # .... Connection between ranks in the vector u in the class
        self.u.vector.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)

        with F.localForm() as f_local:
            f_local.set(0.0)

        fem.petsc.assemble_vector(F, self.L)
        # .... Include the boundary conditions using the lifting technique
        fem.petsc.apply_lifting(F, [self.a], bcs=[self.bcs], x0=[x], scale=-1.0)
        # .... Connection between ranks in the vector x: REVERSE
        F.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        # .... Redefine the function F with the boundary condition.
        fem.petsc.set_bc(F, self.bcs, x, -1.0)
                 
    def J(self, tao, x, A, P):
        A.zeroEntries()
        fem.petsc.assemble_matrix(A, self.a, self.bcs)
        A.assemble()

# --- Define lower bound and upper bound for theta
theta_min = fem.Function(V_theta)
theta_max = fem.Function(V_theta)

def eval_theta_lowerbound(x):
    d = x.shape[1]
    return [-np.pi*np.ones(d), -np.pi/2*np.ones(d)]

def eval_theta_upperbound(x):
    d = x.shape[1]
    return [np.pi*np.ones(d), np.pi/2*np.ones(d)]

theta_min.interpolate(eval_theta_lowerbound)
theta_max.interpolate(eval_theta_upperbound)

""" 
---------------------------------------------------------
Initialize energy functions and directional derivatives
---------------------------------------------------------
"""
# convert m to Cartesian coordinates
def m(theta):
    return ufl.as_vector([ufl.cos(theta[0])*ufl.cos(theta[1]), 
                          ufl.sin(theta[0])*ufl.cos(theta[1]),
                          ufl.sin(theta[1])])

# --- Compute current energies and derivatives

W = ufl.tr(ufl.grad(m(theta))*ufl.grad(m(theta)).T)*dv
W_residual_theta = ufl.derivative(W, theta, theta_test)
W_jacobian_theta = ufl.derivative(W_residual_theta, theta, theta_trial)
# --- Allocate the memory needed for TAO
dofs_domain, dofs_borders = V_theta.dofmap.index_map, V_theta.dofmap.index_map_bs
b = la.create_petsc_vector(dofs_domain, dofs_borders)
J = fem.petsc.create_matrix(fem.form(W_jacobian_theta))

# --- Create optimization problem
problem = TAOProblem(W, W_residual_theta, W_jacobian_theta, theta, [])

# --- Setup optimization
solver_tao = PETSc.TAO().create(comm=comm)
solver_tao.setObjective(problem.f)
solver_tao.setGradient(problem.F, b)
solver_tao.setHessian(problem.J, J)

solver_tao.setType("tron")
solver_tao.getKSP().setType("preonly")
solver_tao.getKSP().getPC().setType("lu")
solver_tao.setVariableBounds(theta_min.vector, theta_max.vector)

solver_tao.solve(theta.vector)

# --- Print out solution status
curr_iter, curr_f, gnorm, cnorm, xdiff, converged_reason = solver_tao.getSolutionStatus()

print("Finish")

The problem persists when I change the mesh and also the form W.
I tried to change the solver but the problem eventually keeps showing up also with other solvers, and I want to use “tron”. Finally, the problem seems highly dependent on the initialization of the function theta.

The issue seems very similar to this and this.

I installed dolfinx using conda, here there is the list of the packets in the environment.

Name Version

_libgcc_mutex 0.1
_openmp_mutex 4.5
aiohttp 3.8.4
aiosignal 1.3.1
alabaster 0.7.13
alsa-lib 1.2.8
anyio 3.7.1
aom 3.5.0
argon2-cffi 21.3.0
argon2-cffi-bindings 21.2.0
arrow 1.2.3
astroid 2.15.6
asttokens 2.2.1
async-timeout 4.0.2
atomicwrites 1.4.1
attr 2.5.1
attrs 23.1.0
autopep8 2.0.2
babel 2.12.1
backcall 0.2.0
backports 1.0
backports.functools_lru_cache 1.6.5
beautifulsoup4 4.12.2
binaryornot 0.4.4
binutils_impl_linux-64 2.40
binutils_linux-64 2.40
black 23.7.0
bleach 6.0.0
blosc 1.21.4
bokeh 3.2.0
boltons 23.0.0
boost-cpp 1.78.0
brotli 1.0.9
brotli-bin 1.0.9
brotli-python 1.0.9
bzip2 1.0.8
c-ares 1.19.1
ca-certificates 2023.7.22
cairo 1.16.0
certifi 2023.7.22
cffi 1.15.1
chardet 5.2.0
charset-normalizer 3.2.0
click 8.1.7
cloudpickle 2.2.1
cmocean 3.0.3
colorama 0.4.6
colorcet 3.0.1
comm 0.1.3
conda 23.5.0
conda-package-handling 2.0.2
conda-package-streaming 0.8.0
contourpy 1.1.0
cookiecutter 2.3.0
coverage 7.2.7
cppimport 22.8.2
cryptography 41.0.1
curl 8.1.2
cycler 0.11.0
cython 3.0.2
dav1d 1.2.1
dbus 1.13.6
debugpy 1.6.7
decorator 5.1.1
defusedxml 0.7.1
diff-match-patch 20230430
dill 0.3.7
docstring-to-markdown 0.12
docutils 0.20.1
double-conversion 3.3.0
eigen 3.4.0
entrypoints 0.4
exceptiongroup 1.1.2
executing 1.2.0
expat 2.5.0
fastjsonschema 2.17.1
fenics-basix 0.6.0
fenics-basix-pybind11-abi 0.4.12
fenics-dolfinx 0.6.0
fenics-ffcx 0.6.0
fenics-instant 2017.2.0
fenics-libbasix 0.6.0
fenics-libdolfinx 0.6.0
fenics-ufcx 0.6.0
fenics-ufl 2023.1.1
ffmpeg 5.1.2
fftw 3.3.10
filelock 3.12.2
flake8 6.0.0
font-ttf-dejavu-sans-mono 2.37
font-ttf-inconsolata 3.000
font-ttf-source-code-pro 2.038
font-ttf-ubuntu 0.83
fontconfig 2.14.2
fonts-conda-ecosystem 1
fonts-conda-forge 1
fonttools 4.40.0
fqdn 1.5.1
freetype 2.12.1
fribidi 1.0.10
frozenlist 1.3.3
gcc 12.3.0
gcc_impl_linux-64 12.3.0
gcc_linux-64 12.3.0
gettext 0.21.1
giflib 5.2.1
gl2ps 1.4.2
glew 2.1.0
glib 2.76.4
glib-tools 2.76.4
gmp 6.2.1
gmsh 4.11.1
gnutls 3.7.8
graphite2 1.3.13
gst-plugins-base 1.22.3
gstreamer 1.22.3
gxx_impl_linux-64 12.3.0
gxx_linux-64 12.3.0
h2tools 1.1.1
h5py 3.9.0
harfbuzz 7.3.0
hdf4 4.2.15
hdf5 1.14.0
hypre 2.27.0
icu 72.1
idna 3.4
imageio 2.31.1
imagesize 1.4.1
importlib-metadata 6.8.0
importlib-resources 6.0.0
importlib_metadata 6.8.0
importlib_resources 6.0.0
inflection 0.5.1
iniconfig 2.0.0
intervaltree 3.1.0
ipycanvas 0.13.1
ipydatawidgets 4.3.2
ipyevents 2.0.1
ipykernel 6.24.0
ipympl 0.9.3
ipyparallel 8.6.1
ipython 8.14.0
ipython-genutils 0.2.0
ipython_genutils 0.2.0
ipyvtklink 0.2.3
ipywidgets 7.7.5
isoduration 20.11.0
isort 5.12.0
itk-core 5.3.0
itk-filtering 5.3.0
itk-meshtopolydata 0.10.0
itk-numerics 5.3.0
itkwidgets 0.32.6
jaraco.classes 3.3.0
jedi 0.18.2
jeepney 0.8.0
jellyfish 1.0.0
jinja2 3.1.2
jsoncpp 1.9.5
jsonpatch 1.32
jsonpointer 2.0
jsonschema 4.18.0
jsonschema-specifications 2023.6.1
jupyter 1.0.0
jupyter-console 6.6.3
jupyter-events 0.6.3
jupyter-server 2.7.0
jupyter-server-proxy 4.0.0
jupyter-server-terminals 0.4.4
jupyter_client 8.3.0
jupyter_core 5.3.1
jupyterlab-widgets 1.1.4
jupyterlab_pygments 0.2.2
jupyterthemes 0.20.0
kaleido 0.2.1
kernel-headers_linux-64 2.6.32
keyring 24.2.0
keyutils 1.6.1
kiwisolver 1.4.4
krb5 1.20.1
lame 3.100
lazy-object-proxy 1.9.0
lcms2 2.15
ld_impl_linux-64 2.40
lerc 4.0.0
lesscpy 0.15.1
libadios2 2.8.3
libaec 1.0.6
libass 0.17.1
libblas 3.9.0
libbrotlicommon 1.0.9
libbrotlidec 1.0.9
libbrotlienc 1.0.9
libcap 2.67
libcblas 3.9.0
libclang 16.0.6
libclang13 16.0.6
libcups 2.3.3 h36d4200_3 conda-forge
libcurl 8.1.2 h409715c_0 conda-forge
libdeflate 1.18
libdrm 2.4.114
libedit 3.1.20191231
libev 4.33
libevent 2.1.12
libexpat 2.5.0
libffi 3.4.2
libflac 1.4.3
libgcc-devel_linux-64 12.3.0
libgcc-ng 13.1.0
libgcrypt 1.10.1
libgfortran-ng 13.1.0
libgfortran5 13.1.0
libglib 2.76.4
libglu 9.0.0
libgomp 13.1.0
libgpg-error 1.47
libhwloc 2.9.1
libiconv 1.17
libidn2 2.3.4
libjpeg-turbo 2.1.5.1
liblapack 3.9.0
libllvm16 16.0.6
libnetcdf 4.9.2
libnghttp2 1.52.0
libnsl 2.0.0
libogg 1.3.4
libopenblas 0.3.23
libopus 1.3.1
libpciaccess 0.17
libpng 1.6.39
libpq 15.3
libsanitizer 12.3.0
libsndfile 1.2.0
libsodium 1.0.18
libspatialindex 1.9.3
libsqlite 3.42.0
libssh2 1.11.0
libstdcxx-devel_linux-64 12.3.0
libstdcxx-ng 13.1.0
libsystemd0 253
libtasn1 4.19.0
libtheora 1.1.1
libtiff 4.5.1
libunistring 0.9.10
libuuid 2.38.1
libva 2.19.0
libvorbis 1.3.7
libvpx 1.13.0
libwebp 1.3.1
libwebp-base 1.3.1
libxcb 1.15
libxkbcommon 1.5.0
libxml2 2.11.4
libzip 1.9.2
libzlib 1.2.13
linkify-it-py 2.0.2
loguru 0.7.0
lz4-c 1.9.4
lzo 2.10
mako 1.2.4
markdown 3.4.3
markdown-it-py 3.0.0
markupsafe 2.1.3
matplotlib-base 3.7.1
matplotlib-inline 0.1.6
maxvolpy 0.3.8
mccabe 0.7.0
mdit-py-plugins 0.4.0
mdurl 0.1.2
meshio 5.3.4
metis 5.1.0
mistune 3.0.1
more-itertools 10.1.0
mpfr 4.2.0
mpg123 1.31.3
mpi 1.0
mpi4py 3.1.4
mpich 4.1.1
mpmath 1.3.0
multidict 6.0.4
multiphenicsx 0.2.dev1
mumps-include 5.2.1
mumps-mpi 5.2.1
munkres 1.1.4
mypy_extensions 1.0.0
mysql-common 8.0.33
mysql-libs 8.0.33 hca2cd23_1 conda-forge
nbclassic 1.0.0 pypi_0 pypi
nbclient 0.8.0 pyhd8ed1ab_0 conda-forge
nbconvert 7.6.0 pypi_0 pypi
nbconvert-core 7.7.4 pyhd8ed1ab_0 conda-forge
nbconvert-pandoc 7.7.4 pyhd8ed1ab_0 conda-forge
nbformat 5.9.0 pypi_0 pypi
nbval 0.10.0 pypi_0 pypi
nbvalx 0.0.dev1 pypi_0 pypi
ncurses 6.4 hcb278e6_0 conda-forge
nest-asyncio 1.5.6 pyhd8ed1ab_0 conda-forge
nettle 3.8.1 hc379101_1 conda-forge
nlohmann_json 3.11.2 h27087fc_0 conda-forge
nomkl 1.0 h5ca1d4c_0 conda-forge
notebook 6.4.12 pypi_0 pypi
notebook-shim 0.2.3 pypi_0 pypi
nspr 4.35 h27087fc_0 conda-forge
nss 3.89
numexpr 2.8.7
numpy 1.25.1
numpydoc 1.5.0
openh264 2.3.1 hcb278e6_2 conda-forge
openjpeg 2.5.0 hfec8fc6_2 conda-forge
openssl 3.1.3
overrides 7.3.1
p11-kit 0.24.1
packaging 23.1
pandas 2.0.3
pandoc 3.1.3
pandocfilters 1.5.0
panel 1.2.0
param 1.13.0
parmetis 4.0.3 h2a9763c_1005 conda-forge
parso 0.8.3 pyhd8ed1ab_0 conda-forge
pathspec 0.11.2 pyhd8ed1ab_0 conda-forge
pcre2 10.40 hc3806b6_0 conda-forge
petsc 3.18.5 real_h348f904_100 conda-forge
petsc4py 3.18.5 real_h1bd14e5_100 conda-forge
pexpect 4.8.0 pyh1a96a4e_2 conda-forge
pickleshare 0.7.5 py_1003 conda-forge
pillow 10.0.0 py39haaeba84_0 conda-forge
pip 23.1.2 pyhd8ed1ab_0 conda-forge
pixman 0.40.0 h36c2ea0_0 conda-forge
pkg-config 0.29.2 h36c2ea0_1008 conda-forge
pkgutil-resolve-name 1.3.10 pyhd8ed1ab_0 conda-forge
platformdirs 3.8.1 pyhd8ed1ab_0 conda-forge
plotly 5.15.0 pypi_0 pypi
pluggy 1.2.0 pyhd8ed1ab_0 conda-forge
ply 3.11 py_1 conda-forge
pooch 1.7.0 pyha770c72_3 conda-forge
proj 9.2.1 ha643af7_0 conda-forge
prometheus-client 0.17.0 pypi_0 pypi
prompt-toolkit 3.0.39 pyha770c72_0 conda-forge
prompt_toolkit 3.0.39 hd8ed1ab_0 conda-forge
psutil 5.9.5 py39h72bdee0_0 conda-forge
pthread-stubs 0.4 h36c2ea0_1001 conda-forge
ptscotch 6.0.9 hb499603_2 conda-forge
ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge
pugixml 1.13 h59595ed_1 conda-forge
pulseaudio-client 16.1 hb77b528_4 conda-forge
pure_eval 0.2.2 pyhd8ed1ab_0 conda-forge
pybind11 2.11.0 pypi_0 pypi
pybind11-abi 4 hd8ed1ab_3 conda-forge
pycodestyle 2.10.0 pyhd8ed1ab_0 conda-forge
pycosat 0.6.4 py39hb9d737c_1 conda-forge
pycparser 2.21 pyhd8ed1ab_0 conda-forge
pyct 0.5.0 pypi_0 pypi
pydocstyle 6.3.0 pyhd8ed1ab_0 conda-forge
pyflakes 3.0.1 pyhd8ed1ab_0 conda-forge
pygments 2.15.1 pyhd8ed1ab_0 conda-forge
pylint 2.17.5 pyhd8ed1ab_0 conda-forge
pylint-venv 3.0.2 pyhd8ed1ab_0 conda-forge
pyls-spyder 0.4.0
pyopenssl 23.2.0
pyparsing 3.1.0
pyqt 5.15.9
pyqt5-sip 12.12.2
pyqtwebengine 5.15.9
pysocks 1.7.1
pytables 3.7.0
pytest 7.4.0
python 3.9.16
python-dateutil 2.8.2
python-fastjsonschema 2.18.0
python-json-logger 2.0.7
python-lsp-black 1.3.0
python-lsp-jsonrpc 1.0.0
python-lsp-server 1.7.4
python-lsp-server-base 1.7.4 pyhd8ed1ab_0 conda-forge
python-slugify 8.0.1
python_abi 3.9
pythreejs 2.4.2
pytoolconfig 1.2.5
pytz 2023.3
pyvista 0.40.0
pyviz-comms 2.3.2
pyxdg 0.28
pyyaml 6.0
pyzmq 25.1.0 pi
qdarkstyle 3.1
qstylizer 0.2.2
qt-main 5.15.8
qt-webengine 5.15.8
qtawesome 1.2.3
qtconsole 5.4.3 pyhd8ed1ab_0 conda-forge
qtconsole-base 5.4.3 pyha770c72_0 conda-forge
qtpy 2.3.1 pyhd8ed1ab_0 conda-forge
readline 8.2 h8228510_1 conda-forge
referencing 0.29.1 pypi_0 pypi
requests 2.31.0 pyhd8ed1ab_0 conda-forge
rfc3339-validator 0.1.4 pypi_0 pypi
rfc3986-validator 0.1.1 pypi_0 pypi
rich 13.4.2
rope 1.9.0
rpds-py 0.8.10
rtree 1.0.1
ruamel.yaml 0.17.32
ruamel.yaml.clib 0.2.7
scalapack 2.2.0
scipy 1.11.3
scooby 0.7.2
scotch 6.0.9
secretstorage 3.3.3
send2trash 1.8.2
setuptools 68.0.0
simpervisor 1.0.0
sip 6.7.11
six 1.16.0
slepc 3.18.3
slepc4py 3.18.3
snappy 1.1.10
sniffio 1.3.0
snowballstemmer 2.2.0
sortedcontainers 2.4.0
soupsieve 2.4.1
sphinx 7.2.3
sphinxcontrib-applehelp 1.0.7
sphinxcontrib-devhelp 1.0.5
sphinxcontrib-htmlhelp 2.0.4
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 1.0.6
sphinxcontrib-serializinghtml 1.1.9
spyder 5.4.4
spyder-kernels 2.4.4
sqlite 3.42.0
stack_data 0.6.2
suitesparse 5.10.1
superlu 5.2.2
superlu_dist 7.2.0
svt-av1 1.4.1
sympy 1.12
sysroot_linux-64 2.12
tbb 2021.9.0
tbb-devel 2021.9.0 hf52228f_0 conda-forge
tenacity 8.2.2 pypi_0 pypi
terminado 0.17.1 pypi_0 pypi
text-unidecode 1.3 py_0 conda-forge
textdistance 4.5.0 pyhd8ed1ab_0 conda-forge
three-merge 0.1.1 pyh9f0ad1d_0 conda-forge
tinycss2 1.2.1 pyhd8ed1ab_0 conda-forge
tk 8.6.12
toml 0.10.2
tomli 2.0.1
tomlkit 0.12.1
toolz 0.12.0
tornado 6.3.2
tqdm 4.65.0
traitlets 5.6.0
traittypes 0.2.1
trame 2.5.0
trame-client 2.9.4
trame-components 2.1.1
trame-deckgl 2.0.2
trame-markdown 2.0.2
trame-matplotlib 2.0.2
trame-plotly 2.1.1
trame-rca 0.3.1
trame-router 2.0.2
trame-server 2.11.4
trame-simput 2.3.2
trame-vega 2.0.3
trame-vtk 2.5.4
trame-vuetify 2.2.4
typing-extensions 4.7.1
typing_extensions 4.7.1
tzdata 2023.3
uc-micro-py 1.0.2
ujson 5.7.0 py39h227be39_0 conda-forge
unicodedata2 15.0.0 py39hb9d737c_0 conda-forge
unidecode 1.3.6 pyhd8ed1ab_0 conda-forge
uri-template 1.3.0 pypi_0 pypi
urllib3 2.0.3 pyhd8ed1ab_1 conda-forge
utfcpp 3.2.3 ha770c72_0 conda-forge
viskex 0.0.dev1 pypi_0 pypi
vitables 3.0.3 py39h2dd6158_1 conda-forge
vtk 9.2.6
vtk-base 9.2.6
vtk-io-ffmpeg 9.2.6
watchdog 3.0.0
wcwidth 0.2.6
webcolors 1.13
webencodings 0.5.1
websocket-client 1.6.1
whatthepatch 1.0.5
wheel 0.40.0
widgetsnbextension 3.6.4
wrapt 1.15.0
wslink 1.11.1
wurlitzer 3.0.3
x264 1!164.3095
x265 3.5
xcb-util 0.4.0
xcb-util-image 0.4.0
xcb-util-keysyms 0.4.0
xcb-util-renderutil 0.3.9
xcb-util-wm 0.4.1
xkeyboard-config 2.39
xorg-compositeproto 0.4.2
xorg-damageproto 1.2.1
xorg-fixesproto 5.0
xorg-inputproto 2.3.2
xorg-kbproto 1.0.7
xorg-libice 1.0.10
xorg-libsm 1.2.3
xorg-libx11 1.8.6
xorg-libxau 1.0.11
xorg-libxcomposite 0.4.6
xorg-libxdamage 1.1.5
xorg-libxdmcp 1.1.3
xorg-libxext 1.3.4
xorg-libxfixes 5.0.3
xorg-libxi 1.7.10
xorg-libxrandr 1.5.2
xorg-libxrender 0.9.11
xorg-libxt 1.3.0
xorg-libxtst 1.2.3
xorg-randrproto 1.5.0
xorg-recordproto 1.14.2
xorg-renderproto 0.11.1
xorg-util-macros 1.19.3
xorg-xextproto 7.3.0
xorg-xf86vidmodeproto 2.3.1
xorg-xproto 7.0.31
xyzservices 2023.5.0
xz 5.2.6
yaml 0.2.5
yapf 0.40.1
yarl 1.9.2
zeromq 4.3.4
zfp 0.5.5
zipp 3.16.0
zlib 1.2.13
zstandard 0.19.0
zstd 1.5.2

1 Like

Hello,
did anyone encounter the same problem and managed to solve it?
thanks

It is quite easy to solve.
You need to sync the local residuals of f between processes.
MWE:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
from mpi4py import MPI

from petsc4py import PETSc
import petsc4py
import ufl
from dolfinx import fem, mesh, la
from dolfinx.fem import petsc

petsc4py.init(sys.argv)
os.system('clear')

# Create MPI communicator
comm = MPI.COMM_WORLD

# ----- Create the mesh
L = 1.0
N = 10
domain = mesh.create_box(comm=comm, points=[[0.0, 0.0, 0.0], [L, L, L]],
                         n=[N, N, N], cell_type=mesh.CellType.tetrahedron)

# ----- Create the measure dv
metadata = {'quadrature_degree': 4}
dv = ufl.Measure("dx", domain=domain, metadata=metadata)

# Define the function space
V_theta = fem.VectorFunctionSpace(domain, ("CG", 1), dim=2)

# Solution, Test, and Trial functions
theta = fem.Function(V_theta)
theta_test, theta_trial = ufl.TestFunction(V_theta), ufl.TrialFunction(V_theta)

# initalise the value of theta with a specific function


def initalization(pos):
    radius = pos[0]**2 + pos[1]**2
    if radius < 1.0e-8:
        return (0., np.pi/2)
    mx = -pos[1]/radius * np.sqrt(1-np.exp(-4*radius**2/L**2))
    my = pos[0]/radius * np.sqrt(1-np.exp(-4*radius**2/L**2))
    mz = np.exp(-2*radius**2/L**2)
    if mx == 0:
        return (0., np.pi/2)
    else:
        return (np.arctan(my/mx), np.arcsin(mz))


def eval_theta(x):
    d = x.shape[1]
    angles = [np.zeros(d), np.zeros(d)]
    for i in range(d):
        angles[0][i], angles[1][i] = initalization([x[0][i], x[1][i], x[2][i]])
    return angles


theta.interpolate(eval_theta)

""" 
---------------------------------------------------------
Define the optimization problem class
---------------------------------------------------------
"""


class TAOProblem:

    """Nonlinear problem class compatible with PETSC.TAO solver"""

    def __init__(self, f, F, J, u, bcs):

        self.bcs = bcs                 # bcs = boundary conditions
        self.u = u                   # u = solution
        self.L = fem.form(F)  # F = residual
        self.a = fem.form(J)  # J = jacobian
        self.obj = fem.form(f)  # f = objective function

        # .. Generation of matrices in the no-linear problem
        self.A = petsc.create_matrix(self.a)
        self.b = petsc.create_vector(self.L)

    def f(self, tao, x):
        # .... Connection between ranks in the vector x: FORWARD
        x.ghostUpdate(addv=PETSc.InsertMode.INSERT,
                      mode=PETSc.ScatterMode.FORWARD)
        # ..... We copy the vector x in the vector u
        x.copy(self.u.vector)
        # .... Connection between ranks in the vector u in the class
        self.u.vector.ghostUpdate(
            addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)
        local_res = fem.assemble_scalar(fem.form(self.obj))
        global_res = self.u.function_space.mesh.comm.allreduce(
            local_res, op=MPI.SUM)

        return global_res

    def F(self, tao, x, F):

        # .... Connection between ranks in the vector x: FORWARD
        x.ghostUpdate(addv=PETSc.InsertMode.INSERT,
                      mode=PETSc.ScatterMode.FORWARD)
        # ..... We copy the vector x in the vector u
        x.copy(self.u.vector)
        # .... Connection between ranks in the vector u in the class
        self.u.vector.ghostUpdate(
            addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)

        with F.localForm() as f_local:
            f_local.set(0.0)

        petsc.assemble_vector(F, self.L)
        # .... Include the boundary conditions using the lifting technique
        petsc.apply_lifting(
            F, [self.a], bcs=[self.bcs], x0=[x], scale=-1.0)
        # .... Connection between ranks in the vector x: REVERSE
        F.ghostUpdate(addv=PETSc.InsertMode.ADD,
                      mode=PETSc.ScatterMode.REVERSE)
        # .... Redefine the function F with the boundary condition.
        petsc.set_bc(F, self.bcs, x, -1.0)
        F.ghostUpdate(addv=PETSc.InsertMode.INSERT,
                      mode=PETSc.ScatterMode.FORWARD)

    def J(self, tao, x, A, P):
        A.zeroEntries()
        petsc.assemble_matrix(A, self.a, self.bcs)
        A.assemble()


# --- Define lower bound and upper bound for theta
theta_min = fem.Function(V_theta)
theta_max = fem.Function(V_theta)


def eval_theta_lowerbound(x):
    d = x.shape[1]
    return [-np.pi*np.ones(d), -np.pi/2*np.ones(d)]


def eval_theta_upperbound(x):
    d = x.shape[1]
    return [np.pi*np.ones(d), np.pi/2*np.ones(d)]


theta_min.interpolate(eval_theta_lowerbound)
theta_max.interpolate(eval_theta_upperbound)

""" 
---------------------------------------------------------
Initialize energy functions and directional derivatives
---------------------------------------------------------
"""
# convert m to Cartesian coordinates


def m(theta):
    return ufl.as_vector([ufl.cos(theta[0])*ufl.cos(theta[1]),
                          ufl.sin(theta[0])*ufl.cos(theta[1]),
                          ufl.sin(theta[1])])

# --- Compute current energies and derivatives


W = ufl.tr(ufl.grad(m(theta))*ufl.grad(m(theta)).T)*dv
W_residual_theta = ufl.derivative(W, theta, theta_test)
W_jacobian_theta = ufl.derivative(W_residual_theta, theta, theta_trial)
# --- Allocate the memory needed for TAO
dofs_domain, dofs_borders = V_theta.dofmap.index_map, V_theta.dofmap.index_map_bs
b = la.create_petsc_vector(dofs_domain, dofs_borders)
J = petsc.create_matrix(fem.form(W_jacobian_theta))

# --- Create optimization problem
problem = TAOProblem(W, W_residual_theta, W_jacobian_theta, theta, [])

# --- Setup optimization
solver_tao = PETSc.TAO().create(comm=comm)
solver_tao.setObjective(problem.f)
solver_tao.setGradient(problem.F, b)
solver_tao.setHessian(problem.J, J)

solver_tao.setType("tron")
solver_tao.getKSP().setType("preonly")
solver_tao.getKSP().getPC().setType("lu")
solver_tao.setVariableBounds(theta_min.vector, theta_max.vector)

solver_tao.solve(theta.vector)

# --- Print out solution status
curr_iter, curr_f, gnorm, cnorm, xdiff, converged_reason = solver_tao.getSolutionStatus()

print("Finish")

I’ve also added a ghost update inside F after setting bcs.

Dear Dokken,

thank you very much! It works perfectly.
Can you please explain to me why J does not need any ghostUpdate or MPI broadcasting after assembling, as it seems to be the case for f and F?
Thank you again for your help and for all the work you do

The updating is here done by:

which is PETSc’s way of communicating matrix data.