From b2410f61d7821cbd8df0a5ce39645e8ec31a8860 Mon Sep 17 00:00:00 2001 From: Theo Steininger Date: Sun, 30 Jul 2017 03:24:27 +0200 Subject: [PATCH] Refactored LogNormalWienerFilterCurvature and LogNormalWienerFilterEnergy. Added log_normal_wiener_filter.py --- demos/log_normal_wiener_filter.py | 84 +++++++++++++++++++ .../log_normal_wiener_filter_curvature.py | 43 +++++++--- .../log_normal_wiener_filter_energy.py | 37 ++++---- 3 files changed, 137 insertions(+), 27 deletions(-) create mode 100644 demos/log_normal_wiener_filter.py diff --git a/demos/log_normal_wiener_filter.py b/demos/log_normal_wiener_filter.py new file mode 100644 index 00000000..7b4c6201 --- /dev/null +++ b/demos/log_normal_wiener_filter.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- + +from nifty import * + +if __name__ == "__main__": + nifty_configuration['default_distribution_strategy'] = 'fftw' + + # Setting up parameters |\label{code:wf_parameters}| + correlation_length = 1. # Typical distance over which the field is correlated + field_variance = 2. # Variance of field in position space + response_sigma = 0.05 # Smoothing length of response (in same unit as L) + signal_to_noise = 1.5 # The signal to noise ratio + np.random.seed(43) # Fixing the random seed + def power_spectrum(k): # Defining the power spectrum + a = 4 * correlation_length * field_variance**2 + return a / (1 + k * correlation_length) ** 4 + + # Setting up the geometry |\label{code:wf_geometry}| + L = 2. # Total side-length of the domain + N_pixels = 512 # Grid resolution (pixels per axis) + signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels) + harmonic_space = FFTOperator.get_default_codomain(signal_space) + fft = FFTOperator(harmonic_space, target=signal_space, target_dtype=np.float) + power_space = PowerSpace(harmonic_space) + + # Creating the mock signal |\label{code:wf_mock_signal}| + S = create_power_operator(harmonic_space, power_spectrum=power_spectrum) + mock_power = Field(power_space, val=power_spectrum) + mock_signal = fft(mock_power.power_synthesize(real_signal=True)) + + # Setting up an exemplary response + mask = Field(signal_space, val=1.) + N10 = int(N_pixels/10) + mask.val[N10*5:N10*9, N10*5:N10*9] = 0. + R = ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}| + data_domain = R.target[0] + R_harmonic = ComposedOperator([fft, R], default_spaces=[0, 0]) + + # Setting up the noise covariance and drawing a random noise realization + N = DiagonalOperator(data_domain, diagonal=mock_signal.var()/signal_to_noise, bare=True) + noise = Field.from_random(domain=data_domain, random_type='normal', + std=mock_signal.std()/np.sqrt(signal_to_noise), mean=0) + data = R(exp(mock_signal)) + noise #|\label{code:wf_mock_data}| + + # Wiener filter + m0 = Field(harmonic_space, val=0.j) + energy = library.LogNormalWienerFilterEnergy(m0, data, R_harmonic, N, S) + + + minimizer = VL_BFGS(convergence_tolerance=0, + iteration_limit=50, + #callback=convergence_measure, + max_history_length=3) + + minimizer = RelaxedNewton(convergence_tolerance=0, + iteration_limit=1, + #callback=convergence_measure + ) + +# # Probing the variance +# class Proby(DiagonalProberMixin, Prober): pass +# proby = Proby(signal_space, probe_count=100) +# proby(lambda z: fft(wiener_curvature.inverse_times(fft.inverse_times(z)))) +# +# sm = SmoothingOperator(signal_space, sigma=0.02) +# variance = sm(proby.diagonal.weight(-1)) + + #Plotting #|\label{code:wf_plotting}| + plotter = plotting.RG2DPlotter(color_map=plotting.colormaps.PlankCmap()) + plotter.figure.xaxis = plotting.Axis(label='Pixel Index') + plotter.figure.yaxis = plotting.Axis(label='Pixel Index') + + #plotter.plot.zmax = exp(mock_signal.max()); plotter.plot.zmin = 0 +# plotter(variance, path = 'variance.html') + #plotter.plot.zmin = exp(mock_signal.min()); + plotter(mock_signal, path='mock_signal.html') + plotter(Field(signal_space, val=np.log(data.val.get_full_data()).reshape(signal_space.shape)), + path = 'data.html') +# plotter(m, path = 'map.html') + + + + + diff --git a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py index 9d6f82a2..a7f63fcd 100644 --- a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py +++ b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py @@ -1,8 +1,11 @@ from nifty.operators import EndomorphicOperator,\ InvertibleOperatorMixin +from nifty.energies.memoization import memo from nifty.basic_arithmetics import exp -class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): + +class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, + EndomorphicOperator): """The curvature of the LogNormalWienerFilterEnergy. This operator implements the second derivative of the @@ -21,15 +24,18 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): """ - def __init__(self, R, N, S, inverter=None, preconditioner=None, **kwargs): - + def __init__(self, R, N, S, d, position, inverter=None, + preconditioner=None, **kwargs): + self._cache = {} self.R = R self.N = N self.S = S + self.d = d + self.position = position if preconditioner is None: preconditioner = self.S.times self._domain = self.S.domain - super(WienerFilterCurvature, self).__init__( + super(LogNormalWienerFilterCurvature, self).__init__( inverter=inverter, preconditioner=preconditioner, **kwargs) @@ -49,11 +55,28 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): # ---Added properties and methods--- def _times(self, x, spaces): - expx = exp(x) - expxx = expx*x part1 = self.S.inverse_times(x) - part2 = (expx * # is an adjoint necessary here? - self.R.adjoint_times(self.N.inverse_times(self.R(expxx)))) - part3 = (expxx * # is an adjoint necessary here? - self.R.adjoint_times(self.N.inverse_times(self.R(expx)))) + part2 = self._exppRNRexppd * x + part3 = self._expp * self.R.adjoint_times( + self.N.inverse_times(self.R(self._expp * x))) return part1 + part2 + part3 + + @property + @memo + def _expp(self): + return exp(self.position) + + @property + @memo + def _Rexppd(self): + return self.R(self._expp) - self.d + + @property + @memo + def _NRexppd(self): + return self.N.inverse_times(self._Rexppd) + + @property + @memo + def _exppRNRexppd(self): + return self._expp * self.R.adjoint_times(self._NRexppd) diff --git a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py index 3a51a948..da1ba5b6 100644 --- a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py +++ b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py @@ -1,7 +1,8 @@ from nifty.energies.energy import Energy from nifty.energies.memoization import memo -from nifty.library.wiener_filter import WienerFilterCurvature -from nifty.basic_arithmetics import exp +from nifty.library.log_normal_wiener_filter import \ + LogNormalWienerFilterCurvature + class LogNormalWienerFilterEnergy(Energy): """The Energy for the log-normal Wiener filter. @@ -24,7 +25,7 @@ class LogNormalWienerFilterEnergy(Energy): """ def __init__(self, position, d, R, N, S): - super(WienerFilterEnergy, self).__init__(position=position) + super(LogNormalWienerFilterEnergy, self).__init__(position=position) self.d = d self.R = R self.N = N @@ -37,35 +38,37 @@ class LogNormalWienerFilterEnergy(Energy): @property @memo def value(self): - return (0.5*self.position.vdot(self._Sx) - - self._Rexpxd.vdot(self._NRexpxd)) + return 0.5*(self.position.vdot(self._Sp) - + self._Rexppd.vdot(self._NRexppd)) @property @memo def gradient(self): - return self._Sx + self._expx * self.R.adjoint_times(self._NRexpxd) + return self._Sp + self._exppRNRexppd @property @memo def curvature(self): - return WienerFilterCurvature(R=self.R, N=self.N, S=self.S) + return LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S, + d=self.d, position=self.position) @property - @memo - def _expx(self): - return exp(self.position) + def _expp(self): + return self.curvature._expp @property - @memo - def _Rexpxd(self): - return self.R(self._expx) - self.d + def _Rexppd(self): + return self.curvature._Rexppd @property - @memo - def _NRexpxd(self): - return self.N.inverse_times(self._Rexpxd) + def _NRexppd(self): + return self.curvature._NRexppd + + @property + def _exppRNRexppd(self): + return self.curvature._exppRNRexppd @property @memo - def _Sx(self): + def _Sp(self): return self.S.inverse_times(self.position) -- GitLab