Commit 45dcfca7 authored by Martin Reinecke's avatar Martin Reinecke

tweak KL_Energy

parent 735e589b
...@@ -72,8 +72,8 @@ if __name__ == '__main__': ...@@ -72,8 +72,8 @@ if __name__ == '__main__':
# set up minimization and inversion schemes # set up minimization and inversion schemes
ic_sampling = ift.GradientNormController(iteration_limit=100) ic_sampling = ift.GradientNormController(iteration_limit=100)
ic_newton = ift.DeltaEnergyController( ic_newton = ift.GradInfNormController(
name='Newton', tol_rel_deltaE=1e-8, iteration_limit=100) name='Newton', tol=1e-7, iteration_limit=1000)
minimizer = ift.NewtonCG(ic_newton) minimizer = ift.NewtonCG(ic_newton)
# build model Hamiltonian # build model Hamiltonian
...@@ -91,7 +91,7 @@ if __name__ == '__main__': ...@@ -91,7 +91,7 @@ if __name__ == '__main__':
# number of samples used to estimate the KL # number of samples used to estimate the KL
N_samples = 20 N_samples = 20
for i in range(2): for i in range(2):
KL = ift.KL_Energy(position, H, N_samples, want_metric=True) KL = ift.KL_Energy(position, H, N_samples)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
position = KL.position position = KL.position
......
...@@ -152,6 +152,9 @@ class Linearization(object): ...@@ -152,6 +152,9 @@ class Linearization(object):
def add_metric(self, metric): def add_metric(self, metric):
return self.new(self._val, self._jac, metric) return self.new(self._val, self._jac, metric)
def with_want_metric(self):
return Linearization(self._val, self._jac, self._metric, True)
@staticmethod @staticmethod
def make_var(field, want_metric=False): def make_var(field, want_metric=False):
from .operators.scaling_operator import ScalingOperator from .operators.scaling_operator import ScalingOperator
...@@ -163,3 +166,15 @@ class Linearization(object): ...@@ -163,3 +166,15 @@ class Linearization(object):
from .operators.simple_linear_operators import NullOperator from .operators.simple_linear_operators import NullOperator
return Linearization(field, NullOperator(field.domain, field.domain), return Linearization(field, NullOperator(field.domain, field.domain),
want_metric=want_metric) want_metric=want_metric)
@staticmethod
def make_partial_var(field, constants, want_metric=False):
from .operators.scaling_operator import ScalingOperator
from .operators.simple_linear_operators import NullOperator
if len(constants) == 0:
return Linearization.make_var(field, want_metric)
else:
ops = [ScalingOperator(0. if key in constants else 1., dom)
for key, dom in field.domain.items()]
bdop = BlockDiagonalOperator(fielld.domain, tuple(ops))
return Linearization(field, bdop, want_metric=want_metric)
...@@ -13,14 +13,8 @@ class EnergyAdapter(Energy): ...@@ -13,14 +13,8 @@ class EnergyAdapter(Energy):
self._op = op self._op = op
self._constants = constants self._constants = constants
self._want_metric = want_metric self._want_metric = want_metric
if len(self._constants) == 0: lin = Linearization.make_partial_var(position, constants, want_metric)
tmp = self._op(Linearization.make_var(self._position, want_metric)) tmp = self._op(lin)
else:
ops = [ScalingOperator(0. if key in self._constants else 1., dom)
for key, dom in self._position.domain.items()]
bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
tmp = self._op(Linearization(self._position, bdop,
want_metric=want_metric))
self._val = tmp.val.local_data[()] self._val = tmp.val.local_data[()]
self._grad = tmp.gradient self._grad = tmp.gradient
self._metric = tmp._metric self._metric = tmp._metric
......
...@@ -9,33 +9,32 @@ from .. import utilities ...@@ -9,33 +9,32 @@ from .. import utilities
class KL_Energy(Energy): class KL_Energy(Energy):
def __init__(self, position, h, nsamp, constants=[], _samples=None, def __init__(self, position, h, nsamp, constants=[], _samples=None):
want_metric=False):
super(KL_Energy, self).__init__(position) super(KL_Energy, self).__init__(position)
self._h = h self._h = h
self._constants = constants self._constants = constants
self._want_metric = want_metric
if _samples is None: if _samples is None:
met = h(Linearization.make_var(position, True)).metric met = h(Linearization.make_var(position, True)).metric
_samples = tuple(met.draw_sample(from_inverse=True) _samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(nsamp)) for _ in range(nsamp))
self._samples = _samples self._samples = _samples
if len(constants) == 0:
tmp = Linearization.make_var(position, want_metric) self._lin = Linearization.make_partial_var(position, constants)
else: v, g = None, None
ops = [ScalingOperator(0. if key in constants else 1., dom) for s in self._samples:
for key, dom in position.domain.items()] tmp = self._h(self._lin+s)
bdop = BlockDiagonalOperator(position.domain, tuple(ops)) if v is None:
tmp = Linearization(position, bdop, want_metric=want_metric) v = tmp.val.local_data[()]
mymap = map(lambda v: self._h(tmp+v), self._samples) g = tmp.gradient
tmp = utilities.my_sum(mymap) * (1./len(self._samples)) else:
self._val = tmp.val.local_data[()] v += tmp.val.local_data[()]
self._grad = tmp.gradient g = g + tmp.gradient
self._metric = tmp.metric self._val = v / len(self._samples)
self._grad = g * (1./len(self._samples))
self._metric = None
def at(self, position): def at(self, position):
return KL_Energy(position, self._h, 0, self._constants, self._samples, return KL_Energy(position, self._h, 0, self._constants, self._samples)
self._want_metric)
@property @property
def value(self): def value(self):
...@@ -45,11 +44,20 @@ class KL_Energy(Energy): ...@@ -45,11 +44,20 @@ class KL_Energy(Energy):
def gradient(self): def gradient(self):
return self._grad return self._grad
def _get_metric(self):
if self._metric is None:
lin = self._lin.with_want_metric()
mymap = map(lambda v: self._h(lin+v).metric, self._samples)
self._metric = utilities.my_sum(mymap)
self._metric = self._metric.scale(1./len(self._samples))
def apply_metric(self, x): def apply_metric(self, x):
self._get_metric()
return self._metric(x) return self._metric(x)
@property @property
def metric(self): def metric(self):
self._get_metric()
return self._metric return self._metric
@property @property
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment