Commit 35afcda0 by Martin Reinecke

### Merge branch '31-rewrite-unitloggauss' into 'NIFTy_5'

```Resolve "Rewrite UnitLogGauss"

Closes #31

See merge request ift/nifty-dev!21```
parents 0de313ef 2dfc4494
 ... ... @@ -172,7 +172,7 @@ " tol_abs_gradnorm=0.1)\n", " # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n", " # helper methods.\n", " return ift.library.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)" " return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)" ] }, { ... ...
 ... ... @@ -2,7 +2,9 @@ import nifty5 as ift import numpy as np from global_newton.models_other.apply_data import ApplyData from global_newton.models_energy.hamiltonian import Hamiltonian from nifty5.library.unit_log_gauss import UnitLogGauss from nifty5 import GaussianEnergy if __name__ == '__main__': # s_space = ift.RGSpace([1024]) s_space = ift.RGSpace([128,128]) ... ... @@ -45,7 +47,7 @@ if __name__ == '__main__': NWR = ApplyData(data, ift.Field(d_space,val=noise), Rs) INITIAL_POSITION = ift.from_random('normal',total_domain) likelihood = UnitLogGauss(INITIAL_POSITION, NWR) likelihood = GaussianEnergy(INITIAL_POSITION, NWR) IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3) inverter = ift.ConjugateGradient(controller=IC) ... ...
 ... ... @@ -16,7 +16,7 @@ from .minimization import * from .sugar import * from .plotting.plot import plot from . import library from .library import * from . import extra from .utilities import memo ... ...
 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2018 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. from ..library.gaussian_energy import GaussianEnergy from ..minimization.energy import Energy from ..operators import InversionEnabler, SamplingEnabler from ..models.variable import Variable from ..operators import InversionEnabler, SamplingEnabler from ..utilities import memo from ..library.unit_log_gauss import UnitLogGauss class Hamiltonian(Energy): ... ... @@ -15,11 +33,8 @@ class Hamiltonian(Energy): super(Hamiltonian, self).__init__(lh.position) self._lh = lh self._ic = iteration_controller if iteration_controller_sampling is None: self._ic_samp = iteration_controller else: self._ic_samp = iteration_controller_sampling self._prior = UnitLogGauss(Variable(self.position)) self._prior = GaussianEnergy(Variable(self.position)) self._precond = self._prior.curvature def at(self, position): ... ... @@ -39,6 +54,9 @@ class Hamiltonian(Energy): @memo def curvature(self): prior_curv = self._prior.curvature if self._ic_samp is None: c = self._lh.curvature + prior_curv else: c = SamplingEnabler(self._lh.curvature, prior_curv.inverse, self._ic_samp, prior_curv.inverse) return InversionEnabler(c, self._ic, self._precond) ... ...
 from .amplitude_model import make_amplitude_model from .apply_data import ApplyData from .gaussian_energy import GaussianEnergy from .los_response import LOSResponse from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy from .unit_log_gauss import UnitLogGauss from .point_sources import PointSources from .poisson_log_likelihood import PoissonLogLikelihood from .poissonian_energy import PoissonianEnergy from .smooth_sky import make_smooth_mf_sky_model, make_smooth_sky_model from .wiener_filter_curvature import WienerFilterCurvature from .wiener_filter_energy import WienerFilterEnergy
 def ApplyData(data, var, model_data): from .. import DiagonalOperator, Constant, sqrt # TODO This is rather confusing. Delete that eventually. from ..operators.diagonal_operator import DiagonalOperator from ..models.constant import Constant from ..sugar import sqrt sqrt_n = DiagonalOperator(sqrt(var)) data = Constant(model_data.position, data) return sqrt_n.inverse(model_data - data)