Commit b55598c9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'flexible_KL' into 'NIFTy_5'

allow specification of constants for samples

See merge request ift/nifty-dev!103
parents eeea9171 695b9d9d
...@@ -7,12 +7,17 @@ from .. import utilities ...@@ -7,12 +7,17 @@ from .. import utilities
class KL_Energy(Energy): class KL_Energy(Energy):
def __init__(self, position, h, nsamp, constants=[], _samples=None): def __init__(self, position, h, nsamp, constants=[],
constants_samples=None, _samples=None):
super(KL_Energy, self).__init__(position) super(KL_Energy, self).__init__(position)
self._h = h self._h = h
self._constants = constants self._constants = constants
if constants_samples is None:
constants_samples = constants
self._constants_samples = constants_samples
if _samples is None: if _samples is None:
met = h(Linearization.make_var(position, True)).metric met = h(Linearization.make_partial_var(
position, constants_samples, True)).metric
_samples = tuple(met.draw_sample(from_inverse=True) _samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(nsamp)) for _ in range(nsamp))
self._samples = _samples self._samples = _samples
...@@ -32,7 +37,8 @@ class KL_Energy(Energy): ...@@ -32,7 +37,8 @@ class KL_Energy(Energy):
self._metric = None self._metric = None
def at(self, position): def at(self, position):
return KL_Energy(position, self._h, 0, self._constants, self._samples) return KL_Energy(position, self._h, 0, self._constants,
self._constants_samples, self._samples)
@property @property
def value(self): def value(self):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment