Commit 0eaf078f authored by Theo Steininger's avatar Theo Steininger

Fixed bugs in LogNormalWienerFilterCurvature

parent 106d199b
Pipeline #15809 failed with stage
in 8 minutes and 8 seconds
...@@ -19,8 +19,8 @@ if __name__ == "__main__": ...@@ -19,8 +19,8 @@ if __name__ == "__main__":
# Setting up the geometry |\label{code:wf_geometry}| # Setting up the geometry |\label{code:wf_geometry}|
L = 2. # Total side-length of the domain L = 2. # Total side-length of the domain
N_pixels = 128 # Grid resolution (pixels per axis) N_pixels = 128 # Grid resolution (pixels per axis)
signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels) #signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
#signal_space = HPSpace(16) signal_space = HPSpace(16)
harmonic_space = FFTOperator.get_default_codomain(signal_space) harmonic_space = FFTOperator.get_default_codomain(signal_space)
fft = FFTOperator(harmonic_space, target=signal_space, target_dtype=np.float) fft = FFTOperator(harmonic_space, target=signal_space, target_dtype=np.float)
power_space = PowerSpace(harmonic_space) power_space = PowerSpace(harmonic_space)
...@@ -49,26 +49,26 @@ if __name__ == "__main__": ...@@ -49,26 +49,26 @@ if __name__ == "__main__":
energy = library.LogNormalWienerFilterEnergy(m0, data, R_harmonic, N, S) energy = library.LogNormalWienerFilterEnergy(m0, data, R_harmonic, N, S)
minimizer1 = VL_BFGS(convergence_tolerance=1e-4, minimizer1 = VL_BFGS(convergence_tolerance=1e-5,
iteration_limit=1000, iteration_limit=3000,
#callback=convergence_measure, #callback=convergence_measure,
max_history_length=3) max_history_length=20)
minimizer2 = RelaxedNewton(convergence_tolerance=1e-4, minimizer2 = RelaxedNewton(convergence_tolerance=1e-5,
iteration_limit=1, iteration_limit=10,
#callback=convergence_measure #callback=convergence_measure
) )
minimizer3 = SteepestDescent(convergence_tolerance=1e-4, iteration_limit=1000) minimizer3 = SteepestDescent(convergence_tolerance=1e-5, iteration_limit=1000)
me1 = minimizer1(energy) # me1 = minimizer1(energy)
me2 = minimizer2(energy) # me2 = minimizer2(energy)
me3 = minimizer3(energy) # me3 = minimizer3(energy)
m1 = fft(me1[0].position)
m2 = fft(me2[0].position)
m3 = fft(me3[0].position)
# m1 = fft(me1[0].position)
# m2 = fft(me2[0].position)
# m3 = fft(me3[0].position)
#
# # Probing the variance # # Probing the variance
...@@ -80,20 +80,20 @@ if __name__ == "__main__": ...@@ -80,20 +80,20 @@ if __name__ == "__main__":
# variance = sm(proby.diagonal.weight(-1)) # variance = sm(proby.diagonal.weight(-1))
#Plotting #|\label{code:wf_plotting}| #Plotting #|\label{code:wf_plotting}|
plotter = plotting.RG2DPlotter(color_map=plotting.colormaps.PlankCmap()) #plotter = plotting.RG2DPlotter(color_map=plotting.colormaps.PlankCmap())
#plotter = plotting.HealpixPlotter(color_map=plotting.colormaps.PlankCmap()) plotter = plotting.HealpixPlotter(color_map=plotting.colormaps.PlankCmap())
plotter.figure.xaxis = plotting.Axis(label='Pixel Index') plotter.figure.xaxis = plotting.Axis(label='Pixel Index')
plotter.figure.yaxis = plotting.Axis(label='Pixel Index') plotter.figure.yaxis = plotting.Axis(label='Pixel Index')
plotter.plot.zmax = 5; plotter.plot.zmin = -5 plotter.plot.zmax = 5; plotter.plot.zmin = -5
# plotter(variance, path = 'variance.html') ## plotter(variance, path = 'variance.html')
#plotter.plot.zmin = exp(mock_signal.min()); # #plotter.plot.zmin = exp(mock_signal.min());
plotter(mock_signal.real, path='mock_signal.html') # plotter(mock_signal.real, path='mock_signal.html')
plotter(Field(signal_space, val=np.log(data.val.get_full_data().real).reshape(signal_space.shape)), # plotter(Field(signal_space, val=np.log(data.val.get_full_data().real).reshape(signal_space.shape)),
path = 'log_of_data.html') # path = 'log_of_data.html')
#
plotter(m1.real, path='m_LBFGS.html') # plotter(m1.real, path='m_LBFGS.html')
plotter(m2.real, path='m_Newton.html') # plotter(m2.real, path='m_Newton.html')
plotter(m3.real, path='m_SteepestDescent.html') # plotter(m3.real, path='m_SteepestDescent.html')
#
import numpy as np
from nifty.energies.energy import Energy from nifty.energies.energy import Energy
from nifty.operators.smoothness_operator import SmoothnessOperator from nifty.operators.smoothness_operator import SmoothnessOperator
......
...@@ -2,6 +2,7 @@ from nifty.operators import EndomorphicOperator,\ ...@@ -2,6 +2,7 @@ from nifty.operators import EndomorphicOperator,\
InvertibleOperatorMixin InvertibleOperatorMixin
from nifty.energies.memoization import memo from nifty.energies.memoization import memo
from nifty.basic_arithmetics import clipped_exp from nifty.basic_arithmetics import clipped_exp
from nifty.sugar import create_composed_fft_operator
class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
...@@ -25,7 +26,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, ...@@ -25,7 +26,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
""" """
def __init__(self, R, N, S, d, position, inverter=None, def __init__(self, R, N, S, d, position, inverter=None,
preconditioner=None, **kwargs): preconditioner=None, fft4exp=None, **kwargs):
self._cache = {} self._cache = {}
self.R = R self.R = R
self.N = N self.N = N
...@@ -35,6 +36,13 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, ...@@ -35,6 +36,13 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
if preconditioner is None: if preconditioner is None:
preconditioner = self.S.times preconditioner = self.S.times
self._domain = self.S.domain self._domain = self.S.domain
if fft4exp is None:
self._fft = create_composed_fft_operator(self.domain,
all_to='position')
else:
self._fft = fft4exp
super(LogNormalWienerFilterCurvature, self).__init__( super(LogNormalWienerFilterCurvature, self).__init__(
inverter=inverter, inverter=inverter,
preconditioner=preconditioner, preconditioner=preconditioner,
...@@ -57,14 +65,22 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, ...@@ -57,14 +65,22 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
def _times(self, x, spaces): def _times(self, x, spaces):
part1 = self.S.inverse_times(x) part1 = self.S.inverse_times(x)
# part2 = self._exppRNRexppd * x # part2 = self._exppRNRexppd * x
part3 = self._expp * self.R.adjoint_times( part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
self.N.inverse_times(self.R(self._expp * x))) part3 = self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3)))))
return part1 + part3 # + part2 return part1 + part3 # + part2
@property
@memo
def _expp_sspace(self):
return clipped_exp(self._fft(self.position))
@property @property
@memo @memo
def _expp(self): def _expp(self):
return clipped_exp(self.position) return self._fft.adjoint_times(self._expp_sspace)
@property @property
@memo @memo
...@@ -79,4 +95,6 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, ...@@ -79,4 +95,6 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
@property @property
@memo @memo
def _exppRNRexppd(self): def _exppRNRexppd(self):
return self._expp * self.R.adjoint_times(self._NRexppd) return self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(self._NRexppd)))
...@@ -2,6 +2,7 @@ from nifty.energies.energy import Energy ...@@ -2,6 +2,7 @@ from nifty.energies.energy import Energy
from nifty.energies.memoization import memo from nifty.energies.memoization import memo
from nifty.library.log_normal_wiener_filter import \ from nifty.library.log_normal_wiener_filter import \
LogNormalWienerFilterCurvature LogNormalWienerFilterCurvature
from nifty.sugar import create_composed_fft_operator
class LogNormalWienerFilterEnergy(Energy): class LogNormalWienerFilterEnergy(Energy):
...@@ -24,16 +25,22 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -24,16 +25,22 @@ class LogNormalWienerFilterEnergy(Energy):
The prior signal covariance in harmonic space. The prior signal covariance in harmonic space.
""" """
def __init__(self, position, d, R, N, S): def __init__(self, position, d, R, N, S, fft4exp=None):
super(LogNormalWienerFilterEnergy, self).__init__(position=position) super(LogNormalWienerFilterEnergy, self).__init__(position=position)
self.d = d self.d = d
self.R = R self.R = R
self.N = N self.N = N
self.S = S self.S = S
if fft4exp is None:
self._fft = create_composed_fft_operator(self.S.domain,
all_to='position')
else:
self._fft = fft4exp
def at(self, position): def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N, return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S) S=self.S, fft4exp=self._fft)
@property @property
@memo @memo
...@@ -50,7 +57,8 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -50,7 +57,8 @@ class LogNormalWienerFilterEnergy(Energy):
@memo @memo
def curvature(self): def curvature(self):
return LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S, return LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S,
d=self.d, position=self.position) d=self.d, position=self.position,
fft4exp=self._fft)
@property @property
def _expp(self): def _expp(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment