energy_operators.py 16.6 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2020 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
22
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
23
24
25
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
Philipp Arras's avatar
Philipp Arras committed
26
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
27
from .operator import Operator
28
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
29
from .scaling_operator import ScalingOperator
30
from .simple_linear_operators import VdotOperator, FieldAdapter
Philipp Arras's avatar
Philipp Arras committed
31
32
33
34
35
36


def _check_sampling_dtype(domain, dtypes):
    if dtypes is None:
        return
    if isinstance(domain, DomainTuple):
Philipp Arras's avatar
Philipp Arras committed
37
38
        np.dtype(dtypes)
        return
Philipp Arras's avatar
Philipp Arras committed
39
    elif isinstance(domain, MultiDomain):
Philipp Arras's avatar
Philipp Arras committed
40
41
42
43
44
45
46
        if isinstance(dtypes, dict):
            for dt in dtypes.values():
                np.dtype(dt)
            if set(domain.keys()) == set(dtypes.keys()):
                return
        else:
            np.dtype(dtypes)
Philipp Arras's avatar
Philipp Arras committed
47
            return
Philipp Arras's avatar
Philipp Arras committed
48
    raise TypeError
Philipp Arras's avatar
Philipp Arras committed
49
50
51
52
53
54
55
56
57
58
59
60
61


def _field_to_dtype(field):
    if isinstance(field, Field):
        dt = field.dtype
    elif isinstance(field, MultiField):
        dt = {kk: ff.dtype for kk, ff in field.items()}
    else:
        raise TypeError
    _check_sampling_dtype(field.domain, dt)
    return dt


Martin Reinecke's avatar
Martin Reinecke committed
62
class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
63
    """Operator which has a scalar domain as target domain.
64

Martin Reinecke's avatar
Martin Reinecke committed
65
    It is intended as an objective function for field inference.
66

Philipp Arras's avatar
Philipp Arras committed
67
68
69
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
70
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
71
       divergence.
72
    """
Martin Reinecke's avatar
Martin Reinecke committed
73
74
75
    _target = DomainTuple.scalar_domain()


76
77
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
78

Philipp Arras's avatar
Philipp Arras committed
79
80
81
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
82
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
83
    """
Philipp Arras's avatar
Philipp Arras committed
84

Martin Reinecke's avatar
Martin Reinecke committed
85
86
87
    def __init__(self, domain):
        self._domain = domain

Philipp Arras's avatar
Philipp Arras committed
88
    def apply(self, x):
89
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
90
91
        if x.jac is None:
            return x.vdot(x)
Philipp Arras's avatar
Philipp Arras committed
92
93
        res = x.val.vdot(x.val)
        return x.new(res, VdotOperator(2*x.val))
Martin Reinecke's avatar
Martin Reinecke committed
94

Martin Reinecke's avatar
Martin Reinecke committed
95

Martin Reinecke's avatar
Martin Reinecke committed
96
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
97
    """Computes the L2-norm of a Field or MultiField with respect to a
98
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
99
100
101

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
102
103
104

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
105
    endo : EndomorphicOperator
106
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
107
    """
Philipp Arras's avatar
Philipp Arras committed
108
109

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
110
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
111
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
112
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
113
114
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
115

Philipp Arras's avatar
Philipp Arras committed
116
    def apply(self, x):
117
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
118
        if x.jac is None:
Philipp Arras's avatar
Philipp Arras committed
119
120
121
            return 0.5*x.vdot(self._op(x))
        res = 0.5*x.val.vdot(self._op(x.val))
        return x.new(res, VdotOperator(self._op(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
122

Philipp Arras's avatar
Philipp Arras committed
123

124
class VariableCovarianceGaussianEnergy(EnergyOperator):
Reimar Leike's avatar
Reimar Leike committed
125
    """Computes the negative log pdf of a Gaussian with unknown covariance.
126

Reimar Leike's avatar
Reimar Leike committed
127
    The covariance is assumed to be diagonal.
128
129

    .. math ::
Reimar Leike's avatar
Reimar Leike committed
130
        E(s,D) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s) + 0.5 tr log(D),
131
132

    an information energy for a Gaussian distribution with residual s and
133
    diagonal covariance D.
Reimar Leike's avatar
Reimar Leike committed
134
135
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
136
137
138

    Parameters
    ----------
139
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
140
        domain of the residual and domain of the covariance diagonal.
141

142
    residual : key
Philipp Arras's avatar
Philipp Arras committed
143
        Residual key of the Gaussian.
144

Philipp Arras's avatar
Philipp Arras committed
145
    inverse_covariance : key
146
        Inverse covariance diagonal key of the Gaussian.
Philipp Arras's avatar
Philipp Arras committed
147

148
    sampling_dtype : np.dtype
Philipp Arras's avatar
Philipp Arras committed
149
        Data type of the samples. Usually either 'np.float*' or 'np.complex*'
150
151
    """

Philipp Arras's avatar
Philipp Arras committed
152
    def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype):
Philipp Arras's avatar
Philipp Arras committed
153
154
155
156
        self._r = str(residual_key)
        self._icov = str(inverse_covariance_key)
        dom = DomainTuple.make(domain)
        self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
Philipp Arras's avatar
Philipp Arras committed
157
158
        self._sampling_dtype = sampling_dtype
        _check_sampling_dtype(self._domain, sampling_dtype)
159

Philipp Arras's avatar
Philipp Arras committed
160
    def apply(self, x):
161
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
162
        res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].ptw("log").sum())
Martin Reinecke's avatar
more    
Martin Reinecke committed
163
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
164
            return res
Philipp Arras's avatar
Philipp Arras committed
165
        mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
Philipp Arras's avatar
Philipp Arras committed
166
        met = makeOp(MultiField.from_dict(mf))
167
        return res.add_metric(SamplingDtypeSetter(met, self._sampling_dtype))
168

Martin Reinecke's avatar
Martin Reinecke committed
169

170
171
172
173
174
175
176
177
178
def _build_MultiScalingOperator(domain, scales):
    op = None
    for k, dom in domain.items():
        o = ScalingOperator(dom, scales[k])
        FA = FieldAdapter(dom, k)
        o = FA.adjoint @ o @ FA
        op = o if op is None else op + o
    return op

Martin Reinecke's avatar
Martin Reinecke committed
179
class GaussianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
180
    """Computes a negative-log Gaussian.
181

Philipp Arras's avatar
Philipp Arras committed
182
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
183

Philipp Arras's avatar
Philipp Arras committed
184
185
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
186

Philipp Arras's avatar
Philipp Arras committed
187
188
    an information energy for a Gaussian distribution with mean m and
    covariance D.
189

Philipp Arras's avatar
Philipp Arras committed
190
191
192
193
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
194
195
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
196
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
197
198
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified
Reimar Leike's avatar
Reimar Leike committed
199
200
201
202
203
204
205
    sampling_dtype : type
        Here one can specify whether the distribution is a compelx Gaussian or
        not. Note that for a complex Gaussian the inverse_covariance is
        .. math ::
        (<ff^dagger>)^{-1}_P(f)/2,
        where the additional factor of 2 is necessary because the 
        domain of s has double as many dimensions as in the real case.
Philipp Arras's avatar
Philipp Arras committed
206
207
208
209

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
210
    """
Martin Reinecke's avatar
Martin Reinecke committed
211

Philipp Arras's avatar
Philipp Arras committed
212
    def __init__(self, mean=None, inverse_covariance=None, domain=None, sampling_dtype=None):
Martin Reinecke's avatar
Martin Reinecke committed
213
214
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
215
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
216
217
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
218
219
220
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
221
222
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
223
224
225
226
227
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Philipp Arras's avatar
Philipp Arras committed
228
229
230
231
232
233
234
235
236
237
238

        # Infer sampling dtype
        if self._mean is None:
            _check_sampling_dtype(self._domain, sampling_dtype)
        else:
            if sampling_dtype is None:
                sampling_dtype = _field_to_dtype(self._mean)
            else:
                if sampling_dtype != _field_to_dtype(self._mean):
                    raise ValueError("Sampling dtype and mean not compatible")

239
        if inverse_covariance is None:
240
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Philipp Arras's avatar
Philipp Arras committed
241
            self._met = ScalingOperator(self._domain, 1)
Martin Reinecke's avatar
Martin Reinecke committed
242
        else:
243
            self._op = QuadraticFormOperator(inverse_covariance)
Philipp Arras's avatar
Philipp Arras committed
244
            self._met = inverse_covariance
Philipp Arras's avatar
Philipp Arras committed
245
        if sampling_dtype is not None:
246
            self._met = SamplingDtypeSetter(self._met, sampling_dtype)
247

Martin Reinecke's avatar
Martin Reinecke committed
248
    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
249
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
250
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
251
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
252
        else:
Philipp Arras's avatar
Philipp Arras committed
253
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
254
255
                raise ValueError("domain mismatch")

Philipp Arras's avatar
Philipp Arras committed
256
    def apply(self, x):
257
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
258
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
259
        res = self._op(residual).real
Martin Reinecke's avatar
more    
Martin Reinecke committed
260
        if x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
261
262
            return res.add_metric(self._met)
        return res
Martin Reinecke's avatar
Martin Reinecke committed
263
264
265


class PoissonianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
266
267
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
268

Philipp Arras's avatar
Philipp Arras committed
269
    Represents up to an f-independent term :math:`log(d!)`:
270

Philipp Arras's avatar
Philipp Arras committed
271
272
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
273

Philipp Arras's avatar
Philipp Arras committed
274
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
275
    the counts.
Philipp Arras's avatar
Philipp Arras committed
276
277
278
279
280
281

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
282
    """
Philipp Arras's avatar
Philipp Arras committed
283

284
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
285
286
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
287
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
288
            raise ValueError
289
290
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
291

Philipp Arras's avatar
Philipp Arras committed
292
    def apply(self, x):
293
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
294
        res = x.sum() - x.ptw("log").vdot(self._d)
Martin Reinecke's avatar
more    
Martin Reinecke committed
295
        if not x.want_metric:
296
            return res
297
        return res.add_metric(SamplingDtypeSetter(makeOp(1./x.val), np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
298

299

300
class InverseGammaLikelihood(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
301
    """Computes the negative log-likelihood of the inverse gamma distribution.
302
303
304

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
305
306
307
308
309
310
311
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
312
313
314
315
316
317
318

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
319
    """
Philipp Arras's avatar
Philipp Arras committed
320

321
322
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
323
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
324
        self._domain = DomainTuple.make(beta.domain)
325
326
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
327
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
328
329
330
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
Philipp Arras's avatar
Philipp Arras committed
331
332
333
334
        if not self._beta.dtype == np.float64:
            # FIXME Add proper complex support for this energy
            raise TypeError
        self._sampling_dtype = _field_to_dtype(self._beta)
335

Philipp Arras's avatar
Philipp Arras committed
336
    def apply(self, x):
337
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
338
        res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
Martin Reinecke's avatar
more    
Martin Reinecke committed
339
        if not x.want_metric:
340
            return res
Philipp Arras's avatar
Philipp Arras committed
341
342
        met = makeOp(self._alphap1/(x.val**2))
        if self._sampling_dtype is not None:
343
            met = SamplingDtypeSetter(met, self._sampling_dtype)
Philipp Arras's avatar
Philipp Arras committed
344
        return res.add_metric(met)
345
346


347
class StudentTEnergy(EnergyOperator):
Lukas Platz's avatar
Lukas Platz committed
348
    """Computes likelihood energy corresponding to Student's t-distribution.
349
350

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
351
352
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
353

Philipp Arras's avatar
Philipp Arras committed
354
355
    where f is a field defined on `domain`. Assumes that the data is `float64`
    for sampling.
356
357
358

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
359
360
    domain : `Domain` or `DomainTuple`
        Domain of the operator
Reimar Leike's avatar
Reimar Leike committed
361
    theta : Scalar or Field
362
363
364
        Degree of freedom parameter for the student t distribution
    """

Philipp Arras's avatar
Philipp Arras committed
365
    def __init__(self, domain, theta):
366
367
368
        self._domain = DomainTuple.make(domain)
        self._theta = theta

Philipp Arras's avatar
Philipp Arras committed
369
    def apply(self, x):
370
        self._check_input(x)
371
        res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
Martin Reinecke's avatar
more    
Martin Reinecke committed
372
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
373
            return res
374
        met = makeOp((self._theta+1) / (self._theta+3), self.domain)
Philipp Arras's avatar
Philipp Arras committed
375
        return res.add_metric(SamplingDtypeSetter(met, np.float64))
376
377


Martin Reinecke's avatar
Martin Reinecke committed
378
class BernoulliEnergy(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
379
    """Computes likelihood energy of expected event frequency constrained by
380
381
    event data.

Philipp Arras's avatar
Philipp Arras committed
382
383
384
385
386
387
388
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

389
390
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
391
    d : Field
Philipp Arras's avatar
Philipp Arras committed
392
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
393
    """
Philipp Arras's avatar
Philipp Arras committed
394

395
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
396
397
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
398
        if np.any(np.logical_and(d.val != 0, d.val != 1)):
Philipp Arras's avatar
Philipp Arras committed
399
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
400
        self._d = d
401
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
402

Philipp Arras's avatar
Philipp Arras committed
403
    def apply(self, x):
404
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
405
        res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
Martin Reinecke's avatar
more    
Martin Reinecke committed
406
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
407
            return res
Philipp Arras's avatar
Philipp Arras committed
408
        met = makeOp(1./(x.val*(1. - x.val)))
409
        return res.add_metric(SamplingDtypeSetter(met, np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
410
411


412
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
413
414
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
415

Philipp Arras's avatar
Philipp Arras committed
416
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
417

Philipp Arras's avatar
Philipp Arras committed
418
419
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
420

Martin Reinecke's avatar
Martin Reinecke committed
421
    Other field priors can be represented via transformations of a white
422
423
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
424
    By implementing prior information this way, the field prior is represented
425
426
427
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
428
429
430
431
432
433
434
435
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
436
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
437
        to use to draw Gaussian samples.
Philipp Arras's avatar
Philipp Arras committed
438
439
    prior_dtype : numpy.dtype or dict of numpy.dtype, optional
        Data type of prior used for sampling.
Philipp Arras's avatar
Philipp Arras committed
440
441
442
443
444

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
445
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
446
    """
Philipp Arras's avatar
Philipp Arras committed
447

Philipp Arras's avatar
Philipp Arras committed
448
    def __init__(self, lh, ic_samp=None, _c_inp=None, prior_dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
449
        self._lh = lh
Philipp Arras's avatar
Philipp Arras committed
450
        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
451
452
        if _c_inp is not None:
            _, self._prior = self._prior.simplify_for_constant_input(_c_inp)
Martin Reinecke's avatar
Martin Reinecke committed
453
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
454
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
455

Philipp Arras's avatar
Philipp Arras committed
456
    def apply(self, x):
457
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
458
        if not x.want_metric or self._ic_samp is None:
Philipp Arras's avatar
Philipp Arras committed
459
            return (self._lh + self._prior)(x)
Philipp Arras's avatar
Philipp Arras committed
460
461
        lhx, prx = self._lh(x), self._prior(x)
        return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
Martin Reinecke's avatar
Martin Reinecke committed
462

Philipp Arras's avatar
Philipp Arras committed
463
464
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
465
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
466
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
467

468
469
470
471
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)

Martin Reinecke's avatar
Martin Reinecke committed
472

Martin Reinecke's avatar
Martin Reinecke committed
473
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
474
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
475

476
477
478
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
479
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
480
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
481
482
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
483

Philipp Arras's avatar
Docs    
Philipp Arras committed
484
485
486
487
488
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
489

Philipp Arras's avatar
Docs    
Philipp Arras committed
490
491
492
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
493
    """
Martin Reinecke's avatar
Martin Reinecke committed
494
495
496

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
497
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
498
499
        self._res_samples = tuple(res_samples)

Philipp Arras's avatar
Philipp Arras committed
500
    def apply(self, x):
501
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
502
503
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap)/len(self._res_samples)