wiener_filter_energy.py 1.79 KB
Newer Older
Jakob Knollmueller's avatar
Jakob Knollmueller committed
1
from nifty.energies.energy import Energy
2
from nifty.energies.memoization import memo
Jakob Knollmueller's avatar
Jakob Knollmueller committed
3
from nifty.library.operator_library import WienerFilterCurvature
4
5
6
7
8

class WienerFilterEnergy(Energy):
    """The Energy for the Wiener filter.

    It describes the situation of linear measurement with
Jakob Knollmueller's avatar
Jakob Knollmueller committed
9
    Gaussian noise and Gaussain signal prior with known covariance.
10
11
12

    Parameters
    ----------
Jakob Knollmueller's avatar
Jakob Knollmueller committed
13
14
    position: Field,
        The current position.
15
16
17
18
19
20
21
22
23
24
    d : Field,
        the data.
    R : Operator,
        The response operator, describtion of the measurement process.
    N : EndomorphicOperator,
        The noise covariance in data space.
    S : EndomorphicOperator,
        The prior signal covariance in harmonic space.
    """

25
    def __init__(self, position, d, R, N, S, inverter=None):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
26
        super(WienerFilterEnergy, self).__init__(position = position)
27
28
29
30
        self.d = d
        self.R = R
        self.N = N
        self.S = S
31
        self.inverter = inverter
32
33
34
35
36
37

    def at(self, position):
        return self.__class__(position, self.d, self.R, self.N, self.S)

    @property
    def value(self):
38
        residual = self._residual()
39
        energy = 0.5 * self.position.dot(self.S.inverse_times(self.position))
40
        energy += 0.5 * (residual).dot(self.N.inverse_times(residual))
Jakob Knollmueller's avatar
Jakob Knollmueller committed
41
        return energy.real
42
43
44

    @property
    def gradient(self):
45
        residual = self._residual()
46
        gradient = self.S.inverse_times(self.position)
Jakob Knollmueller's avatar
Jakob Knollmueller committed
47
        gradient -= self.R.adjoint_times(
48
                    self.N.inverse_times(residual))
49
50
51
52
        return gradient

    @property
    def curvature(self):
53
        curvature = WienerFilterCurvature(R=self.R, N=self.N, S=self.S, inverter=self.inverter)
54
55
        return curvature

56
57
58
59
60
    @memo
    def _residual(self):
        residual = self.d - self.R(self.position)
        return residual