energy_operators.py 16 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2020 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
22
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
23
24
25
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
Philipp Arras's avatar
Philipp Arras committed
26
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
27
from .operator import Operator
28
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
29
from .scaling_operator import ScalingOperator
Philipp Arras's avatar
Philipp Arras committed
30
from .simple_linear_operators import VdotOperator
Philipp Arras's avatar
Philipp Arras committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


def _check_sampling_dtype(domain, dtypes):
    if dtypes is None:
        return
    if isinstance(domain, DomainTuple):
        dtypes = {'': dtypes}
    elif isinstance(domain, MultiDomain):
        if dtypes in [np.float64, np.complex128]:
            return
        dtypes = dtypes.values()
        if set(domain.keys()) != set(dtypes.keys()):
            raise ValueError
    else:
        raise TypeError
    for dt in dtypes.values():
        if dt not in [np.float64, np.complex128]:
            raise ValueError


def _field_to_dtype(field):
    if isinstance(field, Field):
        dt = field.dtype
    elif isinstance(field, MultiField):
        dt = {kk: ff.dtype for kk, ff in field.items()}
    else:
        raise TypeError
    _check_sampling_dtype(field.domain, dt)
    return dt


Martin Reinecke's avatar
Martin Reinecke committed
62
class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
63
    """Operator which has a scalar domain as target domain.
64

Martin Reinecke's avatar
Martin Reinecke committed
65
    It is intended as an objective function for field inference.
66

Philipp Arras's avatar
Philipp Arras committed
67
68
69
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
70
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
71
       divergence.
72
    """
Martin Reinecke's avatar
Martin Reinecke committed
73
74
75
    _target = DomainTuple.scalar_domain()


76
77
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
78

Philipp Arras's avatar
Philipp Arras committed
79
80
81
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
82
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
83
    """
Philipp Arras's avatar
Philipp Arras committed
84

Martin Reinecke's avatar
Martin Reinecke committed
85
86
87
    def __init__(self, domain):
        self._domain = domain

Philipp Arras's avatar
Philipp Arras committed
88
    def apply(self, x):
89
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
90
91
        if x.jac is None:
            return x.vdot(x)
Philipp Arras's avatar
Philipp Arras committed
92
93
        res = x.val.vdot(x.val)
        return x.new(res, VdotOperator(2*x.val))
Martin Reinecke's avatar
Martin Reinecke committed
94

Martin Reinecke's avatar
Martin Reinecke committed
95

Martin Reinecke's avatar
Martin Reinecke committed
96
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
97
    """Computes the L2-norm of a Field or MultiField with respect to a
98
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
99
100
101

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
102
103
104

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
105
    endo : EndomorphicOperator
106
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
107
    """
Philipp Arras's avatar
Philipp Arras committed
108
109

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
110
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
111
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
112
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
113
114
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
115

Philipp Arras's avatar
Philipp Arras committed
116
    def apply(self, x):
117
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
118
        if x.jac is None:
Philipp Arras's avatar
Philipp Arras committed
119
120
121
            return 0.5*x.vdot(self._op(x))
        res = 0.5*x.val.vdot(self._op(x.val))
        return x.new(res, VdotOperator(self._op(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
122

Philipp Arras's avatar
Philipp Arras committed
123

124
class VariableCovarianceGaussianEnergy(EnergyOperator):
Reimar Leike's avatar
Reimar Leike committed
125
    """Computes the negative log pdf of a Gaussian with unknown covariance.
126

Reimar Leike's avatar
Reimar Leike committed
127
    The covariance is assumed to be diagonal.
128
129

    .. math ::
Reimar Leike's avatar
Reimar Leike committed
130
        E(s,D) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s) + 0.5 tr log(D),
131
132

    an information energy for a Gaussian distribution with residual s and
133
    diagonal covariance D.
Reimar Leike's avatar
Reimar Leike committed
134
135
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
136
137
138

    Parameters
    ----------
139
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
140
        domain of the residual and domain of the covariance diagonal.
141

142
    residual : key
Philipp Arras's avatar
Philipp Arras committed
143
        Residual key of the Gaussian.
144

Philipp Arras's avatar
Philipp Arras committed
145
    inverse_covariance : key
146
        Inverse covariance diagonal key of the Gaussian.
147
148
    """

Philipp Arras's avatar
Philipp Arras committed
149
    def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype):
Philipp Arras's avatar
Philipp Arras committed
150
151
152
153
        self._r = str(residual_key)
        self._icov = str(inverse_covariance_key)
        dom = DomainTuple.make(domain)
        self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
Philipp Arras's avatar
Philipp Arras committed
154
155
        self._sampling_dtype = sampling_dtype
        _check_sampling_dtype(self._domain, sampling_dtype)
156

Philipp Arras's avatar
Philipp Arras committed
157
    def apply(self, x):
158
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
159
        res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].ptw("log").sum())
Martin Reinecke's avatar
more    
Martin Reinecke committed
160
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
161
            return res
Philipp Arras's avatar
Philipp Arras committed
162
        mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
Philipp Arras's avatar
Philipp Arras committed
163
        met = makeOp(MultiField.from_dict(mf))
164
        return res.add_metric(SamplingDtypeSetter(met, self._sampling_dtype))
165

Martin Reinecke's avatar
Martin Reinecke committed
166
167

class GaussianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
168
    """Computes a negative-log Gaussian.
169

Philipp Arras's avatar
Philipp Arras committed
170
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
171

Philipp Arras's avatar
Philipp Arras committed
172
173
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
174

Philipp Arras's avatar
Philipp Arras committed
175
176
    an information energy for a Gaussian distribution with mean m and
    covariance D.
177

Philipp Arras's avatar
Philipp Arras committed
178
179
180
181
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
182
183
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
184
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
185
186
187
188
189
190
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
191
    """
Martin Reinecke's avatar
Martin Reinecke committed
192

Philipp Arras's avatar
Philipp Arras committed
193
    def __init__(self, mean=None, inverse_covariance=None, domain=None, sampling_dtype=None):
Martin Reinecke's avatar
Martin Reinecke committed
194
195
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
196
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
197
198
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
199
200
201
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
202
203
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
204
205
206
207
208
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Philipp Arras's avatar
Philipp Arras committed
209
210
211
212
213
214
215
216
217
218
219

        # Infer sampling dtype
        if self._mean is None:
            _check_sampling_dtype(self._domain, sampling_dtype)
        else:
            if sampling_dtype is None:
                sampling_dtype = _field_to_dtype(self._mean)
            else:
                if sampling_dtype != _field_to_dtype(self._mean):
                    raise ValueError("Sampling dtype and mean not compatible")

220
        if inverse_covariance is None:
221
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Philipp Arras's avatar
Philipp Arras committed
222
            self._met = ScalingOperator(self._domain, 1)
Martin Reinecke's avatar
Martin Reinecke committed
223
        else:
224
            self._op = QuadraticFormOperator(inverse_covariance)
Philipp Arras's avatar
Philipp Arras committed
225
            self._met = inverse_covariance
Philipp Arras's avatar
Philipp Arras committed
226
        if sampling_dtype is not None:
227
            self._met = SamplingDtypeSetter(self._met, sampling_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
228
229

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
230
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
231
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
232
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
233
        else:
Philipp Arras's avatar
Philipp Arras committed
234
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
235
236
                raise ValueError("domain mismatch")

Philipp Arras's avatar
Philipp Arras committed
237
    def apply(self, x):
238
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
239
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
240
        res = self._op(residual).real
Martin Reinecke's avatar
more    
Martin Reinecke committed
241
        if x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
242
243
            return res.add_metric(self._met)
        return res
Martin Reinecke's avatar
Martin Reinecke committed
244
245
246


class PoissonianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
247
248
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
249

Philipp Arras's avatar
Philipp Arras committed
250
    Represents up to an f-independent term :math:`log(d!)`:
251

Philipp Arras's avatar
Philipp Arras committed
252
253
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
254

Philipp Arras's avatar
Philipp Arras committed
255
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
256
    the counts.
Philipp Arras's avatar
Philipp Arras committed
257
258
259
260
261
262

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
263
    """
Philipp Arras's avatar
Philipp Arras committed
264

265
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
266
267
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
268
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
269
            raise ValueError
270
271
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
272

Philipp Arras's avatar
Philipp Arras committed
273
    def apply(self, x):
274
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
275
        res = x.sum() - x.ptw("log").vdot(self._d)
Martin Reinecke's avatar
more    
Martin Reinecke committed
276
        if not x.want_metric:
277
            return res
278
        return res.add_metric(SamplingDtypeSetter(makeOp(1./x.val), np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
279

280

281
class InverseGammaLikelihood(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
282
    """Computes the negative log-likelihood of the inverse gamma distribution.
283
284
285

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
286
287
288
289
290
291
292
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
293
294
295
296
297
298
299

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
300
    """
Philipp Arras's avatar
Philipp Arras committed
301

302
303
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
304
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
305
        self._domain = DomainTuple.make(beta.domain)
306
307
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
308
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
309
310
311
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
Philipp Arras's avatar
Philipp Arras committed
312
313
314
315
        if not self._beta.dtype == np.float64:
            # FIXME Add proper complex support for this energy
            raise TypeError
        self._sampling_dtype = _field_to_dtype(self._beta)
316

Philipp Arras's avatar
Philipp Arras committed
317
    def apply(self, x):
318
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
319
        res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
Martin Reinecke's avatar
more    
Martin Reinecke committed
320
        if not x.want_metric:
321
            return res
Philipp Arras's avatar
Philipp Arras committed
322
323
        met = makeOp(self._alphap1/(x.val**2))
        if self._sampling_dtype is not None:
324
            met = SamplingDtypeSetter(met, self._sampling_dtype)
Philipp Arras's avatar
Philipp Arras committed
325
        return res.add_metric(met)
326
327


328
class StudentTEnergy(EnergyOperator):
Lukas Platz's avatar
Lukas Platz committed
329
    """Computes likelihood energy corresponding to Student's t-distribution.
330
331

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
332
333
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
334

Lukas Platz's avatar
Lukas Platz committed
335
    where f is a field defined on `domain`.
336
337
338

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
339
340
    domain : `Domain` or `DomainTuple`
        Domain of the operator
Reimar Leike's avatar
Reimar Leike committed
341
    theta : Scalar or Field
342
343
344
        Degree of freedom parameter for the student t distribution
    """

Philipp Arras's avatar
Philipp Arras committed
345
    def __init__(self, domain, theta, sampling_dtype=np.float64):
346
347
        self._domain = DomainTuple.make(domain)
        self._theta = theta
Philipp Arras's avatar
Philipp Arras committed
348
349
350
        self._sampling_dtype = sampling_dtype
        if sampling_dtype == np.complex128:
            raise NotImplementedError('Complex data not supported yet')
351

Philipp Arras's avatar
Philipp Arras committed
352
    def apply(self, x):
353
        self._check_input(x)
354
        res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
Martin Reinecke's avatar
more    
Martin Reinecke committed
355
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
356
            return res
357
        met = makeOp((self._theta+1) / (self._theta+3), self.domain)
Philipp Arras's avatar
Philipp Arras committed
358
        if self._sampling_dtype is not None:
359
            met = SamplingDtypeSetter(met, self._sampling_dtype)
Philipp Arras's avatar
Philipp Arras committed
360
        return res.add_metric(met)
361
362


Martin Reinecke's avatar
Martin Reinecke committed
363
class BernoulliEnergy(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
364
    """Computes likelihood energy of expected event frequency constrained by
365
366
    event data.

Philipp Arras's avatar
Philipp Arras committed
367
368
369
370
371
372
373
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

374
375
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
376
    d : Field
Philipp Arras's avatar
Philipp Arras committed
377
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
378
    """
Philipp Arras's avatar
Philipp Arras committed
379

380
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
381
382
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
383
        if np.any(np.logical_and(d.val != 0, d.val != 1)):
Philipp Arras's avatar
Philipp Arras committed
384
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
385
        self._d = d
386
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
387

Philipp Arras's avatar
Philipp Arras committed
388
    def apply(self, x):
389
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
390
        res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
Martin Reinecke's avatar
more    
Martin Reinecke committed
391
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
392
            return res
Philipp Arras's avatar
Philipp Arras committed
393
        met = makeOp(1./(x.val*(1. - x.val)))
394
        return res.add_metric(SamplingDtypeSetter(met, np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
395
396


397
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
398
399
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
400

Philipp Arras's avatar
Philipp Arras committed
401
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
402

Philipp Arras's avatar
Philipp Arras committed
403
404
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
405

Martin Reinecke's avatar
Martin Reinecke committed
406
    Other field priors can be represented via transformations of a white
407
408
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
409
    By implementing prior information this way, the field prior is represented
410
411
412
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
413
414
415
416
417
418
419
420
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
421
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
422
        to use to draw Gaussian samples.
Philipp Arras's avatar
Philipp Arras committed
423
    sampling_dtype : FIXME
Philipp Arras's avatar
Philipp Arras committed
424
425
426
427
428

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
429
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
430
    """
Philipp Arras's avatar
Philipp Arras committed
431

Philipp Arras's avatar
Philipp Arras committed
432
    def __init__(self, lh, ic_samp=None, _c_inp=None, sampling_dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
433
        self._lh = lh
Philipp Arras's avatar
Philipp Arras committed
434
        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=sampling_dtype)
435
436
        if _c_inp is not None:
            _, self._prior = self._prior.simplify_for_constant_input(_c_inp)
Martin Reinecke's avatar
Martin Reinecke committed
437
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
438
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
439

Philipp Arras's avatar
Philipp Arras committed
440
    def apply(self, x):
441
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
442
        if not x.want_metric or self._ic_samp is None:
Philipp Arras's avatar
Philipp Arras committed
443
            return (self._lh + self._prior)(x)
Philipp Arras's avatar
Philipp Arras committed
444
445
        lhx, prx = self._lh(x), self._prior(x)
        return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
Martin Reinecke's avatar
Martin Reinecke committed
446

Philipp Arras's avatar
Philipp Arras committed
447
448
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
449
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
450
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
451

452
453
454
455
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)

Martin Reinecke's avatar
Martin Reinecke committed
456

Martin Reinecke's avatar
Martin Reinecke committed
457
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
458
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
459

460
461
462
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
463
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
464
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
465
466
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
467

Philipp Arras's avatar
Docs    
Philipp Arras committed
468
469
470
471
472
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
473

Philipp Arras's avatar
Docs    
Philipp Arras committed
474
475
476
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
477
    """
Martin Reinecke's avatar
Martin Reinecke committed
478
479
480

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
481
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
482
483
        self._res_samples = tuple(res_samples)

Philipp Arras's avatar
Philipp Arras committed
484
    def apply(self, x):
485
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
486
487
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap)/len(self._res_samples)