critical_power_energy.py 2.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from nifty.energies.energy import Energy
from nifty.library.operator_library import CriticalPowerCurvature
from nifty.sugar import generate_posterior_sample
from nifty import Field

class CriticalPowerEnergy(Energy):
    """The Energy for the Gaussian lognormal case.

    It describes the situation of linear measurement  of a
    lognormal signal with Gaussian noise and Gaussain signal prior.

    Parameters
    ----------
    d : Field,
        the data.
    R : Operator,
        The nonlinear response operator, describtion of the measurement process.
    N : EndomorphicOperator,
        The noise covariance in data space.
    S : EndomorphicOperator,
        The prior signal covariance in harmonic space.
    """

24
    def __init__(self, position, m, D=None, alpha =1.0, q=0, w=None, samples=3):
25
        super(CriticalPowerEnergy, self).__init__(position = position)
26
27
28
        self.m = m
        if w is None:
            self._calculate_w(self.m, D)
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

    def at(self, position):
        return self.__class__(position, self.d, self.R, self.N, self.S)

    @property
    def value(self):
        energy = 0.5 * self.position.dot(self.S.inverse_times(self.position))
        energy += 0.5 * (self.d - self.R(self.position)).dot(
            self.N.inverse_times(self.d - self.R(self.position)))
        return energy.real

    @property
    def gradient(self):
        gradient = self.S.inverse_times(self.position)
        gradient -= self.R.derived_adjoint_times(
                    self.N.inverse_times(self.d - self.R(self.position)), self.position)
        return gradient

    @property
    def curvature(self):
        curvature =CriticalPowerCurvature(R=self.R,
                                                   N=self.N,
                                                   S=self.S,
                                                   position=self.position)
        return curvature

    def _calculate_w(self, m, D, samples):
        w = Field(domain=self.position.domain, val=0)
57
58
59
60
61
62
63
64
65
        if D is not None:
            for i in range(samples):
                posterior_sample = generate_posterior_sample(m, D)
                projected_sample =posterior_sample.power_analyze()**2
                w += projected_sample
            w /= float(samples)
        else:
            pass

66
67
68
69
        return w / float(samples)