Skip to content
Snippets Groups Projects
Commit 30df8975 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'kl_energy_mirrored_samples' into 'NIFTy_5'

KL_Energy: add keyword argument to mirror samples

See merge request ift/nifty-dev!140
parents eaf710b7 ef60a287
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,8 @@ from .. import utilities
class KL_Energy(Energy):
def __init__(self, position, h, nsamp, constants=[],
constants_samples=None, _samples=None):
constants_samples=None, gen_mirrored_samples=False,
_samples=None):
super(KL_Energy, self).__init__(position)
if h.domain is not position.domain:
raise TypeError
......@@ -22,6 +23,8 @@ class KL_Energy(Energy):
position, constants_samples, True)).metric
_samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(nsamp))
if gen_mirrored_samples:
_samples += tuple(-s for s in _samples)
self._samples = _samples
self._lin = Linearization.make_partial_var(position, constants)
......@@ -39,8 +42,9 @@ class KL_Energy(Energy):
self._metric = None
def at(self, position):
return KL_Energy(position, self._h, 0, self._constants,
self._constants_samples, self._samples)
return KL_Energy(position, self._h, 0,
self._constants, self._constants_samples,
_samples = self._samples)
@property
def value(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment