energy_operators.py 11.6 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
19
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
20
21
from ..field import Field
from ..linearization import Linearization
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..sugar import makeOp, makeDomain
Martin Reinecke's avatar
Martin Reinecke committed
23
from .operator import Operator
Martin Reinecke's avatar
fix    
Martin Reinecke committed
24
from .sampling_enabler import SamplingEnabler
Philipp Arras's avatar
Philipp Arras committed
25
from .sandwich_operator import SandwichOperator
Martin Reinecke's avatar
Martin Reinecke committed
26
from .simple_linear_operators import VdotOperator
Martin Reinecke's avatar
Martin Reinecke committed
27
28
29


class EnergyOperator(Operator):
Martin Reinecke's avatar
Martin Reinecke committed
30
    """ An abstract class from which
31
    other specific EnergyOperator subclasses are derived.
32

33
34
    An EnergyOperator has a scalar domain as target domain.

Martin Reinecke's avatar
Martin Reinecke committed
35
    It is intended as an objective function for field inference.
36

Martin Reinecke's avatar
Martin Reinecke committed
37
38
39
40
    Typical usage in IFT:
     - as an information Hamiltonian (i.e. a negative log probability)
     - or as a Gibbs free energy (i.e. an averaged Hamiltonian),
       aka Kullbach-Leibler divergence.
41
    """
Martin Reinecke's avatar
Martin Reinecke committed
42
43
44
45
    _target = DomainTuple.scalar_domain()


class SquaredNormOperator(EnergyOperator):
46
    """ Class for squared field norm energy.
Martin Reinecke's avatar
Martin Reinecke committed
47

48
49
    Usage
    -----
Martin Reinecke's avatar
Martin Reinecke committed
50
51
    E = SquaredNormOperator() represents a field energy E that is the L2 norm
    of a field f:
52

53
    E(f) = f^dagger f
Martin Reinecke's avatar
Martin Reinecke committed
54
    """
Martin Reinecke's avatar
Martin Reinecke committed
55
56
57
58
    def __init__(self, domain):
        self._domain = domain

    def apply(self, x):
59
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
60
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
61
            val = Field.scalar(x.val.vdot(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
62
            jac = VdotOperator(2*x.val)(x.jac)
63
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
64
        return Field.scalar(x.vdot(x))
Martin Reinecke's avatar
Martin Reinecke committed
65
66
67


class QuadraticFormOperator(EnergyOperator):
68
69
70
71
72
73
    """ Class for quadratic field energies.

    Parameters
    ----------
    op : EndomorphicOperator
         kernel of quadratic form
74

Martin Reinecke's avatar
Martin Reinecke committed
75
    Notes
76
    -----
Martin Reinecke's avatar
Martin Reinecke committed
77
78
    `E = QuadraticFormOperator(op)` represents a field energy that is a
    quadratic form in a field f with kernel op:
79

Martin Reinecke's avatar
Martin Reinecke committed
80
81
    :math:`E(f) = 0.5 f^\dagger op f`
    """
Martin Reinecke's avatar
Martin Reinecke committed
82
83
84
85
86
    def __init__(self, op):
        from .endomorphic_operator import EndomorphicOperator
        if not isinstance(op, EndomorphicOperator):
            raise TypeError("op must be an EndomorphicOperator")
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
87
        self._domain = op.domain
Martin Reinecke's avatar
Martin Reinecke committed
88
89

    def apply(self, x):
90
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
91
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
92
93
            t1 = self._op(x.val)
            jac = VdotOperator(t1)(x.jac)
Martin Reinecke's avatar
Martin Reinecke committed
94
            val = Field.scalar(0.5*x.val.vdot(t1))
95
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
96
        return Field.scalar(0.5*x.vdot(self._op(x)))
Martin Reinecke's avatar
Martin Reinecke committed
97
98
99


class GaussianEnergy(EnergyOperator):
100
101
102
103
    """ Class for energies of fields with Gaussian probability distribution.

    Attributes
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
104
105
106
107
108
109
110
111
    mean : Field
        mean of the Gaussian, (default 0)
    covariance : LinearOperator
        covariance of the Gaussian (default = identity operator)
    domain : Domainoid
        operator domain, inferred from mean or covariance if specified

    Notes
112
    -----
Martin Reinecke's avatar
Martin Reinecke committed
113
114
115
116
    - At least one of the arguments has to be provided.
    - `E = GaussianEnergy(mean=m, covariance=D)` represents (up to constants)
        :math:`E(f) = - \log G(f-m, D) = 0.5 (f-m)^\dagger D^{-1} (f-m)`,
        an information energy for a Gaussian distribution with mean m and covariance D.
117

Martin Reinecke's avatar
Martin Reinecke committed
118
    """
Martin Reinecke's avatar
Martin Reinecke committed
119
120
121
122
123
124
125
126
127
128
129
    def __init__(self, mean=None, covariance=None, domain=None):
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
        if covariance is not None:
            self._checkEquivalence(covariance.domain)
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Martin Reinecke's avatar
Martin Reinecke committed
130
131
132
133
        if covariance is None:
            self._op = SquaredNormOperator(self._domain).scale(0.5)
        else:
            self._op = QuadraticFormOperator(covariance.inverse)
Martin Reinecke's avatar
Martin Reinecke committed
134
135
136
        self._icov = None if covariance is None else covariance.inverse

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
137
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
138
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
139
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
140
        else:
Philipp Arras's avatar
Philipp Arras committed
141
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
142
143
144
                raise ValueError("domain mismatch")

    def apply(self, x):
145
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
146
        residual = x if self._mean is None else x-self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
147
        res = self._op(residual).real
148
        if not isinstance(x, Linearization) or not x.want_metric:
Martin Reinecke's avatar
Martin Reinecke committed
149
150
151
152
153
154
            return res
        metric = SandwichOperator.make(x.jac, self._icov)
        return res.add_metric(metric)


class PoissonianEnergy(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
155
    """Class for likelihood-energies of expected count field constrained by
156
157
158
159
    Poissonian count data.

    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
160
    d : Field
161
162
        data field with counts

Martin Reinecke's avatar
Martin Reinecke committed
163
    Notes
164
    -----
Martin Reinecke's avatar
Martin Reinecke committed
165
    E = GaussianEnergy(d) represents (up to an f-independent term log(d!))
166

Martin Reinecke's avatar
Martin Reinecke committed
167
    E(f) = -\log Poisson(d|f) = sum(f) - d^\dagger \log(f),
168

Martin Reinecke's avatar
Martin Reinecke committed
169
170
171
    where f is a Field in data space with the expectation values for
    the counts.
    """
172
173
174
    def __init__(self, d):
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
175
176

    def apply(self, x):
177
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
178
179
        res = x.sum() - x.log().vdot(self._d)
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
180
            return Field.scalar(res)
181
182
        if not x.want_metric:
            return res
Martin Reinecke's avatar
Martin Reinecke committed
183
184
185
        metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
        return res.add_metric(metric)

186

187
class InverseGammaLikelihood(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
188
    """Special class for inverse Gamma distributed covariances.
189
190
191

    RL FIXME: To be documented.
    """
Martin Reinecke's avatar
Martin Reinecke committed
192
    def __init__(self, d):
193
194
        self._d = d
        self._domain = DomainTuple.make(d.domain)
195
196

    def apply(self, x):
197
        self._check_input(x)
Philipp Frank's avatar
Philipp Frank committed
198
        res = 0.5*(x.log().sum() + (1./x).vdot(self._d))
199
200
        if not isinstance(x, Linearization):
            return Field.scalar(res)
201
202
        if not x.want_metric:
            return res
203
204
205
206
        metric = SandwichOperator.make(x.jac, makeOp(0.5/(x.val**2)))
        return res.add_metric(metric)


Martin Reinecke's avatar
Martin Reinecke committed
207
class BernoulliEnergy(EnergyOperator):
Martin Reinecke's avatar
Martin Reinecke committed
208
    """Class for likelihood-energies of expected event frequency constrained by
209
210
211
212
    event data.

    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
213
    d : Field
214
215
        data field with events (=1) or non-events (=0)

Martin Reinecke's avatar
Martin Reinecke committed
216
    Notes
217
218
219
    -----
    E = BernoulliEnergy(d) represents

Martin Reinecke's avatar
Martin Reinecke committed
220
    :math:`E(f) = -\log \mbox{Bernoulli}(d|f) = -d^\dagger \log f  - (1-d)^\dagger \log(1-f)`,
221

Martin Reinecke's avatar
Martin Reinecke committed
222
223
224
    where f is a field in data space (d.domain) with the expected frequencies of
    events.
    """
225
    def __init__(self, d):
Martin Reinecke's avatar
Martin Reinecke committed
226
        self._d = d
227
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
228
229

    def apply(self, x):
230
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
231
232
        v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
233
            return Field.scalar(v)
234
235
        if not x.want_metric:
            return v
Martin Reinecke's avatar
Martin Reinecke committed
236
237
238
239
240
241
        met = makeOp(1./(x.val*(1.-x.val)))
        met = SandwichOperator.make(x.jac, met)
        return v.add_metric(met)


class Hamiltonian(EnergyOperator):
242
243
244
245
246
247
    """Class for information Hamiltonians.

    Parameters
    ----------
    lh : EnergyOperator
         a likelihood energy
Martin Reinecke's avatar
Martin Reinecke committed
248
    ic_samp : IterationController
249
              is passed to SamplingEnabler to draw Gaussian distributed samples
Martin Reinecke's avatar
Martin Reinecke committed
250
              with covariance = metric of Hamiltonian
251
252
253
                   (= Hessian without terms that generate negative eigenvalues)
              default = None

Martin Reinecke's avatar
Martin Reinecke committed
254
    Notes
255
256
257
    -----
    H = Hamiltonian(E_lh) represents

Martin Reinecke's avatar
Martin Reinecke committed
258
    :math:`H(f) = 0.5 f^\dagger f + E_{lh}(f)`
259

Martin Reinecke's avatar
Martin Reinecke committed
260
261
    an information Hamiltonian for a field f with a white Gaussian prior
    (unit covariance) and the likelihood energy :math:`E_{lh}`.
262

Martin Reinecke's avatar
Martin Reinecke committed
263
    Other field priors can be represented via transformations of a white
264
265
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
266
    By implementing prior information this way, the field prior is represented
267
268
269
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Martin Reinecke's avatar
Martin Reinecke committed
270
    For more details see:
271
272
273
    "Encoding prior knowledge in the structure of the likelihood"
    Jakob Knollmüller, Torsten A. Ensslin, submitted, arXiv:1812.04403
    https://arxiv.org/abs/1812.04403
Martin Reinecke's avatar
Martin Reinecke committed
274
    """
Martin Reinecke's avatar
Martin Reinecke committed
275
276
277
278
    def __init__(self, lh, ic_samp=None):
        self._lh = lh
        self._prior = GaussianEnergy(domain=lh.domain)
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
279
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
280
281

    def apply(self, x):
282
        self._check_input(x)
283
284
        if (self._ic_samp is None or not isinstance(x, Linearization) or
                not x.want_metric):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
285
            return self._lh(x)+self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
286
        else:
287
            lhx, prx = self._lh(x), self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
288
289
290
291
            mtr = SamplingEnabler(lhx.metric, prx.metric.inverse,
                                  self._ic_samp, prx.metric.inverse)
            return (lhx+prx).add_metric(mtr)

Philipp Arras's avatar
Philipp Arras committed
292
293
294
295
296
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
        subs += '\nPrior: Quadratic{}'.format(self._lh.domain.keys())
        return 'Hamiltonian:\n' + utilities.indent(subs)

Martin Reinecke's avatar
Martin Reinecke committed
297
298

class SampledKullbachLeiblerDivergence(EnergyOperator):
299
    """Class for Kullbach Leibler (KL) Divergence or Gibbs free energies
Martin Reinecke's avatar
Martin Reinecke committed
300
301

    Precisely a sample averaged Hamiltonian (or other energy) that represents
302
303
304
305
306
    approximatively the relevant part of a KL to be used in Variational Bayes
    inference if the samples are drawn from the approximating Gaussian.

    Let Q(f) = G(f-m,D) Gaussian used to approximate
    P(f|d), the correct posterior with information Hamiltonian
Martin Reinecke's avatar
Martin Reinecke committed
307
    H(d,f) = - log P(d,f) = - log P(f|d) + const.
308
309
310
311
312
313

    The KL divergence between those should then be optimized for m. It is

    KL(Q,P) = int Df Q(f) log Q(f)/P(f)
            = < log Q(f) >_Q(f) - < log P(f) >_Q(f)
            = const + < H(f) >_G(f-m,D)
Martin Reinecke's avatar
Martin Reinecke committed
314
315

    in essence the information Hamiltonian averaged over a Gaussian distribution
316
317
    centered on the mean m.

Martin Reinecke's avatar
Martin Reinecke committed
318
    SampledKullbachLeiblerDivergence(H) approximates < H(f) >_G(f-m,D) if the
319
    residuals f-m are drawn from covariance D.
Martin Reinecke's avatar
Martin Reinecke committed
320

321
322
323
324
325
326
327
328
    Parameters
    ----------
    h: Hamiltonian
       the Hamiltonian/energy to be averaged
    res_samples : iterable Field
                  set of residual sample points to be added to mean field
                  for approximate estimation of the KL

Martin Reinecke's avatar
Martin Reinecke committed
329
330
    Notes
    -----
331
    KL = SampledKullbachLeiblerDivergence(H, samples) represents
Martin Reinecke's avatar
Martin Reinecke committed
332

333
334
335
336
337
    KL(m) = sum_i H(m+v_i) / N,

    where v_i are the residual samples, N is their number, and m is the mean field
    around which the samples are drawn.

Martin Reinecke's avatar
Martin Reinecke committed
338
339
340
341
342
    Having symmetrized residual samples, with both, v_i and -v_i being present,
    ensures that the distribution mean is exactly represented. This reduces sampling
    noise and helps the numerics of the KL minimization process in the variational
    Bayes inference.
    """
Martin Reinecke's avatar
Martin Reinecke committed
343
344
    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
345
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
346
347
348
        self._res_samples = tuple(res_samples)

    def apply(self, x):
349
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
350
351
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap) * (1./len(self._res_samples))