Skip to content
Snippets Groups Projects
Commit 746a9122 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent dbe83a62
No related branches found
No related tags found
1 merge request!207new operator convenience functionality
Pipeline #
from ..operators import InversionEnabler from ..operators import InversionEnabler
def LogNormalWienerFilterCurvature(R, N, S, fft, expp_sspace, inverter): def LogNormalWienerFilterCurvature(R, N, S, fft, expp_sspace, inverter):
part1 = S.inverse part1 = S.inverse
part3 = (fft.adjoint * expp_sspace * fft * part3 = (fft.adjoint * expp_sspace * fft *
......
...@@ -25,7 +25,7 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -25,7 +25,7 @@ class LogNormalWienerFilterEnergy(Energy):
The prior signal covariance in harmonic space. The prior signal covariance in harmonic space.
""" """
def __init__(self, position, d, R, N, S, inverter, fft4exp=None): def __init__(self, position, d, R, N, S, inverter, fft=None):
super(LogNormalWienerFilterEnergy, self).__init__(position=position) super(LogNormalWienerFilterEnergy, self).__init__(position=position)
self.d = d self.d = d
self.R = R self.R = R
...@@ -33,59 +33,38 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -33,59 +33,38 @@ class LogNormalWienerFilterEnergy(Energy):
self.S = S self.S = S
self._inverter = inverter self._inverter = inverter
if fft4exp is None: if fft is None:
self._fft = create_composed_fft_operator(self.S.domain, self._fft = create_composed_fft_operator(self.S.domain,
all_to='position') all_to='position')
else: else:
self._fft = fft4exp self._fft = fft
self._expp_sspace = exp(self._fft(self.position))
Sp = self.S.inverse_times(self.position)
expp = self._fft.adjoint_times(self._expp_sspace)
Rexppd = self.R(expp) - self.d
NRexppd = self.N.inverse_times(Rexppd)
self._value = 0.5*(self.position.vdot(Sp) + Rexppd.vdot(NRexppd))
exppRNRexppd = self._fft.adjoint_times(
self._expp_sspace * self._fft(self.R.adjoint_times(NRexppd)))
self._gradient = Sp + exppRNRexppd
def at(self, position): def at(self, position):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N, return self.__class__(position, self.d, self.R, self.N, self.S,
S=self.S, fft4exp=self._fft, self._inverter, self._fft)
inverter=self._inverter)
@property @property
@memo
def value(self): def value(self):
return 0.5*(self.position.vdot(self._Sp) + return self._value
self._Rexppd.vdot(self._NRexppd))
@property @property
@memo
def gradient(self): def gradient(self):
return self._Sp + self._exppRNRexppd return self._gradient
@property @property
@memo @memo
def curvature(self): def curvature(self):
return LogNormalWienerFilterCurvature( return LogNormalWienerFilterCurvature(
R=self.R, N=self.N, S=self.S, fft=self._fft, R=self.R, N=self.N, S=self.S, fft=self._fft,
expp_sspace = self._expp_sspace, inverter=self._inverter) expp_sspace=self._expp_sspace, inverter=self._inverter)
@property
@memo
def _Sp(self):
return self.S.inverse_times(self.position)
@property
@memo
def _expp_sspace(self):
return exp(self._fft(self.position))
@property
@memo
def _Rexppd(self):
expp = self._fft.adjoint_times(self._expp_sspace)
return self.R(expp) - self.d
@property
@memo
def _NRexppd(self):
return self.N.inverse_times(self._Rexppd)
@property
@memo
def _exppRNRexppd(self):
return self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(self._NRexppd)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment