From aba5df2095e9384a3c724ea896456261063cffb6 Mon Sep 17 00:00:00 2001 From: Reimar Leike <reimar@leike.name> Date: Tue, 11 Sep 2018 15:40:54 +0200 Subject: [PATCH] introducing constants_sampling for MPI paralllel KL --- nifty5/minimization/kl_energy.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py index 19732e3f1..510bda7d4 100644 --- a/nifty5/minimization/kl_energy.py +++ b/nifty5/minimization/kl_energy.py @@ -45,16 +45,20 @@ class KL_Energy_MPI(Energy): h, nsamp, constants=[], + constants_samples=None, _samples=None, want_metric=False): super(KL_Energy_MPI, self).__init__(position) self._h = h self._nsamp = nsamp self._constants = constants + if constants_samples is None: + constants_samples = constants + self._constants_samples = constants_samples self._want_metric = want_metric if _samples is None: lo, hi = _shareRange(nsamp, ntask, rank) - met = h(Linearization.make_var(position, True)).metric + met = h(Linearization.make_partial_var(position, constants_samples, True)).metric _samples = [] for i in range(lo, hi): np.random.seed(i) @@ -73,7 +77,7 @@ class KL_Energy_MPI(Energy): def at(self, position): return KL_Energy_MPI(position, self._h, self._nsamp, self._constants, - self._samples, self._want_metric) + self._constants_samples, self._samples, self._want_metric) @property def value(self): -- GitLab