Skip to content
Snippets Groups Projects
Commit 743626ab authored by Philipp Arras's avatar Philipp Arras
Browse files

Rename unitloggauss -> gaussianenergy

parent cb66a789
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,9 @@ import nifty5 as ift ...@@ -2,7 +2,9 @@ import nifty5 as ift
import numpy as np import numpy as np
from global_newton.models_other.apply_data import ApplyData from global_newton.models_other.apply_data import ApplyData
from global_newton.models_energy.hamiltonian import Hamiltonian from global_newton.models_energy.hamiltonian import Hamiltonian
from nifty5.library.unit_log_gauss import UnitLogGauss from nifty5.library.gaussian_energy import GaussianEnergy
if __name__ == '__main__': if __name__ == '__main__':
# s_space = ift.RGSpace([1024]) # s_space = ift.RGSpace([1024])
s_space = ift.RGSpace([128,128]) s_space = ift.RGSpace([128,128])
...@@ -45,7 +47,7 @@ if __name__ == '__main__': ...@@ -45,7 +47,7 @@ if __name__ == '__main__':
NWR = ApplyData(data, ift.Field(d_space,val=noise), Rs) NWR = ApplyData(data, ift.Field(d_space,val=noise), Rs)
INITIAL_POSITION = ift.from_random('normal',total_domain) INITIAL_POSITION = ift.from_random('normal',total_domain)
likelihood = UnitLogGauss(INITIAL_POSITION, NWR) likelihood = GaussianEnergy(INITIAL_POSITION, NWR)
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3) IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC) inverter = ift.ConjugateGradient(controller=IC)
......
...@@ -2,7 +2,7 @@ from ..minimization.energy import Energy ...@@ -2,7 +2,7 @@ from ..minimization.energy import Energy
from ..operators import InversionEnabler, SamplingEnabler from ..operators import InversionEnabler, SamplingEnabler
from ..models.variable import Variable from ..models.variable import Variable
from ..utilities import memo from ..utilities import memo
from ..library.unit_log_gauss import UnitLogGauss from ..library.gaussian_energy import GaussianEnergy
class Hamiltonian(Energy): class Hamiltonian(Energy):
...@@ -19,7 +19,7 @@ class Hamiltonian(Energy): ...@@ -19,7 +19,7 @@ class Hamiltonian(Energy):
self._ic_samp = iteration_controller self._ic_samp = iteration_controller
else: else:
self._ic_samp = iteration_controller_sampling self._ic_samp = iteration_controller_sampling
self._prior = UnitLogGauss(Variable(self.position)) self._prior = GaussianEnergy(Variable(self.position))
self._precond = self._prior.curvature self._precond = self._prior.curvature
def at(self, position): def at(self, position):
......
...@@ -2,7 +2,7 @@ from .amplitude_model import make_amplitude_model ...@@ -2,7 +2,7 @@ from .amplitude_model import make_amplitude_model
from .apply_data import ApplyData from .apply_data import ApplyData
from .los_response import LOSResponse from .los_response import LOSResponse
from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy from .nonlinear_wiener_filter_energy import NonlinearWienerFilterEnergy
from .unit_log_gauss import UnitLogGauss from .gaussian_energy import GaussianEnergy
from .point_sources import PointSources from .point_sources import PointSources
from .poisson_log_likelihood import PoissonLogLikelihood from .poisson_log_likelihood import PoissonLogLikelihood
from .smooth_sky import make_smooth_mf_sky_model, make_smooth_sky_model from .smooth_sky import make_smooth_mf_sky_model, make_smooth_sky_model
......
...@@ -22,7 +22,7 @@ from ..operators.sandwich_operator import SandwichOperator ...@@ -22,7 +22,7 @@ from ..operators.sandwich_operator import SandwichOperator
from ..utilities import memo from ..utilities import memo
class UnitLogGauss(Energy): class GaussianEnergy(Energy):
def __init__(self, s, inverter=None): def __init__(self, s, inverter=None):
""" """
s: Sky model object s: Sky model object
...@@ -30,7 +30,7 @@ class UnitLogGauss(Energy): ...@@ -30,7 +30,7 @@ class UnitLogGauss(Energy):
value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit value = 0.5 * s.vdot(s), i.e. a log-Gauss distribution with unit
covariance covariance
""" """
super(UnitLogGauss, self).__init__(s.position) super(GaussianEnergy, self).__init__(s.position)
self._s = s self._s = s
self._inverter = inverter self._inverter = inverter
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..models.constant import Constant from ..models.constant import Constant
from .unit_log_gauss import UnitLogGauss from .gaussian_energy import GaussianEnergy
from ..energies.hamiltonian import Hamiltonian from ..energies.hamiltonian import Hamiltonian
def NonlinearWienerFilterEnergy(measured_data, data_model, sqrtN, iteration_controller): def NonlinearWienerFilterEnergy(measured_data, data_model, sqrtN, iteration_controller):
d = measured_data.lock() d = measured_data.lock()
residual = Constant(data_model.position, d) - data_model residual = Constant(data_model.position, d) - data_model
lh = UnitLogGauss(sqrtN.inverse(residual)) lh = GaussianEnergy(sqrtN.inverse(residual))
return Hamiltonian(lh, iteration_controller) return Hamiltonian(lh, iteration_controller)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment