log_normal_wiener_filter_energy.py 2.28 KB
Newer Older
1 2
from nifty.energies.energy import Energy
from nifty.energies.memoization import memo
3 4
from nifty.library.log_normal_wiener_filter import \
    LogNormalWienerFilterCurvature
5
from nifty.sugar import create_composed_fft_operator
6

7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27

class LogNormalWienerFilterEnergy(Energy):
    """The Energy for the log-normal Wiener filter.

    It covers the case of linear measurement with
    Gaussian noise and Gaussain signal prior with known covariance.

    Parameters
    ----------
    position: Field,
        The current position.
    d : Field,
        the data.
    R : Operator,
        The response operator, describtion of the measurement process.
    N : EndomorphicOperator,
        The noise covariance in data space.
    S : EndomorphicOperator,
        The prior signal covariance in harmonic space.
    """

28
    def __init__(self, position, d, R, N, S, fft4exp=None):
29
        super(LogNormalWienerFilterEnergy, self).__init__(position=position)
30 31 32 33 34
        self.d = d
        self.R = R
        self.N = N
        self.S = S

35 36 37 38 39 40
        if fft4exp is None:
            self._fft = create_composed_fft_operator(self.S.domain,
                                                     all_to='position')
        else:
            self._fft = fft4exp

41 42
    def at(self, position):
        return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
43
                              S=self.S, fft4exp=self._fft)
44 45 46 47

    @property
    @memo
    def value(self):
Theo Steininger's avatar
Theo Steininger committed
48
        return 0.5*(self.position.vdot(self._Sp) +
49
                    self._Rexppd.vdot(self._NRexppd))
50 51 52 53

    @property
    @memo
    def gradient(self):
54
        return self._Sp + self._exppRNRexppd
55 56 57 58

    @property
    @memo
    def curvature(self):
59
        return LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S,
60 61
                                              d=self.d, position=self.position,
                                              fft4exp=self._fft)
62 63

    @property
64 65
    def _expp(self):
        return self.curvature._expp
66 67

    @property
68 69
    def _Rexppd(self):
        return self.curvature._Rexppd
70 71

    @property
72 73 74 75 76 77
    def _NRexppd(self):
        return self.curvature._NRexppd

    @property
    def _exppRNRexppd(self):
        return self.curvature._exppRNRexppd
78 79 80

    @property
    @memo
81
    def _Sp(self):
82
        return self.S.inverse_times(self.position)