Commit 35afcda0 by Martin Reinecke

### Merge branch '31-rewrite-unitloggauss' into 'NIFTy_5'

```Resolve "Rewrite UnitLogGauss"

Closes #31

See merge request ift/nifty-dev!21```
parents 0de313ef 2dfc4494
 ... @@ -172,7 +172,7 @@ ... @@ -172,7 +172,7 @@ " tol_abs_gradnorm=0.1)\n", " tol_abs_gradnorm=0.1)\n", " # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n", " # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n", " # helper methods.\n", " # helper methods.\n", " return ift.library.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)" " return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)" ] ] }, }, { { ... ...
 ... @@ -2,7 +2,9 @@ import nifty5 as ift ... @@ -2,7 +2,9 @@ import nifty5 as ift import numpy as np import numpy as np from global_newton.models_other.apply_data import ApplyData from global_newton.models_other.apply_data import ApplyData from global_newton.models_energy.hamiltonian import Hamiltonian from global_newton.models_energy.hamiltonian import Hamiltonian from nifty5.library.unit_log_gauss import UnitLogGauss from nifty5 import GaussianEnergy if __name__ == '__main__': if __name__ == '__main__': # s_space = ift.RGSpace([1024]) # s_space = ift.RGSpace([1024]) s_space = ift.RGSpace([128,128]) s_space = ift.RGSpace([128,128]) ... @@ -45,7 +47,7 @@ if __name__ == '__main__': ... @@ -45,7 +47,7 @@ if __name__ == '__main__': NWR = ApplyData(data, ift.Field(d_space,val=noise), Rs) NWR = ApplyData(data, ift.Field(d_space,val=noise), Rs) INITIAL_POSITION = ift.from_random('normal',total_domain) INITIAL_POSITION = ift.from_random('normal',total_domain) likelihood = UnitLogGauss(INITIAL_POSITION, NWR) likelihood = GaussianEnergy(INITIAL_POSITION, NWR) IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3) IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3) inverter = ift.ConjugateGradient(controller=IC) inverter = ift.ConjugateGradient(controller=IC) ... ...
 ... @@ -16,7 +16,7 @@ from .minimization import * ... @@ -16,7 +16,7 @@ from .minimization import * from .sugar import * from .sugar import * from .plotting.plot import plot from .plotting.plot import plot from . import library from .library import * from . import extra from . import extra from .utilities import memo from .utilities import memo ... ...
 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2018 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. from ..library.gaussian_energy import GaussianEnergy from ..minimization.energy import Energy from ..minimization.energy import Energy from ..operators import InversionEnabler, SamplingEnabler from ..models.variable import Variable from ..models.variable import Variable from ..operators import InversionEnabler, SamplingEnabler from ..utilities import memo from ..utilities import memo from ..library.unit_log_gauss import UnitLogGauss class Hamiltonian(Energy): class Hamiltonian(Energy): ... @@ -15,11 +33,8 @@ class Hamiltonian(Energy): ... @@ -15,11 +33,8 @@ class Hamiltonian(Energy): super(Hamiltonian, self).__init__(lh.position) super(Hamiltonian, self).__init__(lh.position) self._lh = lh self._lh = lh self._ic = iteration_controller self._ic = iteration_controller if iteration_controller_sampling is None: self._ic_samp = iteration_controller_sampling self._ic_samp = iteration_controller self._prior = GaussianEnergy(Variable(self.position)) else: self._ic_samp = iteration_controller_sampling self._prior = UnitLogGauss(Variable(self.position)) self._precond = self._prior.curvature self._precond = self._prior.curvature def at(self, position): def at(self, position): ... @@ -39,8 +54,11 @@ class Hamiltonian(Energy): ... @@ -39,8 +54,11 @@ class Hamiltonian(Energy): @memo @memo def curvature(self): def curvature(self): prior_curv = self._prior.curvature prior_curv = self._prior.curvature c = SamplingEnabler(self._lh.curvature, prior_curv.inverse, if self._ic_samp is None: self._ic_samp, prior_curv.inverse) c = self._lh.curvature + prior_curv else: c = SamplingEnabler(self._lh.curvature, prior_curv.inverse, self._ic_samp, prior_curv.inverse) return InversionEnabler(c, self._ic, self._precond) return InversionEnabler(c, self._ic, self._precond) def __str__(self): def __str__(self): ... ...
 from .amplitude_model import make_amplitude_model from .amplitude_model import make_amplitude_model from .apply_data import ApplyData from .apply_data import ApplyData from .gaussian_energy import GaussianEnergy from .los_response import LOSResponse from .los_response import LOSResponse from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy from .unit_log_gauss import UnitLogGauss from .point_sources import PointSources from .point_sources import PointSources from .poisson_log_likelihood import PoissonLogLikelihood from .poissonian_energy import PoissonianEnergy from .smooth_sky import make_smooth_mf_sky_model, make_smooth_sky_model from .smooth_sky import make_smooth_mf_sky_model, make_smooth_sky_model from .wiener_filter_curvature import WienerFilterCurvature from .wiener_filter_curvature import WienerFilterCurvature from .wiener_filter_energy import WienerFilterEnergy from .wiener_filter_energy import WienerFilterEnergy
 def ApplyData(data, var, model_data): def ApplyData(data, var, model_data): from .. import DiagonalOperator, Constant, sqrt # TODO This is rather confusing. Delete that eventually. from ..operators.diagonal_operator import DiagonalOperator from ..models.constant import Constant from ..sugar import sqrt sqrt_n = DiagonalOperator(sqrt(var)) sqrt_n = DiagonalOperator(sqrt(var)) data = Constant(model_data.position, data) data = Constant(model_data.position, data) return sqrt_n.inverse(model_data - data) return sqrt_n.inverse(model_data - data)