sample list average crashes for mpi run with map estimator
When computing a map estimate with multiple MPI tasks the sample list average method crashes. Here is an example code that works if executed normally with python, but crashes in an MPI run
import numpy as np
import nifty8 as ift
ift.random.push_sseq_from_seed(27)
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
master = comm.Get_rank() == 0
except ImportError:
comm = None
master = True
position_space = ift.RGSpace([128, 128])
op = ift.makeOp(ift.full(position_space, 10.))
noise = 0.1
N = ift.ScalingOperator(position_space, noise, np.float64)
mock_position = ift.from_random(op.domain)
data = op(mock_position) + N.draw_sample()
lh = ift.GaussianEnergy(mean=data, inverse_covariance=N.inverse) @ op
ic_sampling = ift.AbsDeltaEnergyController(
name="Sampling (linear)", deltaE=0.05, iteration_limit=10
)
ic_newton = ift.AbsDeltaEnergyController(
name="Newton", deltaE=0.5, convergence_level=2, iteration_limit=5
)
minimizer = ift.NewtonCG(ic_newton)
def callback(samples, i):
plot = ift.Plot()
mean = samples.average(op)
plot.add(mean, title="Reconstruction", zmin=0, zmax=1)
if master:
plot.output()
n_iterations = 3
n_samples = lambda iiter: 0 if iiter < 1 else 2
samples = ift.optimize_kl(
lh,
n_iterations,
n_samples,
minimizer,
ic_sampling,
None,
overwrite=True,
comm=comm,
callback=callback,
)