energy_operators.py 11.8 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
19
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
20
21
from ..field import Field
from ..linearization import Linearization
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..sugar import makeOp, makeDomain
Martin Reinecke's avatar
Martin Reinecke committed
23
from .operator import Operator
Martin Reinecke's avatar
fix    
Martin Reinecke committed
24
from .sampling_enabler import SamplingEnabler
Philipp Arras's avatar
Philipp Arras committed
25
from .sandwich_operator import SandwichOperator
Martin Reinecke's avatar
Martin Reinecke committed
26
from .simple_linear_operators import VdotOperator
Martin Reinecke's avatar
Martin Reinecke committed
27
28
29


class EnergyOperator(Operator):
30
31
    """ Basis class EnergyOperator, an abstract class from which
    other specific EnergyOperator subclasses are derived.
32

33
34
35
36
    An EnergyOperator has a scalar domain as target domain.

    It turns a field into a scalar and a linearization into a linearization. 
    It is intended as an objective function for field inference.    
37

Torsten Ensslin's avatar
Torsten Ensslin committed
38
39
    Typical usage in IFT: 
    as an information Hamiltonian ( = negative log probability) 
40
41
    or as a Gibbs free energy ( = averaged Hamiltonian), 
    aka Kullbach-Leibler divergence. 
42
    """
Martin Reinecke's avatar
Martin Reinecke committed
43
44
45
46
    _target = DomainTuple.scalar_domain()


class SquaredNormOperator(EnergyOperator):
47
48
49
50
51
52
    """ Class for squared field norm energy.
    
    Usage
    -----
    E = SquaredNormOperator() represents a field energy E that is the L2 norm 
    of a field f: 
53

54
    E(f) = f^dagger f
55
    """   
Martin Reinecke's avatar
Martin Reinecke committed
56
57
58
59
    def __init__(self, domain):
        self._domain = domain

    def apply(self, x):
60
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
61
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
62
            val = Field.scalar(x.val.vdot(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
63
            jac = VdotOperator(2*x.val)(x.jac)
64
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
65
        return Field.scalar(x.vdot(x))
Martin Reinecke's avatar
Martin Reinecke committed
66
67
68


class QuadraticFormOperator(EnergyOperator):
69
70
71
72
73
74
    """ Class for quadratic field energies.

    Parameters
    ----------
    op : EndomorphicOperator
         kernel of quadratic form
75

76
77
78
79
    Usage
    -----
    E = QuadraticFormOperator(op) represents a field energy that is a 
    quadratic form in a field f with kernel op: 
80

81
    E(f) = 0.5 f^dagger op f  
82
    """      
Martin Reinecke's avatar
Martin Reinecke committed
83
84
85
86
87
    def __init__(self, op):
        from .endomorphic_operator import EndomorphicOperator
        if not isinstance(op, EndomorphicOperator):
            raise TypeError("op must be an EndomorphicOperator")
        self._op = op
Martin Reinecke's avatar
Martin Reinecke committed
88
        self._domain = op.domain
Martin Reinecke's avatar
Martin Reinecke committed
89
90

    def apply(self, x):
91
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
92
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
93
94
            t1 = self._op(x.val)
            jac = VdotOperator(t1)(x.jac)
Martin Reinecke's avatar
Martin Reinecke committed
95
            val = Field.scalar(0.5*x.val.vdot(t1))
96
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
97
        return Field.scalar(0.5*x.vdot(self._op(x)))
Martin Reinecke's avatar
Martin Reinecke committed
98
99
100


class GaussianEnergy(EnergyOperator):
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    """ Class for energies of fields with Gaussian probability distribution.

    Attributes
    ----------
    mean = mean (field) of the Gaussian, 
           default = 0
    covariance = field covariance of the Gaussian, 
           default = identity operator
    domain = domain of field, 
           default = domain of mean or covariance if specified

    One of the attributes has to be specified at instanciation of a GaussianEnergy
    to inform about the domain, otherwise an exception is rasied. 
 
    Usage
    -----
    E = GaussianEnergy(mean = m, covariance = D) represents (up to constants)

    E(f) = - log G(f-m, D) = 0.5 (f-m)^dagger D^-1 (f-m)

    an information energy for a Gaussian distribution with mean m and covariance D.
    """  
Martin Reinecke's avatar
Martin Reinecke committed
123
124
125
126
127
128
129
130
131
132
133
    def __init__(self, mean=None, covariance=None, domain=None):
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
        if covariance is not None:
            self._checkEquivalence(covariance.domain)
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Martin Reinecke's avatar
Martin Reinecke committed
134
135
136
137
        if covariance is None:
            self._op = SquaredNormOperator(self._domain).scale(0.5)
        else:
            self._op = QuadraticFormOperator(covariance.inverse)
Martin Reinecke's avatar
Martin Reinecke committed
138
139
140
        self._icov = None if covariance is None else covariance.inverse

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
141
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
142
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
143
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
144
        else:
Philipp Arras's avatar
Philipp Arras committed
145
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
146
147
148
                raise ValueError("domain mismatch")

    def apply(self, x):
149
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
150
        residual = x if self._mean is None else x-self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
151
        res = self._op(residual).real
152
        if not isinstance(x, Linearization) or not x.want_metric:
Martin Reinecke's avatar
Martin Reinecke committed
153
154
155
156
157
158
            return res
        metric = SandwichOperator.make(x.jac, self._icov)
        return res.add_metric(metric)


class PoissonianEnergy(EnergyOperator):
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    """Class for likelihood-energies of expected count field constrained by 
    Poissonian count data.

    Parameters
    ----------
    d : Field 
        data field with counts

    Usage
    -----
    E = GaussianEnergy(d) represents (up to an f-independent term log(d!)) 

    E(f) = -log Poisson(d|f) = sum(f) - d^dagger log(f),

    where f is a field in data space (d.domain) with the expectation values for 
    the counts.  
    """ 
176
177
178
    def __init__(self, d):
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
179
180

    def apply(self, x):
181
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
182
183
        res = x.sum() - x.log().vdot(self._d)
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
184
            return Field.scalar(res)
185
186
        if not x.want_metric:
            return res
Martin Reinecke's avatar
Martin Reinecke committed
187
188
189
        metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
        return res.add_metric(metric)

190

191
class InverseGammaLikelihood(EnergyOperator):
192
193
194
195
196
    """Special class for inverse Gamma distributed covariances. 

    RL FIXME: To be documented.
    """
     def __init__(self, d):
197
198
        self._d = d
        self._domain = DomainTuple.make(d.domain)
199
200

    def apply(self, x):
201
        self._check_input(x)
Philipp Frank's avatar
Philipp Frank committed
202
        res = 0.5*(x.log().sum() + (1./x).vdot(self._d))
203
204
        if not isinstance(x, Linearization):
            return Field.scalar(res)
205
206
        if not x.want_metric:
            return res
207
208
209
210
        metric = SandwichOperator.make(x.jac, makeOp(0.5/(x.val**2)))
        return res.add_metric(metric)


Martin Reinecke's avatar
Martin Reinecke committed
211
class BernoulliEnergy(EnergyOperator):
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    """Class for likelihood-energies of expected event frequency constrained by 
    event data.

    Parameters
    ----------
    d : Field 
        data field with events (=1) or non-events (=0)

    Usage
    -----
    E = BernoulliEnergy(d) represents

    E(f) = -log Bernoulli(d|f) = -d^dagger log(f) - (1-d)^dagger log(1-f),

    where f is a field in data space (d.domain) with the expected frequencies of 
    events. 
    """     
229
    def __init__(self, d):
Martin Reinecke's avatar
Martin Reinecke committed
230
        self._d = d
231
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
232
233

    def apply(self, x):
234
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
235
236
        v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
237
            return Field.scalar(v)
238
239
        if not x.want_metric:
            return v
Martin Reinecke's avatar
Martin Reinecke committed
240
241
242
243
244
245
        met = makeOp(1./(x.val*(1.-x.val)))
        met = SandwichOperator.make(x.jac, met)
        return v.add_metric(met)


class Hamiltonian(EnergyOperator):
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    """Class for information Hamiltonians.

    Parameters
    ----------
    lh : EnergyOperator
         a likelihood energy
    ic_samp : IterationController 
              is passed to SamplingEnabler to draw Gaussian distributed samples
              with covariance = metric of Hamiltonian 
                   (= Hessian without terms that generate negative eigenvalues)
              default = None

    Usage
    -----
    H = Hamiltonian(E_lh) represents

    H(f) = 0.5 f^dagger f + E_lh(f)

    an information Hamiltonian for a field f with a white Gaussian prior 
    (unit covariance) and the likelihood energy E_lh.

    Tip
    ---
    Other field priors can be represented via transformations of a white 
    Gaussian field into a field with the desired prior probability structure.

    By implementing prior information this way, the field prior is represented 
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

    For more details see: 
    "Encoding prior knowledge in the structure of the likelihood"
    Jakob Knollmüller, Torsten A. Ensslin, submitted, arXiv:1812.04403
    https://arxiv.org/abs/1812.04403
    """ 
Martin Reinecke's avatar
Martin Reinecke committed
281
282
283
284
    def __init__(self, lh, ic_samp=None):
        self._lh = lh
        self._prior = GaussianEnergy(domain=lh.domain)
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
285
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
286
287

    def apply(self, x):
288
        self._check_input(x)
289
290
        if (self._ic_samp is None or not isinstance(x, Linearization) or
                not x.want_metric):
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
291
            return self._lh(x)+self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
292
        else:
293
            lhx, prx = self._lh(x), self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
294
295
296
297
            mtr = SamplingEnabler(lhx.metric, prx.metric.inverse,
                                  self._ic_samp, prx.metric.inverse)
            return (lhx+prx).add_metric(mtr)

Philipp Arras's avatar
Philipp Arras committed
298
299
300
301
302
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
        subs += '\nPrior: Quadratic{}'.format(self._lh.domain.keys())
        return 'Hamiltonian:\n' + utilities.indent(subs)

Martin Reinecke's avatar
Martin Reinecke committed
303
304

class SampledKullbachLeiblerDivergence(EnergyOperator):
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    """Class for Kullbach Leibler (KL) Divergence or Gibbs free energies
    
    Precisely a sample averaged Hamiltonian (or other energy) that represents 
    approximatively the relevant part of a KL to be used in Variational Bayes
    inference if the samples are drawn from the approximating Gaussian.

    Let Q(f) = G(f-m,D) Gaussian used to approximate
    P(f|d), the correct posterior with information Hamiltonian
    H(d,f) = - log P(d,f) = - log P(f|d) + const. 

    The KL divergence between those should then be optimized for m. It is

    KL(Q,P) = int Df Q(f) log Q(f)/P(f)
            = < log Q(f) >_Q(f) - < log P(f) >_Q(f)
            = const + < H(f) >_G(f-m,D)
    
    in essence the information Hamiltonian averaged over a Gaussian distribution 
    centered on the mean m.

    SampledKullbachLeiblerDivergence(H) approximates < H(f) >_G(f-m,D) if the 
    residuals f-m are drawn from covariance D.
    
    Parameters
    ----------
    h: Hamiltonian
       the Hamiltonian/energy to be averaged
    res_samples : iterable Field
                  set of residual sample points to be added to mean field
                  for approximate estimation of the KL

    Usage:
    ------
    KL = SampledKullbachLeiblerDivergence(H, samples) represents
    
    KL(m) = sum_i H(m+v_i) / N,

    where v_i are the residual samples, N is their number, and m is the mean field
    around which the samples are drawn.

    Tip:
    ----
    Having symmetrized residual samples, with both, v_i and -v_i being present, 
    ensures that the distribution mean is exactly represented. This reduces sampling 
    noise and helps the numerics of the KL minimization process in the variational 
    Bayes inference. 
    """    
Martin Reinecke's avatar
Martin Reinecke committed
351
352
    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
353
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
354
355
356
        self._res_samples = tuple(res_samples)

    def apply(self, x):
357
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
358
359
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap) * (1./len(self._res_samples))