energy_operators.py 11.5 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
19
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
20
21
from ..field import Field
from ..linearization import Linearization
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..sugar import makeOp, makeDomain
Martin Reinecke's avatar
Martin Reinecke committed
23
from .operator import Operator
Martin Reinecke's avatar
fix    
Martin Reinecke committed
24
from .sampling_enabler import SamplingEnabler
Philipp Arras's avatar
Philipp Arras committed
25
from .sandwich_operator import SandwichOperator
Martin Reinecke's avatar
Martin Reinecke committed
26
from .simple_linear_operators import VdotOperator
Martin Reinecke's avatar
Martin Reinecke committed
27
28
29


class EnergyOperator(Operator):
Martin Reinecke's avatar
Martin Reinecke committed
30
    """Abstract class from which
31
    other specific EnergyOperator subclasses are derived.
32

33
34
    An EnergyOperator has a scalar domain as target domain.

Martin Reinecke's avatar
Martin Reinecke committed
35
    It is intended as an objective function for field inference.
36

Martin Reinecke's avatar
Martin Reinecke committed
37
38
39
40
    Typical usage in IFT:
     - as an information Hamiltonian (i.e. a negative log probability)
     - or as a Gibbs free energy (i.e. an averaged Hamiltonian),
       aka Kullbach-Leibler divergence.
41
    """
Martin Reinecke's avatar
Martin Reinecke committed
42
43
44
45
    _target = DomainTuple.scalar_domain()


class SquaredNormOperator(EnergyOperator):
46
    """ Class for squared field norm energy.
Martin Reinecke's avatar
Martin Reinecke committed
47

48
49
    Usage
    -----
Martin Reinecke's avatar
Martin Reinecke committed
50
51
    E = SquaredNormOperator() represents a field energy E that is the L2 norm
    of a field f:
52

Martin Reinecke's avatar
Martin Reinecke committed
53
    :math:`E(f) = f^\dagger f`
Martin Reinecke's avatar
Martin Reinecke committed
54
    """
Martin Reinecke's avatar
Martin Reinecke committed
55
56
57
58
    def __init__(self, domain):
        self._domain = domain

    def apply(self, x):
59
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
60
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
61
            val = Field.scalar(x.val.vdot(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
62
            jac = VdotOperator(2*x.val)(x.jac)
63
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
64
        return Field.scalar(x.vdot(x))
Martin Reinecke's avatar
Martin Reinecke committed
65
66
67


class QuadraticFormOperator(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
68
    """Class for quadratic field energies.
69
70
71
72
73

    Parameters
    ----------
    op : EndomorphicOperator
         kernel of quadratic form
74

Martin Reinecke's avatar
Martin Reinecke committed
75
    Notes
76
    -----
Martin Reinecke's avatar
Martin Reinecke committed
77
78
    `E = QuadraticFormOperator(op)` represents a field energy that is a
    quadratic form in a field f with kernel op:
79

Martin Reinecke's avatar
Martin Reinecke committed
80
81
    :math:`E(f) = 0.5 f^\dagger op f`
    """
Martin Reinecke's avatar
Martin Reinecke committed
82
83
84
85
86
    def __init__(self, op):
        from .endomorphic_operator import EndomorphicOperator
        if not isinstance(op, EndomorphicOperator):
            raise TypeError("op must be an EndomorphicOperator")
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
87
        self._domain = op.domain
Martin Reinecke's avatar
Martin Reinecke committed
88
89

    def apply(self, x):
90
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
91
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
92
93
            t1 = self._op(x.val)
            jac = VdotOperator(t1)(x.jac)
Martin Reinecke's avatar
Martin Reinecke committed
94
            val = Field.scalar(0.5*x.val.vdot(t1))
95
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
96
        return Field.scalar(0.5*x.vdot(self._op(x)))
Martin Reinecke's avatar
Martin Reinecke committed
97
98
99


class GaussianEnergy(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
100
    """Class for energies of fields with Gaussian probability distribution.
101
102
103

    Attributes
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
104
105
106
107
108
109
110
111
    mean : Field
        mean of the Gaussian, (default 0)
    covariance : LinearOperator
        covariance of the Gaussian (default = identity operator)
    domain : Domainoid
        operator domain, inferred from mean or covariance if specified

    Notes
112
    -----
Martin Reinecke's avatar
Martin Reinecke committed
113
114
115
116
    - At least one of the arguments has to be provided.
    - `E = GaussianEnergy(mean=m, covariance=D)` represents (up to constants)
        :math:`E(f) = - \log G(f-m, D) = 0.5 (f-m)^\dagger D^{-1} (f-m)`,
        an information energy for a Gaussian distribution with mean m and covariance D.
117

Martin Reinecke's avatar
Martin Reinecke committed
118
    """
Martin Reinecke's avatar
Martin Reinecke committed
119

Martin Reinecke's avatar
Martin Reinecke committed
120
121
122
123
124
125
126
127
128
129
130
    def __init__(self, mean=None, covariance=None, domain=None):
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
        if covariance is not None:
            self._checkEquivalence(covariance.domain)
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Martin Reinecke's avatar
Martin Reinecke committed
131
132
133
134
        if covariance is None:
            self._op = SquaredNormOperator(self._domain).scale(0.5)
        else:
            self._op = QuadraticFormOperator(covariance.inverse)
Martin Reinecke's avatar
Martin Reinecke committed
135
136
137
        self._icov = None if covariance is None else covariance.inverse

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
138
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
139
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
140
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
141
        else:
Philipp Arras's avatar
Philipp Arras committed
142
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
143
144
145
                raise ValueError("domain mismatch")

    def apply(self, x):
146
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
147
        residual = x if self._mean is None else x-self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
148
        res = self._op(residual).real
149
        if not isinstance(x, Linearization) or not x.want_metric:
Martin Reinecke's avatar
Martin Reinecke committed
150
151
152
153
154
155
            return res
        metric = SandwichOperator.make(x.jac, self._icov)
        return res.add_metric(metric)


class PoissonianEnergy(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
156
    """Class for likelihood-energies of expected count field constrained by
157
158
159
160
    Poissonian count data.

    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
161
    d : Field
162
163
        data field with counts

Martin Reinecke's avatar
Martin Reinecke committed
164
    Notes
165
    -----
Martin Reinecke's avatar
Martin Reinecke committed
166
    E = GaussianEnergy(d) represents (up to an f-independent term log(d!))
167

Martin Reinecke's avatar
Martin Reinecke committed
168
    :math:`E(f) = -\log Poisson(d|f) = \sum f - d^\dagger \log(f)`,
169

Martin Reinecke's avatar
Martin Reinecke committed
170
171
172
    where f is a Field in data space with the expectation values for
    the counts.
    """
173
174
175
    def __init__(self, d):
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
176
177

    def apply(self, x):
178
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
179
180
        res = x.sum() - x.log().vdot(self._d)
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
181
            return Field.scalar(res)
182
183
        if not x.want_metric:
            return res
Martin Reinecke's avatar
Martin Reinecke committed
184
185
186
        metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
        return res.add_metric(metric)

187

188
class InverseGammaLikelihood(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
189
    """Special class for inverse Gamma distributed covariances.
190
191
192

    RL FIXME: To be documented.
    """
Martin Reinecke's avatar
Martin Reinecke committed
193
    def __init__(self, d):
194
195
        self._d = d
        self._domain = DomainTuple.make(d.domain)
196
197

    def apply(self, x):
198
        self._check_input(x)
Philipp Frank's avatar
Philipp Frank committed
199
        res = 0.5*(x.log().sum() + (1./x).vdot(self._d))
200
201
        if not isinstance(x, Linearization):
            return Field.scalar(res)
202
203
        if not x.want_metric:
            return res
204
205
206
207
        metric = SandwichOperator.make(x.jac, makeOp(0.5/(x.val**2)))
        return res.add_metric(metric)


Martin Reinecke's avatar
Martin Reinecke committed
208
class BernoulliEnergy(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
209
    """Class for likelihood-energies of expected event frequency constrained by
210
211
212
213
    event data.

    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
214
    d : Field
215
216
        data field with events (=1) or non-events (=0)

Martin Reinecke's avatar
Martin Reinecke committed
217
    Notes
218
219
220
    -----
    E = BernoulliEnergy(d) represents

Martin Reinecke's avatar
Martin Reinecke committed
221
    :math:`E(f) = -\log \mbox{Bernoulli}(d|f) = -d^\dagger \log f  - (1-d)^\dagger \log(1-f)`,
222

Martin Reinecke's avatar
Martin Reinecke committed
223
224
225
    where f is a field in data space (d.domain) with the expected frequencies of
    events.
    """
226
    def __init__(self, d):
Martin Reinecke's avatar
Martin Reinecke committed
227
        self._d = d
228
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
229
230

    def apply(self, x):
231
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
232
233
        v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
234
            return Field.scalar(v)
235
236
        if not x.want_metric:
            return v
Martin Reinecke's avatar
Martin Reinecke committed
237
238
239
240
241
242
        met = makeOp(1./(x.val*(1.-x.val)))
        met = SandwichOperator.make(x.jac, met)
        return v.add_metric(met)


class Hamiltonian(EnergyOperator):
243
244
245
246
247
    """Class for information Hamiltonians.

    Parameters
    ----------
    lh : EnergyOperator
Martin Reinecke's avatar
Martin Reinecke committed
248
        a likelihood energy
Martin Reinecke's avatar
Martin Reinecke committed
249
    ic_samp : IterationController
Martin Reinecke's avatar
Martin Reinecke committed
250
251
252
        is passed to SamplingEnabler to draw Gaussian distributed samples
        with covariance = metric of Hamiltonian
        (= Hessian without terms that generate negative eigenvalues)
253

Martin Reinecke's avatar
Martin Reinecke committed
254
    Notes
255
256
257
    -----
    H = Hamiltonian(E_lh) represents

Martin Reinecke's avatar
Martin Reinecke committed
258
    :math:`H(f) = 0.5 f^\dagger f + E_{lh}(f)`
259

Martin Reinecke's avatar
Martin Reinecke committed
260
261
    an information Hamiltonian for a field f with a white Gaussian prior
    (unit covariance) and the likelihood energy :math:`E_{lh}`.
262

Martin Reinecke's avatar
Martin Reinecke committed
263
    Other field priors can be represented via transformations of a white
264
265
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
266
    By implementing prior information this way, the field prior is represented
267
268
269
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Martin Reinecke's avatar
Martin Reinecke committed
270
    For more details see:
271
272
273
    "Encoding prior knowledge in the structure of the likelihood"
    Jakob Knollmüller, Torsten A. Ensslin, submitted, arXiv:1812.04403
    https://arxiv.org/abs/1812.04403
Martin Reinecke's avatar
Martin Reinecke committed
274
    """
Martin Reinecke's avatar
Martin Reinecke committed
275
276
277
278
    def __init__(self, lh, ic_samp=None):
        self._lh = lh
        self._prior = GaussianEnergy(domain=lh.domain)
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
279
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
280
281

    def apply(self, x):
282
        self._check_input(x)
283
284
        if (self._ic_samp is None or not isinstance(x, Linearization) or
                not x.want_metric):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
285
            return self._lh(x)+self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
286
        else:
287
            lhx, prx = self._lh(x), self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
288
289
290
291
            mtr = SamplingEnabler(lhx.metric, prx.metric.inverse,
                                  self._ic_samp, prx.metric.inverse)
            return (lhx+prx).add_metric(mtr)

Philipp Arras's avatar
Philipp Arras committed
292
293
294
295
296
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
        subs += '\nPrior: Quadratic{}'.format(self._lh.domain.keys())
        return 'Hamiltonian:\n' + utilities.indent(subs)

Martin Reinecke's avatar
Martin Reinecke committed
297
298

class SampledKullbachLeiblerDivergence(EnergyOperator):
299
    """Class for Kullbach Leibler (KL) Divergence or Gibbs free energies
Martin Reinecke's avatar
Martin Reinecke committed
300
301

    Precisely a sample averaged Hamiltonian (or other energy) that represents
302
303
304
305
306
    approximatively the relevant part of a KL to be used in Variational Bayes
    inference if the samples are drawn from the approximating Gaussian.

    Let Q(f) = G(f-m,D) Gaussian used to approximate
    P(f|d), the correct posterior with information Hamiltonian
Martin Reinecke's avatar
Martin Reinecke committed
307
    H(d,f) = - log P(d,f) = - log P(f|d) + const.
308
309
310

    The KL divergence between those should then be optimized for m. It is

Martin Reinecke's avatar
Martin Reinecke committed
311
312
    :math:`KL(Q,P) = \int Df Q(f) \log Q(f)/P(f)\\
    = \left< \log Q(f) \\right>_Q(f) - < \log P(f) >_Q(f) = const + < H(f) >_G(f-m,D)`
Martin Reinecke's avatar
Martin Reinecke committed
313
314

    in essence the information Hamiltonian averaged over a Gaussian distribution
315
316
    centered on the mean m.

Martin Reinecke's avatar
Martin Reinecke committed
317
    SampledKullbachLeiblerDivergence(H) approximates < H(f) >_G(f-m,D) if the
318
    residuals f-m are drawn from covariance D.
Martin Reinecke's avatar
Martin Reinecke committed
319

320
321
322
323
    Parameters
    ----------
    h: Hamiltonian
       the Hamiltonian/energy to be averaged
Martin Reinecke's avatar
Martin Reinecke committed
324
325
326
    res_samples : iterable of Fields
       set of residual sample points to be added to mean field
       for approximate estimation of the KL
327

Martin Reinecke's avatar
Martin Reinecke committed
328
329
    Notes
    -----
330
    KL = SampledKullbachLeiblerDivergence(H, samples) represents
Martin Reinecke's avatar
Martin Reinecke committed
331

Martin Reinecke's avatar
Martin Reinecke committed
332
    :math:`KL(m) = \sum_i H(m+v_i) / N`,
333
334
335
336

    where v_i are the residual samples, N is their number, and m is the mean field
    around which the samples are drawn.

Martin Reinecke's avatar
Martin Reinecke committed
337
338
339
340
341
    Having symmetrized residual samples, with both, v_i and -v_i being present,
    ensures that the distribution mean is exactly represented. This reduces sampling
    noise and helps the numerics of the KL minimization process in the variational
    Bayes inference.
    """
Martin Reinecke's avatar
Martin Reinecke committed
342
343
    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
344
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
345
346
347
        self._res_samples = tuple(res_samples)

    def apply(self, x):
348
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
349
350
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap) * (1./len(self._res_samples))