Commit 45dcfca7 authored by Martin Reinecke's avatar Martin Reinecke

tweak KL_Energy

parent 735e589b
......@@ -72,8 +72,8 @@ if __name__ == '__main__':
# set up minimization and inversion schemes
ic_sampling = ift.GradientNormController(iteration_limit=100)
ic_newton = ift.DeltaEnergyController(
name='Newton', tol_rel_deltaE=1e-8, iteration_limit=100)
ic_newton = ift.GradInfNormController(
name='Newton', tol=1e-7, iteration_limit=1000)
minimizer = ift.NewtonCG(ic_newton)
# build model Hamiltonian
......@@ -91,7 +91,7 @@ if __name__ == '__main__':
# number of samples used to estimate the KL
N_samples = 20
for i in range(2):
KL = ift.KL_Energy(position, H, N_samples, want_metric=True)
KL = ift.KL_Energy(position, H, N_samples)
KL, convergence = minimizer(KL)
position = KL.position
......
......@@ -152,6 +152,9 @@ class Linearization(object):
def add_metric(self, metric):
return self.new(self._val, self._jac, metric)
def with_want_metric(self):
return Linearization(self._val, self._jac, self._metric, True)
@staticmethod
def make_var(field, want_metric=False):
from .operators.scaling_operator import ScalingOperator
......@@ -163,3 +166,15 @@ class Linearization(object):
from .operators.simple_linear_operators import NullOperator
return Linearization(field, NullOperator(field.domain, field.domain),
want_metric=want_metric)
@staticmethod
def make_partial_var(field, constants, want_metric=False):
from .operators.scaling_operator import ScalingOperator
from .operators.simple_linear_operators import NullOperator
if len(constants) == 0:
return Linearization.make_var(field, want_metric)
else:
ops = [ScalingOperator(0. if key in constants else 1., dom)
for key, dom in field.domain.items()]
bdop = BlockDiagonalOperator(fielld.domain, tuple(ops))
return Linearization(field, bdop, want_metric=want_metric)
......@@ -13,14 +13,8 @@ class EnergyAdapter(Energy):
self._op = op
self._constants = constants
self._want_metric = want_metric
if len(self._constants) == 0:
tmp = self._op(Linearization.make_var(self._position, want_metric))
else:
ops = [ScalingOperator(0. if key in self._constants else 1., dom)
for key, dom in self._position.domain.items()]
bdop = BlockDiagonalOperator(self._position.domain, tuple(ops))
tmp = self._op(Linearization(self._position, bdop,
want_metric=want_metric))
lin = Linearization.make_partial_var(position, constants, want_metric)
tmp = self._op(lin)
self._val = tmp.val.local_data[()]
self._grad = tmp.gradient
self._metric = tmp._metric
......
......@@ -9,33 +9,32 @@ from .. import utilities
class KL_Energy(Energy):
def __init__(self, position, h, nsamp, constants=[], _samples=None,
want_metric=False):
def __init__(self, position, h, nsamp, constants=[], _samples=None):
super(KL_Energy, self).__init__(position)
self._h = h
self._constants = constants
self._want_metric = want_metric
if _samples is None:
met = h(Linearization.make_var(position, True)).metric
_samples = tuple(met.draw_sample(from_inverse=True)
for _ in range(nsamp))
self._samples = _samples
if len(constants) == 0:
tmp = Linearization.make_var(position, want_metric)
else:
ops = [ScalingOperator(0. if key in constants else 1., dom)
for key, dom in position.domain.items()]
bdop = BlockDiagonalOperator(position.domain, tuple(ops))
tmp = Linearization(position, bdop, want_metric=want_metric)
mymap = map(lambda v: self._h(tmp+v), self._samples)
tmp = utilities.my_sum(mymap) * (1./len(self._samples))
self._val = tmp.val.local_data[()]
self._grad = tmp.gradient
self._metric = tmp.metric
self._lin = Linearization.make_partial_var(position, constants)
v, g = None, None
for s in self._samples:
tmp = self._h(self._lin+s)
if v is None:
v = tmp.val.local_data[()]
g = tmp.gradient
else:
v += tmp.val.local_data[()]
g = g + tmp.gradient
self._val = v / len(self._samples)
self._grad = g * (1./len(self._samples))
self._metric = None
def at(self, position):
return KL_Energy(position, self._h, 0, self._constants, self._samples,
self._want_metric)
return KL_Energy(position, self._h, 0, self._constants, self._samples)
@property
def value(self):
......@@ -45,11 +44,20 @@ class KL_Energy(Energy):
def gradient(self):
return self._grad
def _get_metric(self):
if self._metric is None:
lin = self._lin.with_want_metric()
mymap = map(lambda v: self._h(lin+v).metric, self._samples)
self._metric = utilities.my_sum(mymap)
self._metric = self._metric.scale(1./len(self._samples))
def apply_metric(self, x):
self._get_metric()
return self._metric(x)
@property
def metric(self):
self._get_metric()
return self._metric
@property
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment