Commit 30df8975 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'kl_energy_mirrored_samples' into 'NIFTy_5'

KL_Energy: add keyword argument to mirror samples

See merge request ift/nifty-dev!140
parents eaf710b7 ef60a287
...@@ -8,7 +8,8 @@ from .. import utilities ...@@ -8,7 +8,8 @@ from .. import utilities
class KL_Energy(Energy): class KL_Energy(Energy):
def __init__(self, position, h, nsamp, constants=[], def __init__(self, position, h, nsamp, constants=[],
constants_samples=None, _samples=None): constants_samples=None, gen_mirrored_samples=False,
_samples=None):
super(KL_Energy, self).__init__(position) super(KL_Energy, self).__init__(position)
if h.domain is not position.domain: if h.domain is not position.domain:
raise TypeError raise TypeError
...@@ -22,6 +23,8 @@ class KL_Energy(Energy): ...@@ -22,6 +23,8 @@ class KL_Energy(Energy):
position, constants_samples, True)).metric position, constants_samples, True)).metric
_samples = tuple(met.draw_sample(from_inverse=True) _samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(nsamp)) for _ in range(nsamp))
if gen_mirrored_samples:
_samples += tuple(-s for s in _samples)
self._samples = _samples self._samples = _samples
self._lin = Linearization.make_partial_var(position, constants) self._lin = Linearization.make_partial_var(position, constants)
...@@ -39,8 +42,9 @@ class KL_Energy(Energy): ...@@ -39,8 +42,9 @@ class KL_Energy(Energy):
self._metric = None self._metric = None
def at(self, position): def at(self, position):
return KL_Energy(position, self._h, 0, self._constants, return KL_Energy(position, self._h, 0,
self._constants_samples, self._samples) self._constants, self._constants_samples,
_samples = self._samples)
@property @property
def value(self): def value(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment