critical_power_energy.py 5.37 KB
Newer Older
1
2
3
4
from ...energies.energy import Energy
from ...operators.smoothness_operator import SmoothnessOperator
from . import CriticalPowerCurvature
from ...energies.memoization import memo
5

6
7
from ...sugar import generate_posterior_sample
from ... import Field, exp
8

9

10
class CriticalPowerEnergy(Energy):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
11
    """The Energy of the power spectrum according to the critical filter.
12

13
14
15
16
    It describes the energy of the logarithmic amplitudes of the power spectrum
    of a Gaussian random field under reconstruction uncertainty with smoothness
    and inverse gamma prior. It is used to infer the correlation structure of a
    correlated signal.
17
18
19

    Parameters
    ----------
Jakob Knollmueller's avatar
Jakob Knollmueller committed
20
21
22
23
24
25
26
27
28
    position : Field,
        The current position of this energy.
    m : Field,
        The map whichs power spectrum has to be inferred
    D : EndomorphicOperator,
        The curvature of the Gaussian encoding the posterior covariance.
        If not specified, the map is assumed to be no reconstruction.
        default : None
    alpha : float
29
30
        The spectral prior of the inverse gamma distribution. 1.0 corresponds
        to non-informative.
Jakob Knollmueller's avatar
Jakob Knollmueller committed
31
32
        default : 1.0
    q : float
33
34
35
36
37
        The cutoff parameter of the inverse gamma distribution. 0.0 corresponds
        to non-informative.
        default : 0.0
    smoothness_prior : float
        Controls the strength of the smoothness prior
Jakob Knollmueller's avatar
Jakob Knollmueller committed
38
39
40
41
        default : 0.0
    logarithmic : boolean
        Whether smoothness acts on linear or logarithmic scale.
    samples : integer
42
43
        Number of samples used for the estimation of the uncertainty
        corrections.
Jakob Knollmueller's avatar
Jakob Knollmueller committed
44
45
46
47
48
        default : 3
    w : Field
        The contribution from the map with or without uncertainty. It is used
        to pass on the result of the costly sampling during the minimization.
        default : None
49
50
51
    inverter : ConjugateGradient
        The inversion strategy to invert the curvature and to generate samples.
        default : None
52
53
    """

54
55
    # ---Overwritten properties and methods---

56
    def __init__(self, position, m, D=None, alpha=1.0, q=0.,
57
58
                 smoothness_prior=0., logarithmic=True, samples=3, w=None,
                 old_curvature=None):
59
        super(CriticalPowerEnergy, self).__init__(position=position)
60
        self.m = m
61
62
        self.D = D
        self.samples = samples
63
64
        self.alpha = Field(self.position.domain, val=alpha)
        self.q = Field(self.position.domain, val=q)
65
        self.T = SmoothnessOperator(domain=self.position.domain[0],
66
                                    strength=smoothness_prior,
Jakob Knollmueller's avatar
Jakob Knollmueller committed
67
                                    logarithmic=logarithmic)
68
        self.rho = self.position.domain[0].rho
69
        self._w = w if w is not None else None
70
71
        self._old_curvature = old_curvature
        self._curvature = None
72

73
74
    # ---Mandatory properties and methods---

75
    def at(self, position):
76
77
        return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
                              q=self.q, smoothness_prior=self.smoothness_prior,
78
                              logarithmic=self.logarithmic,
79
80
                              w=self.w, samples=self.samples,
                              old_curvature=self._curvature)
81
82

    @property
83
    @memo
84
    def value(self):
85
        energy = self._theta.sum()
86
87
        energy += self.position.vdot(self._rho_prime, bare=True)
        energy += 0.5 * self.position.vdot(self._Tt)
88
89
90
        return energy.real

    @property
91
    @memo
92
    def gradient(self):
93
        gradient = -self._theta.weight(-1)
94
        gradient += (self._rho_prime).weight(-1)
95
        gradient += self._Tt
96
        gradient.val = gradient.val.real
97
98
99
100
        return gradient

    @property
    def curvature(self):
101
102
103
104
105
106
107
108
        if self._curvature is None:
            if self._old_curvature is None:
                self._curvature = CriticalPowerCurvature(
                                        theta=self._theta.weight(-1), T=self.T)
            else:
                self._curvature = self._old_curvature.copy(
                                        theta=self._theta.weight(-1), T=self.T)
        return self._curvature
109

110
111
112
113
114
115
116
117
118
119
    # ---Added properties and methods---

    @property
    def logarithmic(self):
        return self.T.logarithmic

    @property
    def smoothness_prior(self):
        return self.T.strength

120
121
122
    @property
    def w(self):
        if self._w is None:
Theo Steininger's avatar
Theo Steininger committed
123
            self.logger.info("Initializing w")
124
125
126
            w = Field(domain=self.position.domain, val=0., dtype=self.m.dtype)
            if self.D is not None:
                for i in range(self.samples):
Theo Steininger's avatar
Theo Steininger committed
127
                    self.logger.info("Drawing sample %i" % i)
128
129
130
                    posterior_sample = generate_posterior_sample(
                                                            self.m, self.D)
                    projected_sample = posterior_sample.power_analyze(
Martin Reinecke's avatar
Martin Reinecke committed
131
                     binbounds=self.position.domain[0].binbounds)
132
133
134
135
                    w += (projected_sample) * self.rho
                w /= float(self.samples)
            else:
                w = self.m.power_analyze(
Martin Reinecke's avatar
Martin Reinecke committed
136
                     binbounds=self.position.domain[0].binbounds)
137
138
139
                w *= self.rho
            self._w = w
        return self._w
140

141
142
143
    @property
    @memo
    def _theta(self):
144
        return exp(-self.position) * (self.q + self.w / 2.)
145
146
147
148
149
150
151
152
153
154

    @property
    @memo
    def _rho_prime(self):
        return self.alpha - 1. + self.rho / 2.

    @property
    @memo
    def _Tt(self):
        return self.T(self.position)