Commit 35afcda0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch '31-rewrite-unitloggauss' into 'NIFTy_5'

Resolve "Rewrite UnitLogGauss"

Closes #31

See merge request ift/nifty-dev!21
parents 0de313ef 2dfc4494
...@@ -172,7 +172,7 @@ ...@@ -172,7 +172,7 @@
" tol_abs_gradnorm=0.1)\n", " tol_abs_gradnorm=0.1)\n",
" # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n", " # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n",
" # helper methods.\n", " # helper methods.\n",
" return ift.library.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)" " return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)"
] ]
}, },
{ {
......
...@@ -2,7 +2,9 @@ import nifty5 as ift ...@@ -2,7 +2,9 @@ import nifty5 as ift
import numpy as np import numpy as np
from global_newton.models_other.apply_data import ApplyData from global_newton.models_other.apply_data import ApplyData
from global_newton.models_energy.hamiltonian import Hamiltonian from global_newton.models_energy.hamiltonian import Hamiltonian
from nifty5.library.unit_log_gauss import UnitLogGauss from nifty5 import GaussianEnergy
if __name__ == '__main__': if __name__ == '__main__':
# s_space = ift.RGSpace([1024]) # s_space = ift.RGSpace([1024])
s_space = ift.RGSpace([128,128]) s_space = ift.RGSpace([128,128])
...@@ -45,7 +47,7 @@ if __name__ == '__main__': ...@@ -45,7 +47,7 @@ if __name__ == '__main__':
NWR = ApplyData(data, ift.Field(d_space,val=noise), Rs) NWR = ApplyData(data, ift.Field(d_space,val=noise), Rs)
INITIAL_POSITION = ift.from_random('normal',total_domain) INITIAL_POSITION = ift.from_random('normal',total_domain)
likelihood = UnitLogGauss(INITIAL_POSITION, NWR) likelihood = GaussianEnergy(INITIAL_POSITION, NWR)
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3) IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC) inverter = ift.ConjugateGradient(controller=IC)
......
...@@ -16,7 +16,7 @@ from .minimization import * ...@@ -16,7 +16,7 @@ from .minimization import *
from .sugar import * from .sugar import *
from .plotting.plot import plot from .plotting.plot import plot
from . import library from .library import *
from . import extra from . import extra
from .utilities import memo from .utilities import memo
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..library.gaussian_energy import GaussianEnergy
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..operators import InversionEnabler, SamplingEnabler
from ..models.variable import Variable from ..models.variable import Variable
from ..operators import InversionEnabler, SamplingEnabler
from ..utilities import memo from ..utilities import memo
from ..library.unit_log_gauss import UnitLogGauss
class Hamiltonian(Energy): class Hamiltonian(Energy):
...@@ -15,11 +33,8 @@ class Hamiltonian(Energy): ...@@ -15,11 +33,8 @@ class Hamiltonian(Energy):
super(Hamiltonian, self).__init__(lh.position) super(Hamiltonian, self).__init__(lh.position)
self._lh = lh self._lh = lh
self._ic = iteration_controller self._ic = iteration_controller
if iteration_controller_sampling is None: self._ic_samp = iteration_controller_sampling
self._ic_samp = iteration_controller self._prior = GaussianEnergy(Variable(self.position))
else:
self._ic_samp = iteration_controller_sampling
self._prior = UnitLogGauss(Variable(self.position))
self._precond = self._prior.curvature self._precond = self._prior.curvature
def at(self, position): def at(self, position):
...@@ -39,8 +54,11 @@ class Hamiltonian(Energy): ...@@ -39,8 +54,11 @@ class Hamiltonian(Energy):
@memo @memo
def curvature(self): def curvature(self):
prior_curv = self._prior.curvature prior_curv = self._prior.curvature
c = SamplingEnabler(self._lh.curvature, prior_curv.inverse, if self._ic_samp is None:
self._ic_samp, prior_curv.inverse) c = self._lh.curvature + prior_curv
else:
c = SamplingEnabler(self._lh.curvature, prior_curv.inverse,
self._ic_samp, prior_curv.inverse)
return InversionEnabler(c, self._ic, self._precond) return InversionEnabler(c, self._ic, self._precond)
def __str__(self): def __str__(self):
......
from .amplitude_model import make_amplitude_model from .amplitude_model import make_amplitude_model
from .apply_data import ApplyData from .apply_data import ApplyData
from .gaussian_energy import GaussianEnergy
from .los_response import LOSResponse from .los_response import LOSResponse
from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy
from .unit_log_gauss import UnitLogGauss
from .point_sources import PointSources from .point_sources import PointSources
from .poisson_log_likelihood import PoissonLogLikelihood from .poissonian_energy import PoissonianEnergy
from .smooth_sky import make_smooth_mf_sky_model, make_smooth_sky_model from .smooth_sky import make_smooth_mf_sky_model, make_smooth_sky_model
from .wiener_filter_curvature import WienerFilterCurvature from .wiener_filter_curvature import WienerFilterCurvature
from .wiener_filter_energy import WienerFilterEnergy from .wiener_filter_energy import WienerFilterEnergy
def ApplyData(data, var, model_data): def ApplyData(data, var, model_data):
from .. import DiagonalOperator, Constant, sqrt # TODO This is rather confusing. Delete that eventually.
from ..operators.diagonal_operator import DiagonalOperator
from ..models.constant import Constant
from ..sugar import sqrt
sqrt_n = DiagonalOperator(sqrt(var)) sqrt_n = DiagonalOperator(sqrt(var))
data = Constant(model_data.position, data) data = Constant(model_data.position, data)
return sqrt_n.inverse(model_data - data) return sqrt_n.inverse(model_data - data)
...@@ -17,45 +17,50 @@ ...@@ -17,45 +17,50 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..operators.inversion_enabler import InversionEnabler
from ..operators.sandwich_operator import SandwichOperator from ..operators.sandwich_operator import SandwichOperator
from ..utilities import memo from ..utilities import memo
class UnitLogGauss(Energy): class GaussianEnergy(Energy):
def __init__(self, s, inverter=None): def __init__(self, inp, mean=None, covariance=None):
""" """
s: Sky model object inp: Model object
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance covariance
""" """
super(UnitLogGauss, self).__init__(s.position) super(GaussianEnergy, self).__init__(inp.position)
self._s = s self._inp = inp
self._inverter = inverter self._mean = mean
self._cov = covariance
def at(self, position): def at(self, position):
return self.__class__(self._s.at(position), self._inverter) return self.__class__(self._inp.at(position), self._mean, self._cov)
@property @property
@memo @memo
def _gradient_helper(self): def residual(self):
return self._s.gradient if self._mean is not None:
return self._inp.value - self._mean
return self._inp.value
@property @property
@memo @memo
def value(self): def value(self):
return .5 * self._s.value.squared_norm() if self._cov is None:
return .5 * self.residual.vdot(self.residual).real
return .5 * self.residual.vdot(self._cov.inverse(self.residual)).real
@property @property
@memo @memo
def gradient(self): def gradient(self):
return self._gradient_helper.adjoint(self._s.value) if self._cov is None:
return self._inp.gradient.adjoint(self.residual)
return self._inp.gradient.adjoint(self._cov.inverse(self.residual))
@property @property
@memo @memo
def curvature(self): def curvature(self):
c = SandwichOperator.make(self._gradient_helper) if self._cov is None:
if self._inverter is None: return SandwichOperator.make(self._inp.gradient, None)
return c return SandwichOperator.make(self._inp.gradient, self._cov.inverse)
return InversionEnabler(c, self._inverter)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..models.constant import Constant
from .unit_log_gauss import UnitLogGauss
from ..energies.hamiltonian import Hamiltonian
def NonlinearWienerFilterEnergy(measured_data, data_model, sqrtN, iteration_controller):
d = measured_data.lock()
residual = Constant(data_model.position, d) - data_model
lh = UnitLogGauss(sqrtN.inverse(residual))
return Hamiltonian(lh, iteration_controller)
...@@ -23,7 +23,7 @@ from ..operators.sandwich_operator import SandwichOperator ...@@ -23,7 +23,7 @@ from ..operators.sandwich_operator import SandwichOperator
from ..sugar import log, makeOp from ..sugar import log, makeOp
class PoissonLogLikelihood(Energy): class PoissonianEnergy(Energy):
def __init__(self, lamb, d): def __init__(self, lamb, d):
""" """
lamb: Sky model object lamb: Sky model object
...@@ -31,7 +31,7 @@ class PoissonLogLikelihood(Energy): ...@@ -31,7 +31,7 @@ class PoissonLogLikelihood(Energy):
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance covariance
""" """
super(PoissonLogLikelihood, self).__init__(lamb.position) super(PoissonianEnergy, self).__init__(lamb.position)
self._lamb = lamb self._lamb = lamb
self._d = d self._d = d
......
...@@ -164,10 +164,25 @@ class MultiField(object): ...@@ -164,10 +164,25 @@ class MultiField(object):
def __neg__(self): def __neg__(self):
return MultiField({key: -val for key, val in self.items()}) return MultiField({key: -val for key, val in self.items()})
def __abs__(self):
return MultiField({key: abs(val) for key, val in self.items()})
def conjugate(self): def conjugate(self):
return MultiField({key: sub_field.conjugate() return MultiField({key: sub_field.conjugate()
for key, sub_field in self.items()}) for key, sub_field in self.items()})
def all(self):
for v in self.values():
if not v.all():
return False
return True
def any(self):
for v in self.values():
if v.any():
return True
return False
def isEquivalentTo(self, other): def isEquivalentTo(self, other):
"""Determines (as quickly as possible) whether `self`'s content is """Determines (as quickly as possible) whether `self`'s content is
identical to `other`'s content.""" identical to `other`'s content."""
......
...@@ -57,7 +57,7 @@ class Energy_Tests(unittest.TestCase): ...@@ -57,7 +57,7 @@ class Energy_Tests(unittest.TestCase):
tol_abs_gradnorm=1e-5) tol_abs_gradnorm=1e-5)
S = ift.create_power_operator(hspace, power_spectrum=_flat_PS) S = ift.create_power_operator(hspace, power_spectrum=_flat_PS)
energy = ift.library.WienerFilterEnergy( energy = ift.WienerFilterEnergy(
position=s0, d=d, R=R, N=N, S=S, iteration_controller=IC) position=s0, d=d, R=R, N=N, S=S, iteration_controller=IC)
ift.extra.check_value_gradient_curvature_consistency( ift.extra.check_value_gradient_curvature_consistency(
energy, ntries=10) energy, ntries=10)
...@@ -66,10 +66,10 @@ class Energy_Tests(unittest.TestCase): ...@@ -66,10 +66,10 @@ class Energy_Tests(unittest.TestCase):
ift.RGSpace(64, distances=.789), ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)], ift.RGSpace([32, 32], distances=.789)],
[ift.Tanh, ift.Exponential, ift.Linear], [ift.Tanh, ift.Exponential, ift.Linear],
[1, 1e-2, 1e2],
[4, 78, 23])) [4, 78, 23]))
def testNonlinearMap(self, space, nonlinearity, seed): def testGaussianEnergy(self, space, nonlinearity, noise, seed):
np.random.seed(seed) np.random.seed(seed)
f = nonlinearity()
dim = len(space.shape) dim = len(space.shape)
hspace = space.get_default_codomain() hspace = space.get_default_codomain()
ht = ift.HarmonicTransformOperator(hspace, target=space) ht = ift.HarmonicTransformOperator(hspace, target=space)
...@@ -77,23 +77,23 @@ class Energy_Tests(unittest.TestCase): ...@@ -77,23 +77,23 @@ class Energy_Tests(unittest.TestCase):
pspace = ift.PowerSpace(hspace, binbounds=binbounds) pspace = ift.PowerSpace(hspace, binbounds=binbounds)
Dist = ift.PowerDistributor(target=hspace, power_space=pspace) Dist = ift.PowerDistributor(target=hspace, power_space=pspace)
xi0 = ift.Field.from_random(domain=hspace, random_type='normal') xi0 = ift.Field.from_random(domain=hspace, random_type='normal')
xi0_var = ift.Variable(ift.MultiField({'xi':xi0}))['xi'] xi0_var = ift.Variable(ift.MultiField({'xi': xi0}))['xi']
def pspec(k): return 1 / (1 + k**2)**dim def pspec(k): return 1 / (1 + k**2)**dim
pspec = ift.PS_field(pspace, pspec) pspec = ift.PS_field(pspace, pspec)
A = Dist(ift.sqrt(pspec)) A = Dist(ift.sqrt(pspec))
n = ift.Field.from_random(domain=space, random_type='normal') N = ift.ScalingOperator(noise, space)
n = N.draw_sample()
s = ht(ift.makeOp(A)(xi0_var)) s = ht(ift.makeOp(A)(xi0_var))
R = ift.ScalingOperator(10., space) R = ift.ScalingOperator(10., space)
sqrtN = ift.ScalingOperator(1., space)
d_model = R(ift.LocalModel(s, nonlinearity())) d_model = R(ift.LocalModel(s, nonlinearity()))
d = d_model.value + n d = d_model.value + n
IC = ift.GradientNormController(iteration_limit=100, if noise == 1:
tol_abs_gradnorm=1e-5) N = None
energy = ift.library.NonlinearWienerFilterEnergy(
d, d_model, sqrtN, IC) energy = ift.GaussianEnergy(d_model, d, N)
if isinstance(nonlinearity, ift.Linear): if isinstance(nonlinearity(), ift.Linear):
ift.extra.check_value_gradient_curvature_consistency( ift.extra.check_value_gradient_curvature_consistency(
energy, ntries=10) energy, ntries=10)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment