diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py index 19732e3f18de35ea0f48242d90f5bb06cc3c7adc..510bda7d46ea989ff670cd2dd693b16d7bd3ec4c 100644 --- a/nifty5/minimization/kl_energy.py +++ b/nifty5/minimization/kl_energy.py @@ -45,16 +45,20 @@ class KL_Energy_MPI(Energy): h, nsamp, constants=[], + constants_samples=None, _samples=None, want_metric=False): super(KL_Energy_MPI, self).__init__(position) self._h = h self._nsamp = nsamp self._constants = constants + if constants_samples is None: + constants_samples = constants + self._constants_samples = constants_samples self._want_metric = want_metric if _samples is None: lo, hi = _shareRange(nsamp, ntask, rank) - met = h(Linearization.make_var(position, True)).metric + met = h(Linearization.make_partial_var(position, constants_samples, True)).metric _samples = [] for i in range(lo, hi): np.random.seed(i) @@ -73,7 +77,7 @@ class KL_Energy_MPI(Energy): def at(self, position): return KL_Energy_MPI(position, self._h, self._nsamp, self._constants, - self._samples, self._want_metric) + self._constants_samples, self._samples, self._want_metric) @property def value(self):