energy_operators.py 15.3 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
22
from ..multi_domain import MultiDomain
Philipp Arras's avatar
Philipp Arras committed
23
from ..field import Field
24
from ..multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
25
from ..linearization import Linearization
26
from ..sugar import makeDomain, makeOp, full
Philipp Arras's avatar
Philipp Arras committed
27
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
28
from .operator import Operator
Martin Reinecke's avatar
fix    
Martin Reinecke committed
29
from .sampling_enabler import SamplingEnabler
Philipp Arras's avatar
Philipp Arras committed
30
from .sandwich_operator import SandwichOperator
31
from .scaling_operator import ScalingOperator
32
from .simple_linear_operators import VdotOperator, FieldAdapter
Martin Reinecke's avatar
Martin Reinecke committed
33
34
35


class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
36
    """Operator which has a scalar domain as target domain.
37

Martin Reinecke's avatar
Martin Reinecke committed
38
    It is intended as an objective function for field inference.
39

Philipp Arras's avatar
Philipp Arras committed
40
41
42
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
43
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
44
       divergence.
45
    """
Martin Reinecke's avatar
Martin Reinecke committed
46
47
48
    _target = DomainTuple.scalar_domain()


49
50
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
51

Philipp Arras's avatar
Philipp Arras committed
52
53
54
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
55
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
56
    """
Philipp Arras's avatar
Philipp Arras committed
57

Martin Reinecke's avatar
Martin Reinecke committed
58
59
60
61
    def __init__(self, domain):
        self._domain = domain

    def apply(self, x):
62
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
63
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
64
            val = Field.scalar(x.val.vdot(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
65
            jac = VdotOperator(2*x.val)(x.jac)
66
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
67
        return Field.scalar(x.vdot(x))
Martin Reinecke's avatar
Martin Reinecke committed
68

Martin Reinecke's avatar
Martin Reinecke committed
69

Martin Reinecke's avatar
Martin Reinecke committed
70
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
71
    """Computes the L2-norm of a Field or MultiField with respect to a
72
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
73
74
75

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
76
77
78

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
79
    endo : EndomorphicOperator
80
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
81
    """
Philipp Arras's avatar
Philipp Arras committed
82
83

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
84
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
85
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
86
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
87
88
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
89
90

    def apply(self, x):
91
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
92
        if isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
93
94
            t1 = self._op(x.val)
            jac = VdotOperator(t1)(x.jac)
Martin Reinecke's avatar
Martin Reinecke committed
95
            val = Field.scalar(0.5*x.val.vdot(t1))
96
            return x.new(val, jac)
Martin Reinecke's avatar
Martin Reinecke committed
97
        return Field.scalar(0.5*x.vdot(self._op(x)))
Martin Reinecke's avatar
Martin Reinecke committed
98

Philipp Arras's avatar
Philipp Arras committed
99

100
class VariableCovarianceGaussianEnergy(EnergyOperator):
Reimar Leike's avatar
Reimar Leike committed
101
    """Computes the negative log pdf of a Gaussian with unknown covariance.
102

Reimar Leike's avatar
Reimar Leike committed
103
    The covariance is assumed to be diagonal.
104
105

    .. math ::
Reimar Leike's avatar
Reimar Leike committed
106
        E(s,D) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s) + 0.5 tr log(D),
107
108

    an information energy for a Gaussian distribution with residual s and
109
    diagonal covariance D.
Reimar Leike's avatar
Reimar Leike committed
110
111
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
112
113
114

    Parameters
    ----------
115
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
116
        domain of the residual and domain of the covariance diagonal.
117

118
    residual : key
Philipp Arras's avatar
Philipp Arras committed
119
        Residual key of the Gaussian.
120

Philipp Arras's avatar
Philipp Arras committed
121
    inverse_covariance : key
122
        Inverse covariance diagonal key of the Gaussian.
123
124
    """

Philipp Arras's avatar
Philipp Arras committed
125
126
127
128
129
    def __init__(self, domain, residual_key, inverse_covariance_key):
        self._r = str(residual_key)
        self._icov = str(inverse_covariance_key)
        dom = DomainTuple.make(domain)
        self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
130
131
132

    def apply(self, x):
        self._check_input(x)
133
134
135
136
        from .contraction_operator import ContractionOperator
        lin = isinstance(x, Linearization)
        r = FieldAdapter(self._domain[self._r], self._r)
        icov = FieldAdapter(self._domain[self._icov], self._icov)
137
138
        res0 = r.vdot(r*icov).real
        res1 = icov.log().sum()
Philipp Arras's avatar
Philipp Arras committed
139
        res = 0.5*(res0-res1)
140
141
        res = res(x)
        if not lin:
Philipp Arras's avatar
Philipp Arras committed
142
143
144
            return Field.scalar(res)
        if not x.want_metric:
            return res
Philipp Arras's avatar
Philipp Arras committed
145
146
        mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
        metric = makeOp(MultiField.from_dict(mf))
Philipp Arras's avatar
Fixup    
Philipp Arras committed
147
        return res.add_metric(SandwichOperator.make(x.jac, metric))
148

Martin Reinecke's avatar
Martin Reinecke committed
149
150

class GaussianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
151
    """Computes a negative-log Gaussian.
152

Philipp Arras's avatar
Philipp Arras committed
153
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
154

Philipp Arras's avatar
Philipp Arras committed
155
156
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
157

Philipp Arras's avatar
Philipp Arras committed
158
159
    an information energy for a Gaussian distribution with mean m and
    covariance D.
160

Philipp Arras's avatar
Philipp Arras committed
161
162
163
164
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
165
166
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
167
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
168
169
170
171
172
173
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
174
    """
Martin Reinecke's avatar
Martin Reinecke committed
175

176
    def __init__(self, mean=None, inverse_covariance=None, domain=None):
Martin Reinecke's avatar
Martin Reinecke committed
177
178
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
179
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
180
181
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
182
183
184
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
185
186
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
187
188
189
190
191
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
192
        if inverse_covariance is None:
193
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Martin Reinecke's avatar
Martin Reinecke committed
194
        else:
195
196
            self._op = QuadraticFormOperator(inverse_covariance)
        self._icov = None if inverse_covariance is None else inverse_covariance
Martin Reinecke's avatar
Martin Reinecke committed
197
198

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
199
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
200
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
201
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
202
        else:
Philipp Arras's avatar
Philipp Arras committed
203
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
204
205
206
                raise ValueError("domain mismatch")

    def apply(self, x):
207
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
208
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
209
        res = self._op(residual).real
210
        if not isinstance(x, Linearization) or not x.want_metric:
Martin Reinecke's avatar
Martin Reinecke committed
211
212
213
214
215
216
            return res
        metric = SandwichOperator.make(x.jac, self._icov)
        return res.add_metric(metric)


class PoissonianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
217
218
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
219

Philipp Arras's avatar
Philipp Arras committed
220
    Represents up to an f-independent term :math:`log(d!)`:
221

Philipp Arras's avatar
Philipp Arras committed
222
223
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
224

Philipp Arras's avatar
Philipp Arras committed
225
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
226
    the counts.
Philipp Arras's avatar
Philipp Arras committed
227
228
229
230
231
232

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
233
    """
Philipp Arras's avatar
Philipp Arras committed
234

235
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
236
237
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
238
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
239
            raise ValueError
240
241
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
242
243

    def apply(self, x):
244
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
245
        res = x.sum()
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
246
        tmp = res.val.val if isinstance(res, Linearization) else res
Martin Reinecke's avatar
Martin Reinecke committed
247
248
        # if we have no infinity here, we can continue with the calculation;
        # otherwise we know that the result must also be infinity
Martin Reinecke's avatar
Martin Reinecke committed
249
        if not np.isinf(tmp):
Martin Reinecke's avatar
Martin Reinecke committed
250
            res = res - x.log().vdot(self._d)
Martin Reinecke's avatar
Martin Reinecke committed
251
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
252
            return Field.scalar(res)
253
254
        if not x.want_metric:
            return res
Martin Reinecke's avatar
Martin Reinecke committed
255
256
257
        metric = SandwichOperator.make(x.jac, makeOp(1./x.val))
        return res.add_metric(metric)

258

259
class InverseGammaLikelihood(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
260
    """Computes the negative log-likelihood of the inverse gamma distribution.
261
262
263

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
264
265
266
267
268
269
270
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
271
272
273
274
275
276
277

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
278
    """
Philipp Arras's avatar
Philipp Arras committed
279

280
281
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
282
            raise TypeError
283
284
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
285
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
286
287
288
289
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
        self._domain = DomainTuple.make(beta.domain)
290
291

    def apply(self, x):
292
        self._check_input(x)
293
        res = x.log().vdot(self._alphap1) + (1./x).vdot(self._beta)
294
295
        if not isinstance(x, Linearization):
            return Field.scalar(res)
296
297
        if not x.want_metric:
            return res
298
        metric = SandwichOperator.make(x.jac, makeOp(self._alphap1/(x.val**2)))
299
300
301
        return res.add_metric(metric)


302
class StudentTEnergy(EnergyOperator):
Lukas Platz's avatar
Lukas Platz committed
303
    """Computes likelihood energy corresponding to Student's t-distribution.
304
305

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
306
307
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
308

Lukas Platz's avatar
Lukas Platz committed
309
    where f is a field defined on `domain`.
310
311
312

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
313
314
    domain : `Domain` or `DomainTuple`
        Domain of the operator
315
316
317
318
319
320
321
322
323
324
    theta : Scalar
        Degree of freedom parameter for the student t distribution
    """

    def __init__(self, domain, theta):
        self._domain = DomainTuple.make(domain)
        self._theta = theta

    def apply(self, x):
        self._check_input(x)
325
        v = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
326
327
328
329
        if not isinstance(x, Linearization):
            return Field.scalar(v)
        if not x.want_metric:
            return v
330
        met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3))
331
332
333
334
        met = SandwichOperator.make(x.jac, met)
        return v.add_metric(met)


Martin Reinecke's avatar
Martin Reinecke committed
335
class BernoulliEnergy(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
336
    """Computes likelihood energy of expected event frequency constrained by
337
338
    event data.

Philipp Arras's avatar
Philipp Arras committed
339
340
341
342
343
344
345
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

346
347
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
348
    d : Field
Philipp Arras's avatar
Philipp Arras committed
349
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
350
    """
Philipp Arras's avatar
Philipp Arras committed
351

352
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
353
354
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
355
        if not np.all(np.logical_or(d.val == 0, d.val == 1)):
Philipp Arras's avatar
Philipp Arras committed
356
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
357
        self._d = d
358
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
359
360

    def apply(self, x):
361
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
362
        v = -(x.log().vdot(self._d) + (1. - x).log().vdot(1. - self._d))
Martin Reinecke's avatar
Martin Reinecke committed
363
        if not isinstance(x, Linearization):
Martin Reinecke's avatar
Martin Reinecke committed
364
            return Field.scalar(v)
365
366
        if not x.want_metric:
            return v
Philipp Arras's avatar
Philipp Arras committed
367
        met = makeOp(1./(x.val*(1. - x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
368
369
370
371
        met = SandwichOperator.make(x.jac, met)
        return v.add_metric(met)


372
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
373
374
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
375

Philipp Arras's avatar
Philipp Arras committed
376
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
377

Philipp Arras's avatar
Philipp Arras committed
378
379
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
380

Martin Reinecke's avatar
Martin Reinecke committed
381
    Other field priors can be represented via transformations of a white
382
383
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
384
    By implementing prior information this way, the field prior is represented
385
386
387
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
388
389
390
391
392
393
394
395
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
396
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
397
398
399
400
401
402
        to use to draw Gaussian samples.

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
403
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
404
    """
Philipp Arras's avatar
Philipp Arras committed
405

406
    def __init__(self, lh, ic_samp=None, _c_inp=None):
Martin Reinecke's avatar
Martin Reinecke committed
407
408
        self._lh = lh
        self._prior = GaussianEnergy(domain=lh.domain)
409
410
        if _c_inp is not None:
            _, self._prior = self._prior.simplify_for_constant_input(_c_inp)
Martin Reinecke's avatar
Martin Reinecke committed
411
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
412
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
413
414

    def apply(self, x):
415
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
416
417
418
        if (self._ic_samp is None or not isinstance(x, Linearization)
                or not x.want_metric):
            return self._lh(x) + self._prior(x)
Martin Reinecke's avatar
Martin Reinecke committed
419
        else:
420
            lhx, prx = self._lh(x), self._prior(x)
421
422
            mtr = SamplingEnabler(lhx.metric, prx.metric,
                                  self._ic_samp)
Philipp Arras's avatar
Philipp Arras committed
423
            return (lhx + prx).add_metric(mtr)
Martin Reinecke's avatar
Martin Reinecke committed
424

Philipp Arras's avatar
Philipp Arras committed
425
426
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
427
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
428
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
429

430
431
432
433
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)

Martin Reinecke's avatar
Martin Reinecke committed
434

Martin Reinecke's avatar
Martin Reinecke committed
435
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
436
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
437

438
439
440
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
441
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
442
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
443
444
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
445

Philipp Arras's avatar
Docs    
Philipp Arras committed
446
447
448
449
450
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
451

Philipp Arras's avatar
Docs    
Philipp Arras committed
452
453
454
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
455
    """
Martin Reinecke's avatar
Martin Reinecke committed
456
457
458

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
459
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
460
461
462
        self._res_samples = tuple(res_samples)

    def apply(self, x):
463
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
464
465
        mymap = map(lambda v: self._h(x + v), self._res_samples)
        return utilities.my_sum(mymap)*(1./len(self._res_samples))