Commit 743626ab authored by Philipp Arras's avatar Philipp Arras
Browse files

Rename unitloggauss -> gaussianenergy

parent cb66a789
......@@ -2,7 +2,9 @@ import nifty5 as ift
import numpy as np
from global_newton.models_other.apply_data import ApplyData
from global_newton.models_energy.hamiltonian import Hamiltonian
from nifty5.library.unit_log_gauss import UnitLogGauss
from nifty5.library.gaussian_energy import GaussianEnergy
if __name__ == '__main__':
# s_space = ift.RGSpace([1024])
s_space = ift.RGSpace([128,128])
......@@ -45,7 +47,7 @@ if __name__ == '__main__':
NWR = ApplyData(data, ift.Field(d_space,val=noise), Rs)
INITIAL_POSITION = ift.from_random('normal',total_domain)
likelihood = UnitLogGauss(INITIAL_POSITION, NWR)
likelihood = GaussianEnergy(INITIAL_POSITION, NWR)
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC)
......
......@@ -2,7 +2,7 @@ from ..minimization.energy import Energy
from ..operators import InversionEnabler, SamplingEnabler
from ..models.variable import Variable
from ..utilities import memo
from ..library.unit_log_gauss import UnitLogGauss
from ..library.gaussian_energy import GaussianEnergy
class Hamiltonian(Energy):
......@@ -19,7 +19,7 @@ class Hamiltonian(Energy):
self._ic_samp = iteration_controller
else:
self._ic_samp = iteration_controller_sampling
self._prior = UnitLogGauss(Variable(self.position))
self._prior = GaussianEnergy(Variable(self.position))
self._precond = self._prior.curvature
def at(self, position):
......
......@@ -2,7 +2,7 @@ from .amplitude_model import make_amplitude_model
from .apply_data import ApplyData
from .los_response import LOSResponse
from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy
from .unit_log_gauss import UnitLogGauss
from .gaussian_energy import GaussianEnergy
from .point_sources import PointSources
from .poisson_log_likelihood import PoissonLogLikelihood
from .smooth_sky import make_smooth_mf_sky_model, make_smooth_sky_model
......
......@@ -22,7 +22,7 @@ from ..operators.sandwich_operator import SandwichOperator
from ..utilities import memo
class UnitLogGauss(Energy):
class GaussianEnergy(Energy):
def __init__(self, s, inverter=None):
"""
s: Sky model object
......@@ -30,7 +30,7 @@ class UnitLogGauss(Energy):
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance
"""
super(UnitLogGauss, self).__init__(s.position)
super(GaussianEnergy, self).__init__(s.position)
self._s = s
self._inverter = inverter
......
......@@ -17,12 +17,12 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..models.constant import Constant
from .unit_log_gauss import UnitLogGauss
from .gaussian_energy import GaussianEnergy
from ..energies.hamiltonian import Hamiltonian
def NonlinearWienerFilterEnergy(measured_data, data_model, sqrtN, iteration_controller):
d = measured_data.lock()
residual = Constant(data_model.position, d) - data_model
lh = UnitLogGauss(sqrtN.inverse(residual))
lh = GaussianEnergy(sqrtN.inverse(residual))
return Hamiltonian(lh, iteration_controller)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment