log_normal_wiener_filter_energy.py 2.21 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
from ..minimization.energy import Energy
from ..utilities import memo
from .log_normal_wiener_filter_curvature import LogNormalWienerFilterCurvature
from ..sugar import create_composed_fft_operator
Martin Reinecke's avatar
Martin Reinecke committed
5
from ..field import exp
Martin Reinecke's avatar
Martin Reinecke committed
6

7
8
9
10
11
12
13
14
15
16

class LogNormalWienerFilterEnergy(Energy):
    """The Energy for the log-normal Wiener filter.

    It covers the case of linear measurement with
    Gaussian noise and Gaussain signal prior with known covariance.

    Parameters
    ----------
    position: Field,
Martin Reinecke's avatar
Martin Reinecke committed
17
18
19
20
21
22
23
24
25
       The current position.
    d: Field,
       the data.
    R: Operator,
       The response operator, describtion of the measurement process.
    N: EndomorphicOperator,
       The noise covariance in data space.
    S: EndomorphicOperator,
       The prior signal covariance in harmonic space.
26
27
    """

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
28
    def __init__(self, position, d, R, N, S, inverter, fft=None):
29
        super(LogNormalWienerFilterEnergy, self).__init__(position=position)
30
31
32
33
        self.d = d
        self.R = R
        self.N = N
        self.S = S
Martin Reinecke's avatar
adjust    
Martin Reinecke committed
34
        self._inverter = inverter
35

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
36
        if fft is None:
37
38
39
            self._fft = create_composed_fft_operator(self.S.domain,
                                                     all_to='position')
        else:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
40
41
42
43
44
45
46
47
48
49
50
51
            self._fft = fft

        self._expp_sspace = exp(self._fft(self.position))

        Sp = self.S.inverse_times(self.position)
        expp = self._fft.adjoint_times(self._expp_sspace)
        Rexppd = self.R(expp) - self.d
        NRexppd = self.N.inverse_times(Rexppd)
        self._value = 0.5*(self.position.vdot(Sp) + Rexppd.vdot(NRexppd))
        exppRNRexppd = self._fft.adjoint_times(
            self._expp_sspace * self._fft(self.R.adjoint_times(NRexppd)))
        self._gradient = Sp + exppRNRexppd
52

53
    def at(self, position):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
54
55
        return self.__class__(position, self.d, self.R, self.N, self.S,
                              self._inverter, self._fft)
56
57
58

    @property
    def value(self):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
59
        return self._value
60
61
62

    @property
    def gradient(self):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
63
        return self._gradient
64
65

    @property
Martin Reinecke's avatar
Martin Reinecke committed
66
    @memo
67
    def curvature(self):
68
        return LogNormalWienerFilterCurvature(
Martin Reinecke's avatar
Martin Reinecke committed
69
            R=self.R, N=self.N, S=self.S, fft=self._fft,
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
70
            expp_sspace=self._expp_sspace, inverter=self._inverter)