Commit 3d6b510e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

try to simplify energies, stage 1

parent 55a5717b
from builtins import *
from ..minimization.energy import Energy
from ..operators.inversion_enabler import InversionEnabler
from ..operators.scaling_operator import ScalingOperator
from ..utilities import memo
class SampledKullbachLeiblerDivergence(Energy):
def __init__(self, h, res_samples, iteration_controller):
def __init__(self, h, res_samples):
"""
h: Hamiltonian
N: Number of samples to be used
......@@ -13,42 +12,28 @@ class SampledKullbachLeiblerDivergence(Energy):
super(SampledKullbachLeiblerDivergence, self).__init__(h.position)
self._h = h
self._res_samples = res_samples
self._iteration_controller = iteration_controller
self._energy_list = []
for ss in res_samples:
e = h.at(self.position+ss)
self._energy_list.append(e)
self._energy_list = tuple(h.at(self.position+ss)
for ss in res_samples)
def at(self, position):
return self.__class__(self._h.at(position), self._res_samples,
self._iteration_controller)
return self.__class__(self._h.at(position), self._res_samples)
@property
@memo
def value(self):
v = self._energy_list[0].value
for energy in self._energy_list[1:]:
v += energy.value
return v / len(self._energy_list)
return (my_sum(map(lambda v: v.value, self._energy_list)) /
len(self._energy_list))
@property
@memo
def gradient(self):
g = self._energy_list[0].gradient
for energy in self._energy_list[1:]:
g += energy.gradient
return g / len(self._energy_list)
return (my_sum(map(lambda v: v.gradient, self._energy_list)) /
len(self._energy_list))
@property
@memo
def curvature(self):
# MR FIXME: This looks a bit strange...
approx = self._energy_list[-1]._prior.curvature
curvature_list = [e.curvature for e in self._energy_list]
op = curvature_list[0]
for curv in curvature_list[1:]:
op = op + curv
op = op * ScalingOperator(1./len(curvature_list), op.domain)
return InversionEnabler(op, self._iteration_controller, approx)
return (my_sum(map(lambda v: v.curvature, self._energy_list)) *
(1./len(self._energy_list)))
......@@ -129,6 +129,11 @@ class Energy(NiftyMetaBase()):
"""
return None
def makeInvertible(self, controller, preconditioner=None):
if not isinstance(controller, IterationController):
raise TypeError
return CurvatureInversionEnabler(self, controller, preconditioner)
def __mul__(self, factor):
from .energy_sum import EnergySum
if isinstance(factor, (float, int)):
......@@ -153,3 +158,45 @@ class Energy(NiftyMetaBase()):
def __neg__(self):
from .energy_sum import EnergySum
return EnergySum.make([self], [-1.])
class CurvatureInversionEnabler(Energy):
def __init__(self, ene, controller, preconditioner):
super(CurvatureInversionEnabler, self).__init__(ene.position)
self._energy = ene
self._controller = controller
self._preconditioner = preconditioner
def at(self, position):
if self._position.isSubsetOf(position):
return self
return CurvatureInversionEnabler(
self._energy.at(position), self._controller, self._preconditioner)
@property
def position(self):
return self._energy.position
@property
def value(self):
return self._energy.value
@property
def gradient(self):
return self._energy.gradient
@property
def curvature(self):
from ..operators.linear_operator import LinearOperator
from ..operators.inversion_enabler import InversionEnabler
curv = self._energy.curvature
if self._preconditioner is None:
precond = None
elif isinstance(self._preconditioner, LinearOperator):
precond = self._preconditioner
elif isinstance(self._preconditioner, Energy):
precond = self._preconditioner.at(self.position).curvature
return InversionEnabler(curv, self._controller, precond)
def longest_step(self, dir):
return self._energy.longest_step(dir)
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..utilities import memo
from builtins import *
from ..utilities import memo, my_lincomb_simple, my_lincomb
from .energy import Energy
......@@ -55,26 +56,17 @@ class EnergySum(Energy):
@property
@memo
def value(self):
res = self._energies[0].value * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.value * f
return res
return my_lincomb_simple(map(lambda v: v.value, self._energies),
self._factors)
@property
@memo
def gradient(self):
res = self._energies[0].gradient.copy() if self._factors[0] == 1. \
else self._energies[0].gradient * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res += e.gradient if f == 1. else f*e.gradient
return res.lock()
return my_lincomb(map(lambda v: v.gradient, self._energies),
self._factors).lock()
@property
@memo
def curvature(self):
res = self._energies[0].curvature if self._factors[0] == 1. \
else self._energies[0].curvature * self._factors[0]
for e, f in zip(self._energies[1:], self._factors[1:]):
res = res + (e.curvature if f == 1. else e.curvature*f)
return res
return my_lincomb(map(lambda v: v.curvature, self._energies),
self._factors)
......@@ -16,15 +16,36 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import next, range
from builtins import *
import numpy as np
from itertools import product
import abc
from future.utils import with_metaclass
from functools import reduce
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMetaBase", "fft_prep", "hartley", "my_fftn_r2c",
"my_fftn"]
"my_fftn", "my_sum", "my_lincomb_simple", "my_lincomb",
"my_product"]
def my_sum(terms):
return reduce(lambda x, y: x+y, terms)
def my_lincomb_simple(terms, factors):
terms2 = map(lambda v: v[0]*v[1], zip(terms, factors))
return my_sum(terms2)
def my_lincomb(terms, factors):
terms2 = map(lambda v: v[0] if v[1] == 1. else v[0]*v[1],
zip(terms, factors))
return my_sum(terms2)
def my_product(iterable):
return reduce(lambda x, y: x*y, iterable)
def get_slice_list(shape, axes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment