Commit 0eaf078f authored by Theo Steininger's avatar Theo Steininger

Fixed bugs in LogNormalWienerFilterCurvature

parent 106d199b
Pipeline #15809 failed with stage
in 8 minutes and 8 seconds
......@@ -19,8 +19,8 @@ if __name__ == "__main__":
# Setting up the geometry |\label{code:wf_geometry}|
L = 2. # Total side-length of the domain
N_pixels = 128 # Grid resolution (pixels per axis)
signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
#signal_space = HPSpace(16)
#signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
signal_space = HPSpace(16)
harmonic_space = FFTOperator.get_default_codomain(signal_space)
fft = FFTOperator(harmonic_space, target=signal_space, target_dtype=np.float)
power_space = PowerSpace(harmonic_space)
......@@ -49,26 +49,26 @@ if __name__ == "__main__":
energy = library.LogNormalWienerFilterEnergy(m0, data, R_harmonic, N, S)
minimizer1 = VL_BFGS(convergence_tolerance=1e-4,
iteration_limit=1000,
minimizer1 = VL_BFGS(convergence_tolerance=1e-5,
iteration_limit=3000,
#callback=convergence_measure,
max_history_length=3)
max_history_length=20)
minimizer2 = RelaxedNewton(convergence_tolerance=1e-4,
iteration_limit=1,
minimizer2 = RelaxedNewton(convergence_tolerance=1e-5,
iteration_limit=10,
#callback=convergence_measure
)
minimizer3 = SteepestDescent(convergence_tolerance=1e-4, iteration_limit=1000)
minimizer3 = SteepestDescent(convergence_tolerance=1e-5, iteration_limit=1000)
me1 = minimizer1(energy)
me2 = minimizer2(energy)
me3 = minimizer3(energy)
m1 = fft(me1[0].position)
m2 = fft(me2[0].position)
m3 = fft(me3[0].position)
# me1 = minimizer1(energy)
# me2 = minimizer2(energy)
# me3 = minimizer3(energy)
# m1 = fft(me1[0].position)
# m2 = fft(me2[0].position)
# m3 = fft(me3[0].position)
#
# # Probing the variance
......@@ -80,20 +80,20 @@ if __name__ == "__main__":
# variance = sm(proby.diagonal.weight(-1))
#Plotting #|\label{code:wf_plotting}|
plotter = plotting.RG2DPlotter(color_map=plotting.colormaps.PlankCmap())
#plotter = plotting.HealpixPlotter(color_map=plotting.colormaps.PlankCmap())
#plotter = plotting.RG2DPlotter(color_map=plotting.colormaps.PlankCmap())
plotter = plotting.HealpixPlotter(color_map=plotting.colormaps.PlankCmap())
plotter.figure.xaxis = plotting.Axis(label='Pixel Index')
plotter.figure.yaxis = plotting.Axis(label='Pixel Index')
plotter.plot.zmax = 5; plotter.plot.zmin = -5
# plotter(variance, path = 'variance.html')
#plotter.plot.zmin = exp(mock_signal.min());
plotter(mock_signal.real, path='mock_signal.html')
plotter(Field(signal_space, val=np.log(data.val.get_full_data().real).reshape(signal_space.shape)),
path = 'log_of_data.html')
plotter(m1.real, path='m_LBFGS.html')
plotter(m2.real, path='m_Newton.html')
plotter(m3.real, path='m_SteepestDescent.html')
## plotter(variance, path = 'variance.html')
# #plotter.plot.zmin = exp(mock_signal.min());
# plotter(mock_signal.real, path='mock_signal.html')
# plotter(Field(signal_space, val=np.log(data.val.get_full_data().real).reshape(signal_space.shape)),
# path = 'log_of_data.html')
#
# plotter(m1.real, path='m_LBFGS.html')
# plotter(m2.real, path='m_Newton.html')
# plotter(m3.real, path='m_SteepestDescent.html')
#
import numpy as np
from nifty.energies.energy import Energy
from nifty.operators.smoothness_operator import SmoothnessOperator
......
......@@ -2,6 +2,7 @@ from nifty.operators import EndomorphicOperator,\
InvertibleOperatorMixin
from nifty.energies.memoization import memo
from nifty.basic_arithmetics import clipped_exp
from nifty.sugar import create_composed_fft_operator
class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
......@@ -25,7 +26,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
"""
def __init__(self, R, N, S, d, position, inverter=None,
preconditioner=None, **kwargs):
preconditioner=None, fft4exp=None, **kwargs):
self._cache = {}
self.R = R
self.N = N
......@@ -35,6 +36,13 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
if preconditioner is None:
preconditioner = self.S.times
self._domain = self.S.domain
if fft4exp is None:
self._fft = create_composed_fft_operator(self.domain,
all_to='position')
else:
self._fft = fft4exp
super(LogNormalWienerFilterCurvature, self).__init__(
inverter=inverter,
preconditioner=preconditioner,
......@@ -57,14 +65,22 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
def _times(self, x, spaces):
part1 = self.S.inverse_times(x)
# part2 = self._exppRNRexppd * x
part3 = self._expp * self.R.adjoint_times(
self.N.inverse_times(self.R(self._expp * x)))
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3)))))
return part1 + part3 # + part2
@property
@memo
def _expp_sspace(self):
return clipped_exp(self._fft(self.position))
@property
@memo
def _expp(self):
return clipped_exp(self.position)
return self._fft.adjoint_times(self._expp_sspace)
@property
@memo
......@@ -79,4 +95,6 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
@property
@memo
def _exppRNRexppd(self):
return self._expp * self.R.adjoint_times(self._NRexppd)
return self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(self._NRexppd)))
......@@ -2,6 +2,7 @@ from nifty.energies.energy import Energy
from nifty.energies.memoization import memo
from nifty.library.log_normal_wiener_filter import \
LogNormalWienerFilterCurvature
from nifty.sugar import create_composed_fft_operator
class LogNormalWienerFilterEnergy(Energy):
......@@ -24,16 +25,22 @@ class LogNormalWienerFilterEnergy(Energy):
The prior signal covariance in harmonic space.
"""
def __init__(self, position, d, R, N, S):
def __init__(self, position, d, R, N, S, fft4exp=None):
super(LogNormalWienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
self.N = N
self.S = S
if fft4exp is None:
self._fft = create_composed_fft_operator(self.S.domain,
all_to='position')
else:
self._fft = fft4exp
def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S)
S=self.S, fft4exp=self._fft)
@property
@memo
......@@ -50,7 +57,8 @@ class LogNormalWienerFilterEnergy(Energy):
@memo
def curvature(self):
return LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S,
d=self.d, position=self.position)
d=self.d, position=self.position,
fft4exp=self._fft)
@property
def _expp(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment