Commit b2410f61 authored by Theo Steininger's avatar Theo Steininger

Refactored LogNormalWienerFilterCurvature and LogNormalWienerFilterEnergy....

Refactored LogNormalWienerFilterCurvature and LogNormalWienerFilterEnergy. Added log_normal_wiener_filter.py
parent ef9b17c9
Pipeline #15682 failed with stage
in 8 minutes and 1 second
# -*- coding: utf-8 -*-
from nifty import *
if __name__ == "__main__":
nifty_configuration['default_distribution_strategy'] = 'fftw'
# Setting up parameters |\label{code:wf_parameters}|
correlation_length = 1. # Typical distance over which the field is correlated
field_variance = 2. # Variance of field in position space
response_sigma = 0.05 # Smoothing length of response (in same unit as L)
signal_to_noise = 1.5 # The signal to noise ratio
np.random.seed(43) # Fixing the random seed
def power_spectrum(k): # Defining the power spectrum
a = 4 * correlation_length * field_variance**2
return a / (1 + k * correlation_length) ** 4
# Setting up the geometry |\label{code:wf_geometry}|
L = 2. # Total side-length of the domain
N_pixels = 512 # Grid resolution (pixels per axis)
signal_space = RGSpace([N_pixels, N_pixels], distances=L/N_pixels)
harmonic_space = FFTOperator.get_default_codomain(signal_space)
fft = FFTOperator(harmonic_space, target=signal_space, target_dtype=np.float)
power_space = PowerSpace(harmonic_space)
# Creating the mock signal |\label{code:wf_mock_signal}|
S = create_power_operator(harmonic_space, power_spectrum=power_spectrum)
mock_power = Field(power_space, val=power_spectrum)
mock_signal = fft(mock_power.power_synthesize(real_signal=True))
# Setting up an exemplary response
mask = Field(signal_space, val=1.)
N10 = int(N_pixels/10)
mask.val[N10*5:N10*9, N10*5:N10*9] = 0.
R = ResponseOperator(signal_space, sigma=(response_sigma,), exposure=(mask,)) #|\label{code:wf_response}|
data_domain = R.target[0]
R_harmonic = ComposedOperator([fft, R], default_spaces=[0, 0])
# Setting up the noise covariance and drawing a random noise realization
N = DiagonalOperator(data_domain, diagonal=mock_signal.var()/signal_to_noise, bare=True)
noise = Field.from_random(domain=data_domain, random_type='normal',
std=mock_signal.std()/np.sqrt(signal_to_noise), mean=0)
data = R(exp(mock_signal)) + noise #|\label{code:wf_mock_data}|
# Wiener filter
m0 = Field(harmonic_space, val=0.j)
energy = library.LogNormalWienerFilterEnergy(m0, data, R_harmonic, N, S)
minimizer = VL_BFGS(convergence_tolerance=0,
iteration_limit=50,
#callback=convergence_measure,
max_history_length=3)
minimizer = RelaxedNewton(convergence_tolerance=0,
iteration_limit=1,
#callback=convergence_measure
)
# # Probing the variance
# class Proby(DiagonalProberMixin, Prober): pass
# proby = Proby(signal_space, probe_count=100)
# proby(lambda z: fft(wiener_curvature.inverse_times(fft.inverse_times(z))))
#
# sm = SmoothingOperator(signal_space, sigma=0.02)
# variance = sm(proby.diagonal.weight(-1))
#Plotting #|\label{code:wf_plotting}|
plotter = plotting.RG2DPlotter(color_map=plotting.colormaps.PlankCmap())
plotter.figure.xaxis = plotting.Axis(label='Pixel Index')
plotter.figure.yaxis = plotting.Axis(label='Pixel Index')
#plotter.plot.zmax = exp(mock_signal.max()); plotter.plot.zmin = 0
# plotter(variance, path = 'variance.html')
#plotter.plot.zmin = exp(mock_signal.min());
plotter(mock_signal, path='mock_signal.html')
plotter(Field(signal_space, val=np.log(data.val.get_full_data()).reshape(signal_space.shape)),
path = 'data.html')
# plotter(m, path = 'map.html')
from nifty.operators import EndomorphicOperator,\ from nifty.operators import EndomorphicOperator,\
InvertibleOperatorMixin InvertibleOperatorMixin
from nifty.energies.memoization import memo
from nifty.basic_arithmetics import exp from nifty.basic_arithmetics import exp
class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
EndomorphicOperator):
"""The curvature of the LogNormalWienerFilterEnergy. """The curvature of the LogNormalWienerFilterEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
...@@ -21,15 +24,18 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -21,15 +24,18 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
""" """
def __init__(self, R, N, S, inverter=None, preconditioner=None, **kwargs): def __init__(self, R, N, S, d, position, inverter=None,
preconditioner=None, **kwargs):
self._cache = {}
self.R = R self.R = R
self.N = N self.N = N
self.S = S self.S = S
self.d = d
self.position = position
if preconditioner is None: if preconditioner is None:
preconditioner = self.S.times preconditioner = self.S.times
self._domain = self.S.domain self._domain = self.S.domain
super(WienerFilterCurvature, self).__init__( super(LogNormalWienerFilterCurvature, self).__init__(
inverter=inverter, inverter=inverter,
preconditioner=preconditioner, preconditioner=preconditioner,
**kwargs) **kwargs)
...@@ -49,11 +55,28 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator): ...@@ -49,11 +55,28 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
# ---Added properties and methods--- # ---Added properties and methods---
def _times(self, x, spaces): def _times(self, x, spaces):
expx = exp(x)
expxx = expx*x
part1 = self.S.inverse_times(x) part1 = self.S.inverse_times(x)
part2 = (expx * # is an adjoint necessary here? part2 = self._exppRNRexppd * x
self.R.adjoint_times(self.N.inverse_times(self.R(expxx)))) part3 = self._expp * self.R.adjoint_times(
part3 = (expxx * # is an adjoint necessary here? self.N.inverse_times(self.R(self._expp * x)))
self.R.adjoint_times(self.N.inverse_times(self.R(expx))))
return part1 + part2 + part3 return part1 + part2 + part3
@property
@memo
def _expp(self):
return exp(self.position)
@property
@memo
def _Rexppd(self):
return self.R(self._expp) - self.d
@property
@memo
def _NRexppd(self):
return self.N.inverse_times(self._Rexppd)
@property
@memo
def _exppRNRexppd(self):
return self._expp * self.R.adjoint_times(self._NRexppd)
from nifty.energies.energy import Energy from nifty.energies.energy import Energy
from nifty.energies.memoization import memo from nifty.energies.memoization import memo
from nifty.library.wiener_filter import WienerFilterCurvature from nifty.library.log_normal_wiener_filter import \
from nifty.basic_arithmetics import exp LogNormalWienerFilterCurvature
class LogNormalWienerFilterEnergy(Energy): class LogNormalWienerFilterEnergy(Energy):
"""The Energy for the log-normal Wiener filter. """The Energy for the log-normal Wiener filter.
...@@ -24,7 +25,7 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -24,7 +25,7 @@ class LogNormalWienerFilterEnergy(Energy):
""" """
def __init__(self, position, d, R, N, S): def __init__(self, position, d, R, N, S):
super(WienerFilterEnergy, self).__init__(position=position) super(LogNormalWienerFilterEnergy, self).__init__(position=position)
self.d = d self.d = d
self.R = R self.R = R
self.N = N self.N = N
...@@ -37,35 +38,37 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -37,35 +38,37 @@ class LogNormalWienerFilterEnergy(Energy):
@property @property
@memo @memo
def value(self): def value(self):
return (0.5*self.position.vdot(self._Sx) - return 0.5*(self.position.vdot(self._Sp) -
self._Rexpxd.vdot(self._NRexpxd)) self._Rexppd.vdot(self._NRexppd))
@property @property
@memo @memo
def gradient(self): def gradient(self):
return self._Sx + self._expx * self.R.adjoint_times(self._NRexpxd) return self._Sp + self._exppRNRexppd
@property @property
@memo @memo
def curvature(self): def curvature(self):
return WienerFilterCurvature(R=self.R, N=self.N, S=self.S) return LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S,
d=self.d, position=self.position)
@property @property
@memo def _expp(self):
def _expx(self): return self.curvature._expp
return exp(self.position)
@property @property
@memo def _Rexppd(self):
def _Rexpxd(self): return self.curvature._Rexppd
return self.R(self._expx) - self.d
@property @property
@memo def _NRexppd(self):
def _NRexpxd(self): return self.curvature._NRexppd
return self.N.inverse_times(self._Rexpxd)
@property
def _exppRNRexppd(self):
return self.curvature._exppRNRexppd
@property @property
@memo @memo
def _Sx(self): def _Sp(self):
return self.S.inverse_times(self.position) return self.S.inverse_times(self.position)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment