energy_operators.py 15.6 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
22
from ..multi_domain import MultiDomain
Philipp Arras's avatar
Philipp Arras committed
23
from ..field import Field
24
from ..multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
25
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
26
27
from ..sugar import makeDomain, makeOp
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
28
from .operator import Operator
Martin Reinecke's avatar
fix    
Martin Reinecke committed
29
from .sampling_enabler import SamplingEnabler
Philipp Arras's avatar
Philipp Arras committed
30
from .sandwich_operator import SandwichOperator
31
from .scaling_operator import ScalingOperator
32
from .simple_linear_operators import VdotOperator, FieldAdapter
Martin Reinecke's avatar
Martin Reinecke committed
33
34
35


class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
36
    """Operator which has a scalar domain as target domain.
37

Martin Reinecke's avatar
Martin Reinecke committed
38
    It is intended as an objective function for field inference.
39

Philipp Arras's avatar
Philipp Arras committed
40
41
42
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
43
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
44
       divergence.
45
    """
Martin Reinecke's avatar
Martin Reinecke committed
46
47
48
    _target = DomainTuple.scalar_domain()


49
50
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
51

Philipp Arras's avatar
Philipp Arras committed
52
53
54
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
55
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
56
    """
Philipp Arras's avatar
Philipp Arras committed
57

Martin Reinecke's avatar
Martin Reinecke committed
58
59
60
61
    def __init__(self, domain):
        self._domain = domain

    def apply(self, x):
62
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
63
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
64
            val = Field.scalar(x.val.vdot(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
65
            jac = VdotOperator(2*x.val)(x.jac)
66
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
67
        return Field.scalar(x.vdot(x))
Martin Reinecke's avatar
Martin Reinecke committed
68

Martin Reinecke's avatar
Martin Reinecke committed
69

Martin Reinecke's avatar
Martin Reinecke committed
70
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
71
    """Computes the L2-norm of a Field or MultiField with respect to a
72
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
73
74
75

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
76
77
78

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
79
    endo : EndomorphicOperator
80
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
81
    """
Philipp Arras's avatar
Philipp Arras committed
82
83

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
84
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
85
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
86
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
87
88
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
89
90

    def apply(self, x):
91
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
92
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
93
94
            t1 = self._op(x.val)
            jac = VdotOperator(t1)(x.jac)
Martin Reinecke's avatar
Martin Reinecke committed
95
            val = Field.scalar(0.5*x.val.vdot(t1))
96
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
97
        return Field.scalar(0.5*x.vdot(self._op(x)))
Martin Reinecke's avatar
Martin Reinecke committed
98

99
100
101
class VariableCovarianceGaussianEnergy(EnergyOperator):
    """Computes a negative-log Gaussian with unknown covariance.

102
    Represents up to constants in :math:`s`:
103
104
105
106
107
108
109
110
111

    .. math ::
        E(f) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s),

    an information energy for a Gaussian distribution with residual s and
    covariance D.

    Parameters
    ----------
112
113
114
115
    domain : Domain, DomainTuple, tuple of Domain
        Operator domain. By default it is inferred from `s` or
        `covariance` if specified

116
117
    residual : key
        residual of the Gaussian. 
118
    
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    inverse_covariance : key
        Inverse covariance of the Gaussian. 

    """

    def __init__(self, domain, residual, inverse_covariance):
        self._residual = residual
        self._icov = inverse_covariance
        self._domain = MultiDomain.make({self._residual:domain,
            self._icov:domain})
        self._singledom = domain

    def apply(self, x):
        self._check_input(x)
        lin = isinstance(x, Linearization)
        xval = x.val if lin else x
        res = .5*xval[self._residual].vdot(xval[self._residual]*xval[self._icov])\
                - .5*xval[self._icov].log().sum()
        if not lin:
            return res
139

140
141
142
143
        FA_res = FieldAdapter(self._singledom, self._residual)
        FA_sig = FieldAdapter(self._singledom, self._icov)
        jac_res = xval[self._residual]*xval[self._icov]
        jac_res = VdotOperator(jac_res)(FA_res)
144
145

        # So here we are varying w.r.t. inverse covariance
146
147
        jac_sig = .5*(xval[self._residual].absolute()**2)
        jac_sig = VdotOperator(jac_sig)(FA_sig)
148
        jac_sig = jac_sig - .5*VdotOperator(1./xval[self._icov])(FA_sig)
149
        jac = (jac_sig + jac_res)(x.jac)
150
151

        res = x.new(Field.scalar(res), jac)
152
153
154
155
156
157
158
159
160
161
        if not x.want_metric:
            return res
        mf = {self._residual:xval[self._icov],
                self._icov:.5*xval[self._icov]**(-2)}
        mf = MultiField.from_dict(mf)
        metric = makeOp(mf)
        metric = SandwichOperator(x.jac, metric)
        return res.add_metric(metric)


Martin Reinecke's avatar
Martin Reinecke committed
162
163

class GaussianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
164
    """Computes a negative-log Gaussian.
165

Philipp Arras's avatar
Philipp Arras committed
166
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
167

Philipp Arras's avatar
Philipp Arras committed
168
169
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
170

Philipp Arras's avatar
Philipp Arras committed
171
172
    an information energy for a Gaussian distribution with mean m and
    covariance D.
173

Philipp Arras's avatar
Philipp Arras committed
174
175
176
177
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
178
179
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
180
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
181
182
183
184
185
186
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
187
    """
Martin Reinecke's avatar
Martin Reinecke committed
188

189
    def __init__(self, mean=None, inverse_covariance=None, domain=None):
Martin Reinecke's avatar
Martin Reinecke committed
190
191
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
192
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
193
194
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
195
196
197
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
198
199
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
200
201
202
203
204
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
205
        if inverse_covariance is None:
206
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Martin Reinecke's avatar
Martin Reinecke committed
207
        else:
208
209
            self._op = QuadraticFormOperator(inverse_covariance)
        self._icov = None if inverse_covariance is None else inverse_covariance
Martin Reinecke's avatar
Martin Reinecke committed
210
211

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
212
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
213
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
214
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
215
        else:
Philipp Arras's avatar
Philipp Arras committed
216
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
217
218
219
                raise ValueError("domain mismatch")

    def apply(self, x):
220
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
221
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
222
        res = self._op(residual).real
223
        if not isinstance(x, Linearization) or not x.want_metric:
Martin Reinecke's avatar
Martin Reinecke committed
224
225
226
227
228
229
            return res
        metric = SandwichOperator.make(x.jac, self._icov)
        return res.add_metric(metric)


class PoissonianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
230
231
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
232

Philipp Arras's avatar
Philipp Arras committed
233
    Represents up to an f-independent term :math:`log(d!)`:
234

Philipp Arras's avatar
Philipp Arras committed
235
236
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
237

Philipp Arras's avatar
Philipp Arras committed
238
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
239
    the counts.
Philipp Arras's avatar
Philipp Arras committed
240
241
242
243
244
245

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
246
    """
Philipp Arras's avatar
Philipp Arras committed
247

248
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
249
250
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
251
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
252
            raise ValueError
253
254
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
255
256

    def apply(self, x):
257
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
258
        res = x.sum()
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
259
        tmp = res.val.val if isinstance(res, Linearization) else res
Martin Reinecke's avatar
Martin Reinecke committed
260
261
        # if we have no infinity here, we can continue with the calculation;
        # otherwise we know that the result must also be infinity
Martin Reinecke's avatar
Martin Reinecke committed
262
        if not np.isinf(tmp):
Martin Reinecke's avatar
Martin Reinecke committed
263
            res = res - x.log().vdot(self._d)
Martin Reinecke's avatar
Martin Reinecke committed
264
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
265
            return Field.scalar(res)
266
267
        if not x.want_metric:
            return res
Martin Reinecke's avatar
Martin Reinecke committed
268
269
270
        metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
        return res.add_metric(metric)

271

272
class InverseGammaLikelihood(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
273
    """Computes the negative log-likelihood of the inverse gamma distribution.
274
275
276

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
277
278
279
280
281
282
283
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
284
285
286
287
288
289
290

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
291
    """
Philipp Arras's avatar
Philipp Arras committed
292

293
294
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
295
            raise TypeError
296
297
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
298
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
299
300
301
302
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
        self._domain = DomainTuple.make(beta.domain)
303
304

    def apply(self, x):
305
        self._check_input(x)
306
        res = x.log().vdot(self._alphap1) + (1./x).vdot(self._beta)
307
308
        if not isinstance(x, Linearization):
            return Field.scalar(res)
309
310
        if not x.want_metric:
            return res
311
        metric = SandwichOperator.make(x.jac, makeOp(self._alphap1/(x.val**2)))
312
313
314
        return res.add_metric(metric)


315
class StudentTEnergy(EnergyOperator):
Lukas Platz's avatar
Lukas Platz committed
316
    """Computes likelihood energy corresponding to Student's t-distribution.
317
318

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
319
320
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
321

Lukas Platz's avatar
Lukas Platz committed
322
    where f is a field defined on `domain`.
323
324
325

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
326
327
    domain : `Domain` or `DomainTuple`
        Domain of the operator
328
329
330
331
332
333
334
335
336
337
    theta : Scalar
        Degree of freedom parameter for the student t distribution
    """

    def __init__(self, domain, theta):
        self._domain = DomainTuple.make(domain)
        self._theta = theta

    def apply(self, x):
        self._check_input(x)
338
        v = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
339
340
341
342
        if not isinstance(x, Linearization):
            return Field.scalar(v)
        if not x.want_metric:
            return v
343
        met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3))
344
345
346
347
        met = SandwichOperator.make(x.jac, met)
        return v.add_metric(met)


Martin Reinecke's avatar
Martin Reinecke committed
348
class BernoulliEnergy(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
349
    """Computes likelihood energy of expected event frequency constrained by
350
351
    event data.

Philipp Arras's avatar
Philipp Arras committed
352
353
354
355
356
357
358
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

359
360
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
361
    d : Field
Philipp Arras's avatar
Philipp Arras committed
362
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
363
    """
Philipp Arras's avatar
Philipp Arras committed
364

365
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
366
367
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
368
        if not np.all(np.logical_or(d.val == 0, d.val == 1)):
Philipp Arras's avatar
Philipp Arras committed
369
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
370
        self._d = d
371
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
372
373

    def apply(self, x):
374
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
375
        v = -(x.log().vdot(self._d) + (1. - x).log().vdot(1. - self._d))
Martin Reinecke's avatar
Martin Reinecke committed
376
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
377
            return Field.scalar(v)
378
379
        if not x.want_metric:
            return v
Philipp Arras's avatar
Philipp Arras committed
380
        met = makeOp(1./(x.val*(1. - x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
381
382
383
384
        met = SandwichOperator.make(x.jac, met)
        return v.add_metric(met)


385
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
386
387
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
388

Philipp Arras's avatar
Philipp Arras committed
389
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
390

Philipp Arras's avatar
Philipp Arras committed
391
392
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
393

Martin Reinecke's avatar
Martin Reinecke committed
394
    Other field priors can be represented via transformations of a white
395
396
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
397
    By implementing prior information this way, the field prior is represented
398
399
400
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
401
402
403
404
405
406
407
408
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
409
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
410
411
412
413
414
415
        to use to draw Gaussian samples.

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
416
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
417
    """
Philipp Arras's avatar
Philipp Arras committed
418

419
    def __init__(self, lh, ic_samp=None, _c_inp=None):
Martin Reinecke's avatar
Martin Reinecke committed
420
421
        self._lh = lh
        self._prior = GaussianEnergy(domain=lh.domain)
422
423
        if _c_inp is not None:
            _, self._prior = self._prior.simplify_for_constant_input(_c_inp)
Martin Reinecke's avatar
Martin Reinecke committed
424
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
425
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
426
427

    def apply(self, x):
428
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
429
430
431
        if (self._ic_samp is None or not isinstance(x, Linearization)
                or not x.want_metric):
            return self._lh(x) + self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
432
        else:
433
            lhx, prx = self._lh(x), self._prior(x)
434
435
            mtr = SamplingEnabler(lhx.metric, prx.metric,
                                  self._ic_samp)
Philipp Arras's avatar
Philipp Arras committed
436
            return (lhx + prx).add_metric(mtr)
Martin Reinecke's avatar
Martin Reinecke committed
437

Philipp Arras's avatar
Philipp Arras committed
438
439
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
440
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
441
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
442

443
444
445
446
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)

Martin Reinecke's avatar
Martin Reinecke committed
447

Martin Reinecke's avatar
Martin Reinecke committed
448
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
449
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
450

451
452
453
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
454
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
455
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
456
457
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
458

Philipp Arras's avatar
Docs    
Philipp Arras committed
459
460
461
462
463
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
464

Philipp Arras's avatar
Docs    
Philipp Arras committed
465
466
467
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
468
    """
Martin Reinecke's avatar
Martin Reinecke committed
469
470
471

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
472
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
473
474
475
        self._res_samples = tuple(res_samples)

    def apply(self, x):
476
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
477
478
        mymap = map(lambda v: self._h(x + v), self._res_samples)
        return utilities.my_sum(mymap)*(1./len(self._res_samples))