critical_power_energy.py 2.64 KB
Newer Older
1
from nifty.energies.energy import Energy
2
3
from nifty.library.operator_library import CriticalPowerCurvature,\
                                            SmoothnessOperator
4
from nifty.sugar import generate_posterior_sample
5
from nifty import Field, exp
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

class CriticalPowerEnergy(Energy):
    """The Energy for the Gaussian lognormal case.

    It describes the situation of linear measurement  of a
    lognormal signal with Gaussian noise and Gaussain signal prior.

    Parameters
    ----------
    d : Field,
        the data.
    R : Operator,
        The nonlinear response operator, describtion of the measurement process.
    N : EndomorphicOperator,
        The noise covariance in data space.
    S : EndomorphicOperator,
        The prior signal covariance in harmonic space.
    """

25
    def __init__(self, position, m, D=None, alpha =1.0, q=0, sigma=0, w=None, samples=3):
26
        super(CriticalPowerEnergy, self).__init__(position = position)
27
        self.m = m
28
29
30
31
32
33
34
        self.D = D
        self.samples = samples
        self.sigma = sigma
        self.alpha = alpha
        self.q = q
        self.T = SmoothnessOperator(domain=self.position.domain, sigma=self.sigma)
        self.rho = self.position.domain.rho
35
        if w is None:
36
37
            self.w = self._calculate_w(self.m, self.D, self.samples)
        self.theta = exp(-self.position) * (self.q + w / 2.)
38
39

    def at(self, position):
40
41
42
43
44
        return self.__class__(position, self.m, D=self.D,
                              alpha =self.alpha,
                              q=self.q,
                              sigma=self.sigma, w=self.w,
                              samples=self.samples)
45
46
47

    @property
    def value(self):
48
49
50
        energy = self.theta.sum()
        energy += self.position.dot(self.alpha - 1 + self.rho / 2.)
        energy += 0.5 * self.position.dot(self.T(self.position))
51
52
53
54
        return energy.real

    @property
    def gradient(self):
55
56
57
        gradient = - self.theta
        gradient += self.alpha - 1 + self.rho / 2.
        gradient += self.T(self.position)
58
59
60
61
        return gradient

    @property
    def curvature(self):
62
        curvature = CriticalPowerCurvature(theta=self.theta, T = self.T)
63
64
65
66
        return curvature

    def _calculate_w(self, m, D, samples):
        w = Field(domain=self.position.domain, val=0)
67
68
69
        if D is not None:
            for i in range(samples):
                posterior_sample = generate_posterior_sample(m, D)
70
                projected_sample =posterior_sample.project_power(domain=self.position.domain)
71
72
73
                w += projected_sample
            w /= float(samples)
        else:
74
            w = m.project_power(domain=self.position.domain)
75

76
        return w
77
78
79