energy_operators.py 11.7 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
19
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
20
21
from ..field import Field
from ..linearization import Linearization
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..sugar import makeOp, makeDomain
Martin Reinecke's avatar
Martin Reinecke committed
23
from .operator import Operator
Martin Reinecke's avatar
fix    
Martin Reinecke committed
24
from .sampling_enabler import SamplingEnabler
Philipp Arras's avatar
Philipp Arras committed
25
from .sandwich_operator import SandwichOperator
Martin Reinecke's avatar
Martin Reinecke committed
26
from .simple_linear_operators import VdotOperator
Martin Reinecke's avatar
Martin Reinecke committed
27
28
29


class EnergyOperator(Operator):
Martin Reinecke's avatar
Martin Reinecke committed
30
    """Abstract class from which
31
    other specific EnergyOperator subclasses are derived.
32

33
34
    An EnergyOperator has a scalar domain as target domain.

Martin Reinecke's avatar
Martin Reinecke committed
35
    It is intended as an objective function for field inference.
36

Martin Reinecke's avatar
Martin Reinecke committed
37
    Typical usage in IFT:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
38

Martin Reinecke's avatar
Martin Reinecke committed
39
40
41
     - as an information Hamiltonian (i.e. a negative log probability)
     - or as a Gibbs free energy (i.e. an averaged Hamiltonian),
       aka Kullbach-Leibler divergence.
42
    """
Martin Reinecke's avatar
Martin Reinecke committed
43
44
45
46
    _target = DomainTuple.scalar_domain()


class SquaredNormOperator(EnergyOperator):
47
    """ Class for squared field norm energy.
Martin Reinecke's avatar
Martin Reinecke committed
48

49
50
    Usage
    -----
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
51
52
    ``E = SquaredNormOperator()`` represents a field energy E that is the
    L2 norm of a field f:
53

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
54
    :math:`E(f) = f^\\dagger f`
Martin Reinecke's avatar
Martin Reinecke committed
55
    """
Martin Reinecke's avatar
Martin Reinecke committed
56
57
58
59
    def __init__(self, domain):
        self._domain = domain

    def apply(self, x):
60
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
61
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
62
            val = Field.scalar(x.val.vdot(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
63
            jac = VdotOperator(2*x.val)(x.jac)
64
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
65
        return Field.scalar(x.vdot(x))
Martin Reinecke's avatar
Martin Reinecke committed
66
67
68


class QuadraticFormOperator(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
69
    """Class for quadratic field energies.
70
71
72
73
74

    Parameters
    ----------
    op : EndomorphicOperator
         kernel of quadratic form
75

Martin Reinecke's avatar
Martin Reinecke committed
76
    Notes
77
    -----
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
78
    ``E = QuadraticFormOperator(op)`` represents a field energy that is a
Martin Reinecke's avatar
Martin Reinecke committed
79
    quadratic form in a field f with kernel op:
80

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
81
    :math:`E(f) = 0.5 f^\\dagger op f`
Martin Reinecke's avatar
Martin Reinecke committed
82
    """
Martin Reinecke's avatar
Martin Reinecke committed
83
84
85
86
87
    def __init__(self, op):
        from .endomorphic_operator import EndomorphicOperator
        if not isinstance(op, EndomorphicOperator):
            raise TypeError("op must be an EndomorphicOperator")
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
88
        self._domain = op.domain
Martin Reinecke's avatar
Martin Reinecke committed
89
90

    def apply(self, x):
91
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
92
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
93
94
            t1 = self._op(x.val)
            jac = VdotOperator(t1)(x.jac)
Martin Reinecke's avatar
Martin Reinecke committed
95
            val = Field.scalar(0.5*x.val.vdot(t1))
96
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
97
        return Field.scalar(0.5*x.vdot(self._op(x)))
Martin Reinecke's avatar
Martin Reinecke committed
98
99
100


class GaussianEnergy(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
101
    """Class for energies of fields with Gaussian probability distribution.
102
103
104

    Attributes
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
105
106
107
108
109
110
111
112
    mean : Field
        mean of the Gaussian, (default 0)
    covariance : LinearOperator
        covariance of the Gaussian (default = identity operator)
    domain : Domainoid
        operator domain, inferred from mean or covariance if specified

    Notes
113
    -----
Martin Reinecke's avatar
Martin Reinecke committed
114
    - At least one of the arguments has to be provided.
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
115
116
    - ``E = GaussianEnergy(mean=m, covariance=D)`` represents (up to constants)

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
117
        :math:`E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m)`,
118

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
119
120
        an information energy for a Gaussian distribution with mean m and
        covariance D.
Martin Reinecke's avatar
Martin Reinecke committed
121
    """
Martin Reinecke's avatar
Martin Reinecke committed
122

Martin Reinecke's avatar
Martin Reinecke committed
123
124
125
126
127
128
129
130
131
132
133
    def __init__(self, mean=None, covariance=None, domain=None):
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
        if covariance is not None:
            self._checkEquivalence(covariance.domain)
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Martin Reinecke's avatar
Martin Reinecke committed
134
135
136
137
        if covariance is None:
            self._op = SquaredNormOperator(self._domain).scale(0.5)
        else:
            self._op = QuadraticFormOperator(covariance.inverse)
Martin Reinecke's avatar
Martin Reinecke committed
138
139
140
        self._icov = None if covariance is None else covariance.inverse

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
141
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
142
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
143
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
144
        else:
Philipp Arras's avatar
Philipp Arras committed
145
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
146
147
148
                raise ValueError("domain mismatch")

    def apply(self, x):
149
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
150
        residual = x if self._mean is None else x-self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
151
        res = self._op(residual).real
152
        if not isinstance(x, Linearization) or not x.want_metric:
Martin Reinecke's avatar
Martin Reinecke committed
153
154
155
156
157
158
            return res
        metric = SandwichOperator.make(x.jac, self._icov)
        return res.add_metric(metric)


class PoissonianEnergy(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
159
    """Class for likelihood-energies of expected count field constrained by
160
161
162
163
    Poissonian count data.

    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
164
    d : Field
165
166
        data field with counts

Martin Reinecke's avatar
Martin Reinecke committed
167
    Notes
168
    -----
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
169
170
    ``E = PoissonianEnergy(d)`` represents (up to an f-independent term
    log(d!))
171

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
172
    :math:`E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f)`,
173

Martin Reinecke's avatar
Martin Reinecke committed
174
175
176
    where f is a Field in data space with the expectation values for
    the counts.
    """
177
178
179
    def __init__(self, d):
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
180
181

    def apply(self, x):
182
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
183
184
        res = x.sum() - x.log().vdot(self._d)
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
185
            return Field.scalar(res)
186
187
        if not x.want_metric:
            return res
Martin Reinecke's avatar
Martin Reinecke committed
188
189
190
        metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
        return res.add_metric(metric)

191

192
class InverseGammaLikelihood(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
193
    """Special class for inverse Gamma distributed covariances.
194
195
196

    RL FIXME: To be documented.
    """
Martin Reinecke's avatar
Martin Reinecke committed
197
    def __init__(self, d):
198
199
        self._d = d
        self._domain = DomainTuple.make(d.domain)
200
201

    def apply(self, x):
202
        self._check_input(x)
Philipp Frank's avatar
Philipp Frank committed
203
        res = 0.5*(x.log().sum() + (1./x).vdot(self._d))
204
205
        if not isinstance(x, Linearization):
            return Field.scalar(res)
206
207
        if not x.want_metric:
            return res
208
209
210
211
        metric = SandwichOperator.make(x.jac, makeOp(0.5/(x.val**2)))
        return res.add_metric(metric)


Martin Reinecke's avatar
Martin Reinecke committed
212
class BernoulliEnergy(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
213
    """Class for likelihood-energies of expected event frequency constrained by
214
215
216
217
    event data.

    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
218
    d : Field
219
220
        data field with events (=1) or non-events (=0)

Martin Reinecke's avatar
Martin Reinecke committed
221
    Notes
222
    -----
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
223
    ``E = BernoulliEnergy(d)`` represents
224

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
225
226
    :math:`E(f) = -\\log \\text{Bernoulli}(d|f) =
    -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f)`,
227

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
228
229
    where f is a field in data space (d.domain) with the expected
    frequencies of events.
Martin Reinecke's avatar
Martin Reinecke committed
230
    """
231
    def __init__(self, d):
Martin Reinecke's avatar
Martin Reinecke committed
232
        self._d = d
233
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
234
235

    def apply(self, x):
236
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
237
238
        v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
239
            return Field.scalar(v)
240
241
        if not x.want_metric:
            return v
Martin Reinecke's avatar
Martin Reinecke committed
242
243
244
245
246
247
        met = makeOp(1./(x.val*(1.-x.val)))
        met = SandwichOperator.make(x.jac, met)
        return v.add_metric(met)


class Hamiltonian(EnergyOperator):
248
249
250
251
252
    """Class for information Hamiltonians.

    Parameters
    ----------
    lh : EnergyOperator
Martin Reinecke's avatar
Martin Reinecke committed
253
        a likelihood energy
Martin Reinecke's avatar
Martin Reinecke committed
254
    ic_samp : IterationController
Martin Reinecke's avatar
Martin Reinecke committed
255
256
257
        is passed to SamplingEnabler to draw Gaussian distributed samples
        with covariance = metric of Hamiltonian
        (= Hessian without terms that generate negative eigenvalues)
258

Martin Reinecke's avatar
Martin Reinecke committed
259
    Notes
260
    -----
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
261
    ``H = Hamiltonian(E_lh)`` represents
262

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
263
    :math:`H(f) = 0.5 f^\\dagger f + E_{lh}(f)`
264

Martin Reinecke's avatar
Martin Reinecke committed
265
266
    an information Hamiltonian for a field f with a white Gaussian prior
    (unit covariance) and the likelihood energy :math:`E_{lh}`.
267

Martin Reinecke's avatar
Martin Reinecke committed
268
    Other field priors can be represented via transformations of a white
269
270
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
271
    By implementing prior information this way, the field prior is represented
272
273
274
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Martin Reinecke's avatar
Martin Reinecke committed
275
    For more details see:
276
277
    "Encoding prior knowledge in the structure of the likelihood"
    Jakob Knollmüller, Torsten A. Ensslin, submitted, arXiv:1812.04403
Martin Reinecke's avatar
Martin Reinecke committed
278
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
279
    """
Martin Reinecke's avatar
Martin Reinecke committed
280
281
282
283
    def __init__(self, lh, ic_samp=None):
        self._lh = lh
        self._prior = GaussianEnergy(domain=lh.domain)
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
284
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
285
286

    def apply(self, x):
287
        self._check_input(x)
288
289
        if (self._ic_samp is None or not isinstance(x, Linearization) or
                not x.want_metric):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
290
            return self._lh(x)+self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
291
        else:
292
            lhx, prx = self._lh(x), self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
293
294
295
296
            mtr = SamplingEnabler(lhx.metric, prx.metric.inverse,
                                  self._ic_samp, prx.metric.inverse)
            return (lhx+prx).add_metric(mtr)

Philipp Arras's avatar
Philipp Arras committed
297
298
299
300
301
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
        subs += '\nPrior: Quadratic{}'.format(self._lh.domain.keys())
        return 'Hamiltonian:\n' + utilities.indent(subs)

Martin Reinecke's avatar
Martin Reinecke committed
302
303

class SampledKullbachLeiblerDivergence(EnergyOperator):
304
    """Class for Kullbach Leibler (KL) Divergence or Gibbs free energies
Martin Reinecke's avatar
Martin Reinecke committed
305
306

    Precisely a sample averaged Hamiltonian (or other energy) that represents
307
308
309
    approximatively the relevant part of a KL to be used in Variational Bayes
    inference if the samples are drawn from the approximating Gaussian.

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
310
311
    Let :math:`Q(f) = G(f-m,D)` Gaussian used to approximate
    :math:`P(f|d)`, the correct posterior with information Hamiltonian
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
312
    :math:`H(d,f) = -\\log P(d,f) = -\\log P(f|d) + \\text{const.}`
313
314
315

    The KL divergence between those should then be optimized for m. It is

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
316
317
318
    :math:`KL(Q,P) = \\int Df Q(f) \\log Q(f)/P(f)\\\\
    = \\left< \\log Q(f) \\right>_Q(f) - \\left< \\log P(f) \\right>_Q(f)\\\\
    = \\text{const} + \\left< H(f) \\right>_G(f-m,D)`
Martin Reinecke's avatar
Martin Reinecke committed
319

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
320
321
    in essence the information Hamiltonian averaged over a Gaussian
    distribution centered on the mean m.
322

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
323
    SampledKullbachLeiblerDivergence(H) approximates
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
324
    :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
325
    :math:`f-m` are drawn from covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
326

327
328
329
330
    Parameters
    ----------
    h: Hamiltonian
       the Hamiltonian/energy to be averaged
Martin Reinecke's avatar
Martin Reinecke committed
331
332
333
    res_samples : iterable of Fields
       set of residual sample points to be added to mean field
       for approximate estimation of the KL
334

Martin Reinecke's avatar
Martin Reinecke committed
335
336
    Notes
    -----
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
337
    ``KL = SampledKullbachLeiblerDivergence(H, samples)`` represents
Martin Reinecke's avatar
Martin Reinecke committed
338

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
339
    :math:`\\text{KL}(m) = \\sum_i H(m+v_i) / N`,
340

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
341
342
    where :math:`v_i` are the residual samples, :math:`N` is their number,
    and :math:`m` is the mean field around which the samples are drawn.
343

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
344
345
346
347
    Having symmetrized residual samples, with both v_i and -v_i being present
    ensures that the distribution mean is exactly represented. This reduces
    sampling noise and helps the numerics of the KL minimization process in the
    variational Bayes inference.
Martin Reinecke's avatar
Martin Reinecke committed
348
    """
Martin Reinecke's avatar
Martin Reinecke committed
349
350
    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
351
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
352
353
354
        self._res_samples = tuple(res_samples)

    def apply(self, x):
355
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
356
357
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap) * (1./len(self._res_samples))