Skip to content
Snippets Groups Projects
Commit b35351ab authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

implemented WienerFilterEnergy and WienerFilterCurvature, updated wiener_filter_hamiltonian

parent b60f3b77
No related branches found
No related tags found
1 merge request!120Working on demos
Pipeline #
...@@ -10,45 +10,20 @@ rank = comm.rank ...@@ -10,45 +10,20 @@ rank = comm.rank
np.random.seed(42) np.random.seed(42)
class WienerFilterEnergy(Energy): class AdjointFFTResponse(LinearOperator):
def __init__(self, position, D, j): def __init__(self, FFT, R, default_spaces=None):
# in principle not necessary, but useful in order to make the signature super(ResponseOperator, self).__init__(default_spaces)
# explicit self._domain = FFT.target
super(WienerFilterEnergy, self).__init__(position) self.target = R.target
self.D = D self.R = R
self.j = j self.FFT = FFT
def at(self, position): def _times(self, x):
return self.__class__(position, D=self.D, j=self.j) return self.R(self.FFT.adjoint_times(x))
@property def _adjoint_times(self, x):
def value(self): return self.FFT(self.R.adjoint_times(x))
D_inv_x = self.D_inverse_x()
H = 0.5 * D_inv_x.dot(self.position) - self.j.dot(self.position)
return H.real
@property
def gradient(self):
D_inv_x = self.D_inverse_x()
g = D_inv_x - self.j
return_g = g.copy_empty(dtype=np.float)
return_g.val = g.val.real
return return_g
@property
def curvature(self):
class Dummy(object):
def __init__(self, x):
self.x = x
def inverse_times(self, *args, **kwargs):
return self.x.times(*args, **kwargs)
my_dummy = Dummy(self.D)
return my_dummy
@memo
def D_inverse_x(self):
return D.inverse_times(self.position)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -72,10 +47,11 @@ if __name__ == "__main__": ...@@ -72,10 +47,11 @@ if __name__ == "__main__":
ss = fft.inverse_times(sh) ss = fft.inverse_times(sh)
# model the measurement process # model the measurement process
R = SmoothingOperator(s_space, sigma=0.01) Instrument = SmoothingOperator(s_space, sigma=0.01)
# R = DiagonalOperator(s_space, diagonal=1.)
# R._diagonal.val[200:400, 200:400] = 0
# Instrument = DiagonalOperator(s_space, diagonal=1.)
# Instrument._diagonal.val[200:400, 200:400] = 0
R = AdjointFFTResponse(fft, Instrument)
signal_to_noise = 1 signal_to_noise = 1
N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True) N = DiagonalOperator(s_space, diagonal=ss.var()/signal_to_noise, bare=True)
n = Field.from_random(domain=s_space, n = Field.from_random(domain=s_space,
...@@ -84,15 +60,11 @@ if __name__ == "__main__": ...@@ -84,15 +60,11 @@ if __name__ == "__main__":
mean=0) mean=0)
# create mock data # create mock data
d = R(ss) + n d = R(sh) + n
# set up reconstruction objects
j = R.adjoint_times(N.inverse_times(d))
D = PropagatorOperator(S=S, N=N, R=R)
def distance_measure(energy, iteration): def distance_measure(energy, iteration):
x = energy.position x = energy.value
print (iteration, ((x-ss).norm()/ss.norm()).real) print (x, iteration)
# minimizer = SteepestDescent(convergence_tolerance=0, # minimizer = SteepestDescent(convergence_tolerance=0,
# iteration_limit=50, # iteration_limit=50,
...@@ -107,13 +79,14 @@ if __name__ == "__main__": ...@@ -107,13 +79,14 @@ if __name__ == "__main__":
# callback=distance_measure, # callback=distance_measure,
# max_history_length=3) # max_history_length=3)
solution = energy.analytic_solution()
m0 = Field(s_space, val=1.) m0 = Field(s_space, val=1.)
energy = WienerFilterEnergy(position=m0, D=D, j=j) energy = WienerFilterEnergy(position=m0, D=D, j=j)
(energy, convergence) = minimizer(energy) (energy, convergence) = minimizer(energy)
m = energy.position m = fft.adjoint_times(energy.position)
d_data = d.val.get_full_data().real d_data = d.val.get_full_data().real
if rank == 0: if rank == 0:
......
...@@ -19,3 +19,4 @@ ...@@ -19,3 +19,4 @@
from energy import Energy from energy import Energy
from line_energy import LineEnergy from line_energy import LineEnergy
from memoization import memo from memoization import memo
from wiener_filter_energy import WienerFilterEnergy
\ No newline at end of file
from .energy import Energy
from nifty.operators import WienerFilterCurvature
class WienerFilterEnergy(Energy):
"""The Energy for the Wiener filter.
It describes the situation of linear measurement with
Gaussian noise and Gaussain signal prior.
Parameters
----------
d : Field,
the data.
R : Operator,
The response operator, describtion of the measurement process.
N : EndomorphicOperator,
The noise covariance in data space.
S : EndomorphicOperator,
The prior signal covariance in harmonic space.
"""
def __init__(self, position, d, R, N, S):
super(WienerFilterEnergy, self).__init__(position)
self.d = d
self.R = R
self.N = N
self.S = S
def at(self, position):
return self.__class__(position, self.d, self.R, self.N, self.S)
@property
def value(self):
energy = 0.5 * self.position.dot(self.S.inverse_times(self.position))
energy += 0.5 * (self.d - self.R(self.position)).dot(
self.N.inverse_times(self.d - self.R(self.position)))
return energy
@property
def gradient(self):
gradient = self.S.inverse_times(self.position)
gradient -= self.N.inverse_times(self.d - self.R(self.position))
return gradient
@property
def curvature(self):
curvature = WienerFilterCurvature(R=self.R, N=self.N, S=self.S)
return curvature
def analytic_solution(self):
D_inverse = self.curvature()
j = self.R.adjoint_times(self.N.inverse_times(self.d))
new_position = D_inverse.inverse_times(j)
return self.at(new_position)
...@@ -39,3 +39,5 @@ from propagator_operator import HarmonicPropagatorOperator ...@@ -39,3 +39,5 @@ from propagator_operator import HarmonicPropagatorOperator
from composed_operator import ComposedOperator from composed_operator import ComposedOperator
from response_operator import ResponseOperator from response_operator import ResponseOperator
from curvature_operators import WienerFilterCurvature
from wiener_filter_curvature import WienerFilterCurvature
\ No newline at end of file
from nifty.operators import EndomorphicOperator,\
InvertibleOperatorMixin
class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
def __init__(self, R, N, S, inverter=None, preconditioner=None):
self.R = R
self.N = N
self.S = S
if preconditioner is None:
preconditioner = self.S.times
self._domain = self.S.domain
super(WienerFilterCurvature, self).__init__(inverter=inverter,
preconditioner=preconditioner)
@property
def domain(self):
return self._domain
@property
def self_adjoint(self):
return True
@property
def unitary(self):
return False
# ---Added properties and methods---
def _times(self, x, spaces):
return self.R.adjoint_times(self.N.inverse_times(self.R(x)))\
+ self.S.inverse_times(x)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment