sample list average crashes for mpi run with map estimator

When computing a map estimate with multiple MPI tasks the sample list average method crashes. Here is an example code that works if executed normally with python, but crashes in an MPI run

import numpy as np
import nifty8 as ift

ift.random.push_sseq_from_seed(27)

try:
    from mpi4py import MPI

    comm = MPI.COMM_WORLD
    master = comm.Get_rank() == 0
except ImportError:
    comm = None
    master = True

position_space = ift.RGSpace([128, 128])
op = ift.makeOp(ift.full(position_space, 10.))
noise = 0.1
N = ift.ScalingOperator(position_space, noise, np.float64)
mock_position = ift.from_random(op.domain)
data = op(mock_position) + N.draw_sample()
lh = ift.GaussianEnergy(mean=data, inverse_covariance=N.inverse) @ op

ic_sampling = ift.AbsDeltaEnergyController(
    name="Sampling (linear)", deltaE=0.05, iteration_limit=10
)
ic_newton = ift.AbsDeltaEnergyController(
    name="Newton", deltaE=0.5, convergence_level=2, iteration_limit=5
)
minimizer = ift.NewtonCG(ic_newton)


def callback(samples, i):
    plot = ift.Plot()
    mean = samples.average(op)
    plot.add(mean, title="Reconstruction", zmin=0, zmax=1)
    if master:
        plot.output()


n_iterations = 3
n_samples = lambda iiter: 0 if iiter < 1 else 2
samples = ift.optimize_kl(
    lh,
    n_iterations,
    n_samples,
    minimizer,
    ic_sampling,
    None,
    overwrite=True,
    comm=comm,
    callback=callback,
)
Assignee Loading
Time tracking Loading