Thread safety in optimize_kl
The following code attempts to run a number of optimize_kl instances in parallel, first with a process pool and then with a thread pool.
import numpy as np
import nifty8 as ift
import concurrent.futures as cf
def function_calling_nifty(seed):
# set up correlated field model
npix = 512
pospace = ift.RGSpace((npix,))
spfreq = ift.RGSpace((npix,))
cfmaker = ift.CorrelatedFieldMaker('amplitude')
cfmaker.add_fluctuations(spfreq, (0.1, 1e-2), None, None, (-3, 1),
'f')
cfmaker.set_amplitude_total_offset(0., (1e-2, 1e-6))
cf = cfmaker.finalize()
normalized_amp = cfmaker.get_normalized_amplitudes()
pspec = normalized_amp[0]**2
# signal + fake data
signal = ift.exp(cf.real)
mask = np.random.binomial(1, 0.5, size=npix).astype(bool)
tmp = ift.makeField(signal.target, ~mask)
R = ift.MaskOperator(tmp)
signal_response = R(signal)
mock_position = ift.from_random(signal_response.domain, 'normal')
dspace = R.target
noise = .001
N = ift.ScalingOperator(dspace, noise, float)
data = signal_response(mock_position) + N.draw_sample()
# Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController(deltaE=0.01, iteration_limit=50)
ic_newton = ift.AbsDeltaEnergyController(deltaE=0.01, iteration_limit=15)
minimizer = ift.NewtonCG(ic_newton, enable_logging=False)
# Set up likelihood energy and information Hamiltonian
likelihood_energy = ift.GaussianEnergy(data, inverse_covariance=N.inverse) @ signal_response
n_samples = 3
n_iterations = 5
with ift.random.Context(seed):
samples = ift.optimize_kl(likelihood_energy,
n_iterations,
n_samples,
minimizer,
ic_sampling,
None, # for GeoVI
plot_energy_history=False,
plot_minisanity_history=False)
return samples
if __name__=='__main__':
# processes
print("___________________________________Running with processes___________________________________")
nrun = 8
futures = []
with cf.ProcessPoolExecutor(max_workers=nrun) as executor:
for i in range(nrun):
future = executor.submit(function_calling_nifty, i)
futures.append(future)
for f in cf.as_completed(futures):
samples = f.result()
# threads
print("___________________________________Running with threads___________________________________")
futures = []
with cf.ThreadPoolExecutor(max_workers=nrun) as executor:
for i in range(nrun):
future = executor.submit(function_calling_nifty, i)
futures.append(future)
for f in cf.as_completed(futures):
samples = f.result()
The first run with processes is successful but the second one falls over with
Traceback (most recent call last):
File "/home/landman/software/scratch/nifty_threadsafety.py", line 43, in function_calling_nifty
samples = ift.optimize_kl(likelihood_energy,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/landman/venvs/qcal/lib/python3.11/site-packages/nifty8/minimization/optimize_kl.py", line 368, in optimize_kl
e = SampledKLEnergy(
^^^^^^^^^^^^^^^^
File "/home/landman/venvs/qcal/lib/python3.11/site-packages/nifty8/minimization/kl_energies.py", line 290, in SampledKLEnergy
sample_list = draw_samples(position, ham_sampling, minimizer_sampling, n_samples,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/landman/venvs/qcal/lib/python3.11/site-packages/nifty8/minimization/kl_energies.py", line 142, in draw_samples
with random.Context(sseq[i]):
File "/home/landman/venvs/qcal/lib/python3.11/site-packages/nifty8/random.py", line 290, in __exit__
raise RuntimeError("inconsistent RNG usage detected")
RuntimeError: inconsistent RNG usage detected
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/landman/software/scratch/nifty_threadsafety.py", line 78, in <module>
samples = f.result()
^^^^^^^^^^
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 449, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/usr/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/landman/software/scratch/nifty_threadsafety.py", line 42, in function_calling_nifty
with ift.random.Context(seed):
File "/home/landman/venvs/qcal/lib/python3.11/site-packages/nifty8/random.py", line 290, in __exit__
raise RuntimeError("inconsistent RNG usage detected")
RuntimeError: inconsistent RNG usage detected
I tried using the Context class as suggested in src/random.py
but this doesn't seem to have the desired effect. Any suggestions on how to get this working with a thread pool would be much appreciated.