diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py index fe250f264371704c6b305b0df167a985fdab9f9f..0e8155f9e3d4bd36c40433bdb94a61459117010e 100644 --- a/nifty5/minimization/kl_energy.py +++ b/nifty5/minimization/kl_energy.py @@ -8,7 +8,8 @@ from .. import utilities class KL_Energy(Energy): def __init__(self, position, h, nsamp, constants=[], - constants_samples=None, _samples=None): + constants_samples=None, gen_mirrored_samples=False, + _samples=None): super(KL_Energy, self).__init__(position) if h.domain is not position.domain: raise TypeError @@ -22,6 +23,8 @@ class KL_Energy(Energy): position, constants_samples, True)).metric _samples = tuple(met.draw_sample(from_inverse=True) for _ in range(nsamp)) + if gen_mirrored_samples: + _samples += tuple(-s for s in _samples) self._samples = _samples self._lin = Linearization.make_partial_var(position, constants) @@ -39,8 +42,9 @@ class KL_Energy(Energy): self._metric = None def at(self, position): - return KL_Energy(position, self._h, 0, self._constants, - self._constants_samples, self._samples) + return KL_Energy(position, self._h, 0, + self._constants, self._constants_samples, + _samples = self._samples) @property def value(self):