critical_power_energy.py 5.03 KB
Newer Older
1 2 3
from ...energies.energy import Energy
from ...operators.smoothness_operator import SmoothnessOperator
from . import CriticalPowerCurvature
4
from ...memoization import memo
5

6 7
from ...sugar import generate_posterior_sample
from ... import Field, exp
8

9

10
class CriticalPowerEnergy(Energy):
Jakob Knollmueller's avatar
Jakob Knollmueller committed
11
    """The Energy of the power spectrum according to the critical filter.
12

13 14 15 16
    It describes the energy of the logarithmic amplitudes of the power spectrum
    of a Gaussian random field under reconstruction uncertainty with smoothness
    and inverse gamma prior. It is used to infer the correlation structure of a
    correlated signal.
17 18 19

    Parameters
    ----------
Jakob Knollmueller's avatar
Jakob Knollmueller committed
20 21 22 23 24 25 26 27 28
    position : Field,
        The current position of this energy.
    m : Field,
        The map whichs power spectrum has to be inferred
    D : EndomorphicOperator,
        The curvature of the Gaussian encoding the posterior covariance.
        If not specified, the map is assumed to be no reconstruction.
        default : None
    alpha : float
29 30
        The spectral prior of the inverse gamma distribution. 1.0 corresponds
        to non-informative.
Jakob Knollmueller's avatar
Jakob Knollmueller committed
31 32
        default : 1.0
    q : float
33 34 35 36 37
        The cutoff parameter of the inverse gamma distribution. 0.0 corresponds
        to non-informative.
        default : 0.0
    smoothness_prior : float
        Controls the strength of the smoothness prior
Jakob Knollmueller's avatar
Jakob Knollmueller committed
38 39 40 41
        default : 0.0
    logarithmic : boolean
        Whether smoothness acts on linear or logarithmic scale.
    samples : integer
42 43
        Number of samples used for the estimation of the uncertainty
        corrections.
Jakob Knollmueller's avatar
Jakob Knollmueller committed
44 45 46 47 48
        default : 3
    w : Field
        The contribution from the map with or without uncertainty. It is used
        to pass on the result of the costly sampling during the minimization.
        default : None
49 50 51
    inverter : ConjugateGradient
        The inversion strategy to invert the curvature and to generate samples.
        default : None
52 53
    """

54 55
    # ---Overwritten properties and methods---

56
    def __init__(self, position, m, D=None, alpha=1.0, q=0.,
57 58
                 smoothness_prior=0., logarithmic=True, samples=3, w=None,
                 inverter=None):
59
        super(CriticalPowerEnergy, self).__init__(position=position)
60
        self.m = m
61 62
        self.D = D
        self.samples = samples
63 64
        self.alpha = Field(self.position.domain, val=alpha)
        self.q = Field(self.position.domain, val=q)
65
        self.T = SmoothnessOperator(domain=self.position.domain[0],
66
                                    strength=smoothness_prior,
Jakob Knollmueller's avatar
Jakob Knollmueller committed
67
                                    logarithmic=logarithmic)
68
        self.rho = self.position.domain[0].rho
69 70
        self._w = w
        self._inverter = inverter
71

72 73
    # ---Mandatory properties and methods---

74
    def at(self, position):
75 76
        return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
                              q=self.q, smoothness_prior=self.smoothness_prior,
77
                              logarithmic=self.logarithmic,
78 79
                              w=self.w, samples=self.samples,
                              inverter=self._inverter)
80 81 82

    @property
    def value(self):
83
        energy = self._theta.sum()
Martin Reinecke's avatar
Martin Reinecke committed
84
        energy += self.position.weight(-1).vdot(self._rho_prime)
85
        energy += 0.5 * self.position.vdot(self._Tt)
86 87 88 89
        return energy.real

    @property
    def gradient(self):
90
        gradient = -self._theta.weight(-1)
91
        gradient += self._rho_prime.weight(-1)
92
        gradient += self._Tt
Martin Reinecke's avatar
Martin Reinecke committed
93
        gradient = gradient.real
94 95 96 97
        return gradient

    @property
    def curvature(self):
Martin Reinecke's avatar
Martin Reinecke committed
98
        curvature = CriticalPowerCurvature(theta=self._theta.weight(-1),
99
                                           T=self.T, inverter=self._inverter)
Martin Reinecke's avatar
Martin Reinecke committed
100
        return curvature
101

102 103 104 105 106 107 108 109 110 111
    # ---Added properties and methods---

    @property
    def logarithmic(self):
        return self.T.logarithmic

    @property
    def smoothness_prior(self):
        return self.T.strength

112 113 114
    @property
    def w(self):
        if self._w is None:
Martin Reinecke's avatar
PEP8  
Martin Reinecke committed
115
            # self.logger.info("Initializing w")
116 117 118
            w = Field(domain=self.position.domain, val=0., dtype=self.m.dtype)
            if self.D is not None:
                for i in range(self.samples):
Martin Reinecke's avatar
PEP8  
Martin Reinecke committed
119
                    # self.logger.info("Drawing sample %i" % i)
120 121 122
                    posterior_sample = generate_posterior_sample(
                                                            self.m, self.D)
                    projected_sample = posterior_sample.power_analyze(
Martin Reinecke's avatar
Martin Reinecke committed
123
                     binbounds=self.position.domain[0].binbounds)
124 125 126 127
                    w += (projected_sample) * self.rho
                w /= float(self.samples)
            else:
                w = self.m.power_analyze(
Martin Reinecke's avatar
Martin Reinecke committed
128
                     binbounds=self.position.domain[0].binbounds)
129 130 131
                w *= self.rho
            self._w = w
        return self._w
132

133 134 135
    @property
    @memo
    def _theta(self):
136
        return exp(-self.position) * (self.q + self.w / 2.)
137 138 139 140 141 142 143 144 145 146

    @property
    @memo
    def _rho_prime(self):
        return self.alpha - 1. + self.rho / 2.

    @property
    @memo
    def _Tt(self):
        return self.T(self.position)