energy_operators.py 17.4 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2020 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
22
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
23
24
25
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
Philipp Arras's avatar
Philipp Arras committed
26
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
27
from .operator import Operator
Martin Reinecke's avatar
fix    
Martin Reinecke committed
28
from .sampling_enabler import SamplingEnabler
29
from .scaling_operator import ScalingOperator
Philipp Arras's avatar
Philipp Arras committed
30
from .simple_linear_operators import VdotOperator
Philipp Arras's avatar
Philipp Arras committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from .endomorphic_operator import EndomorphicOperator


def _check_sampling_dtype(domain, dtypes):
    if dtypes is None:
        return
    if isinstance(domain, DomainTuple):
        dtypes = {'': dtypes}
    elif isinstance(domain, MultiDomain):
        if dtypes in [np.float64, np.complex128]:
            return
        dtypes = dtypes.values()
        if set(domain.keys()) != set(dtypes.keys()):
            raise ValueError
    else:
        raise TypeError
    for dt in dtypes.values():
        if dt not in [np.float64, np.complex128]:
            raise ValueError


def _field_to_dtype(field):
    if isinstance(field, Field):
        dt = field.dtype
    elif isinstance(field, MultiField):
        dt = {kk: ff.dtype for kk, ff in field.items()}
    else:
        raise TypeError
    _check_sampling_dtype(field.domain, dt)
    return dt


class SamplingDtypeEnabler(EndomorphicOperator):
    def __init__(self, endomorphic_operator, dtype):
        if not isinstance(endomorphic_operator, EndomorphicOperator):
            raise TypeError
        if not hasattr(endomorphic_operator, 'draw_sample_with_dtype'):
            raise TypeError
        dom = endomorphic_operator.domain
        if isinstance(dom, MultiDomain):
            if dtype in [np.float64, np.complex128]:
                dtype = {kk: dtype for kk in dom.keys()}
            if set(dtype.keys()) != set(dom.keys()):
                raise TypeError
        self._dtype = dtype
        self._domain = dom
        self._capability = endomorphic_operator._capability
        self.apply = endomorphic_operator.apply
        self._op = endomorphic_operator

    def draw_sample(self, from_inverse=False):
        """Generate a zero-mean sample

        Generates a sample from a Gaussian distribution with zero mean and
        covariance given by the operator.

        Parameters
        ----------
        from_inverse : bool (default : False)
            if True, the sample is drawn from the inverse of the operator

        Returns
        -------
        Field
            A sample from the Gaussian of given covariance.
        """
        return self._op.draw_sample_with_dtype(self._dtype, from_inverse=from_inverse)
Martin Reinecke's avatar
Martin Reinecke committed
98
99
100


class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
101
    """Operator which has a scalar domain as target domain.
102

Martin Reinecke's avatar
Martin Reinecke committed
103
    It is intended as an objective function for field inference.
104

Philipp Arras's avatar
Philipp Arras committed
105
106
107
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
108
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
109
       divergence.
110
    """
Martin Reinecke's avatar
Martin Reinecke committed
111
112
113
    _target = DomainTuple.scalar_domain()


114
115
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
116

Philipp Arras's avatar
Philipp Arras committed
117
118
119
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
120
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
121
    """
Philipp Arras's avatar
Philipp Arras committed
122

Martin Reinecke's avatar
Martin Reinecke committed
123
124
125
    def __init__(self, domain):
        self._domain = domain

Philipp Arras's avatar
Philipp Arras committed
126
    def apply(self, x):
127
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
128
129
        if x.jac is None:
            return x.vdot(x)
Philipp Arras's avatar
Philipp Arras committed
130
131
        res = x.val.vdot(x.val)
        return x.new(res, VdotOperator(2*x.val))
Martin Reinecke's avatar
Martin Reinecke committed
132

Martin Reinecke's avatar
Martin Reinecke committed
133

Martin Reinecke's avatar
Martin Reinecke committed
134
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
135
    """Computes the L2-norm of a Field or MultiField with respect to a
136
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
137
138
139

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
140
141
142

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
143
    endo : EndomorphicOperator
144
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
145
    """
Philipp Arras's avatar
Philipp Arras committed
146
147

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
148
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
149
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
150
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
151
152
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
153

Philipp Arras's avatar
Philipp Arras committed
154
    def apply(self, x):
155
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
156
        if x.jac is None:
Philipp Arras's avatar
Philipp Arras committed
157
158
159
            return 0.5*x.vdot(self._op(x))
        res = 0.5*x.val.vdot(self._op(x.val))
        return x.new(res, VdotOperator(self._op(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
160

Philipp Arras's avatar
Philipp Arras committed
161

162
class VariableCovarianceGaussianEnergy(EnergyOperator):
Reimar Leike's avatar
Reimar Leike committed
163
    """Computes the negative log pdf of a Gaussian with unknown covariance.
164

Reimar Leike's avatar
Reimar Leike committed
165
    The covariance is assumed to be diagonal.
166
167

    .. math ::
Reimar Leike's avatar
Reimar Leike committed
168
        E(s,D) = - \\log G(s, D) = 0.5 (s)^\\dagger D^{-1} (s) + 0.5 tr log(D),
169
170

    an information energy for a Gaussian distribution with residual s and
171
    diagonal covariance D.
Reimar Leike's avatar
Reimar Leike committed
172
173
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
174
175
176

    Parameters
    ----------
177
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
178
        domain of the residual and domain of the covariance diagonal.
179

180
    residual : key
Philipp Arras's avatar
Philipp Arras committed
181
        Residual key of the Gaussian.
182

Philipp Arras's avatar
Philipp Arras committed
183
    inverse_covariance : key
184
        Inverse covariance diagonal key of the Gaussian.
185
186
    """

Philipp Arras's avatar
Philipp Arras committed
187
    def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype):
Philipp Arras's avatar
Philipp Arras committed
188
189
190
191
        self._r = str(residual_key)
        self._icov = str(inverse_covariance_key)
        dom = DomainTuple.make(domain)
        self._domain = MultiDomain.make({self._r: dom, self._icov: dom})
Philipp Arras's avatar
Philipp Arras committed
192
193
        self._sampling_dtype = sampling_dtype
        _check_sampling_dtype(self._domain, sampling_dtype)
194

Philipp Arras's avatar
Philipp Arras committed
195
    def apply(self, x):
196
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
197
        res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].ptw("log").sum())
Martin Reinecke's avatar
more    
Martin Reinecke committed
198
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
199
            return res
Philipp Arras's avatar
Philipp Arras committed
200
        mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
Philipp Arras's avatar
Philipp Arras committed
201
202
        met = makeOp(MultiField.from_dict(mf))
        return res.add_metric(SamplingDtypeEnabler(met, self._sampling_dtype))
203

Martin Reinecke's avatar
Martin Reinecke committed
204
205

class GaussianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
206
    """Computes a negative-log Gaussian.
207

Philipp Arras's avatar
Philipp Arras committed
208
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
209

Philipp Arras's avatar
Philipp Arras committed
210
211
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
212

Philipp Arras's avatar
Philipp Arras committed
213
214
    an information energy for a Gaussian distribution with mean m and
    covariance D.
215

Philipp Arras's avatar
Philipp Arras committed
216
217
218
219
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
220
221
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
222
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
223
224
225
226
227
228
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
229
    """
Martin Reinecke's avatar
Martin Reinecke committed
230

Philipp Arras's avatar
Philipp Arras committed
231
    def __init__(self, mean=None, inverse_covariance=None, domain=None, sampling_dtype=None):
Martin Reinecke's avatar
Martin Reinecke committed
232
233
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
234
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
235
236
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
237
238
239
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
240
241
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
242
243
244
245
246
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Philipp Arras's avatar
Philipp Arras committed
247
248
249
250
251
252
253
254
255
256
257

        # Infer sampling dtype
        if self._mean is None:
            _check_sampling_dtype(self._domain, sampling_dtype)
        else:
            if sampling_dtype is None:
                sampling_dtype = _field_to_dtype(self._mean)
            else:
                if sampling_dtype != _field_to_dtype(self._mean):
                    raise ValueError("Sampling dtype and mean not compatible")

258
        if inverse_covariance is None:
259
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Philipp Arras's avatar
Philipp Arras committed
260
            self._met = ScalingOperator(self._domain, 1)
Martin Reinecke's avatar
Martin Reinecke committed
261
        else:
262
            self._op = QuadraticFormOperator(inverse_covariance)
Philipp Arras's avatar
Philipp Arras committed
263
            self._met = inverse_covariance
Philipp Arras's avatar
Philipp Arras committed
264
265
        if sampling_dtype is not None:
            self._met = SamplingDtypeEnabler(self._met, sampling_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
266
267

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
268
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
269
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
270
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
271
        else:
Philipp Arras's avatar
Philipp Arras committed
272
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
273
274
                raise ValueError("domain mismatch")

Philipp Arras's avatar
Philipp Arras committed
275
    def apply(self, x):
276
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
277
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
278
        res = self._op(residual).real
Martin Reinecke's avatar
more    
Martin Reinecke committed
279
        if x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
280
281
            return res.add_metric(self._met)
        return res
Martin Reinecke's avatar
Martin Reinecke committed
282
283
284


class PoissonianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
285
286
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
287

Philipp Arras's avatar
Philipp Arras committed
288
    Represents up to an f-independent term :math:`log(d!)`:
289

Philipp Arras's avatar
Philipp Arras committed
290
291
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
292

Philipp Arras's avatar
Philipp Arras committed
293
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
294
    the counts.
Philipp Arras's avatar
Philipp Arras committed
295
296
297
298
299
300

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
301
    """
Philipp Arras's avatar
Philipp Arras committed
302

303
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
304
305
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
306
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
307
            raise ValueError
308
309
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
310

Philipp Arras's avatar
Philipp Arras committed
311
    def apply(self, x):
312
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
313
        res = x.sum() - x.ptw("log").vdot(self._d)
Martin Reinecke's avatar
more    
Martin Reinecke committed
314
        if not x.want_metric:
315
            return res
Philipp Arras's avatar
Philipp Arras committed
316
        return res.add_metric(SamplingDtypeEnabler(makeOp(1./x.val), np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
317

318

319
class InverseGammaLikelihood(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
320
    """Computes the negative log-likelihood of the inverse gamma distribution.
321
322
323

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
324
325
326
327
328
329
330
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
331
332
333
334
335
336
337

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
338
    """
Philipp Arras's avatar
Philipp Arras committed
339

340
341
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
342
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
343
        self._domain = DomainTuple.make(beta.domain)
344
345
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
346
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
347
348
349
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
Philipp Arras's avatar
Philipp Arras committed
350
351
352
353
        if not self._beta.dtype == np.float64:
            # FIXME Add proper complex support for this energy
            raise TypeError
        self._sampling_dtype = _field_to_dtype(self._beta)
354

Philipp Arras's avatar
Philipp Arras committed
355
    def apply(self, x):
356
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
357
        res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
Martin Reinecke's avatar
more    
Martin Reinecke committed
358
        if not x.want_metric:
359
            return res
Philipp Arras's avatar
Philipp Arras committed
360
361
362
363
        met = makeOp(self._alphap1/(x.val**2))
        if self._sampling_dtype is not None:
            met = SamplingDtypeEnabler(met, self._sampling_dtype)
        return res.add_metric(met)
364
365


366
class StudentTEnergy(EnergyOperator):
Lukas Platz's avatar
Lukas Platz committed
367
    """Computes likelihood energy corresponding to Student's t-distribution.
368
369

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
370
371
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
372

Lukas Platz's avatar
Lukas Platz committed
373
    where f is a field defined on `domain`.
374
375
376

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
377
378
    domain : `Domain` or `DomainTuple`
        Domain of the operator
Reimar Leike's avatar
Reimar Leike committed
379
    theta : Scalar or Field
380
381
382
        Degree of freedom parameter for the student t distribution
    """

Philipp Arras's avatar
Philipp Arras committed
383
    def __init__(self, domain, theta, sampling_dtype=np.float64):
384
385
        self._domain = DomainTuple.make(domain)
        self._theta = theta
Philipp Arras's avatar
Philipp Arras committed
386
387
388
        self._sampling_dtype = sampling_dtype
        if sampling_dtype == np.complex128:
            raise NotImplementedError('Complex data not supported yet')
389

Philipp Arras's avatar
Philipp Arras committed
390
    def apply(self, x):
391
        self._check_input(x)
392
        res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
Martin Reinecke's avatar
more    
Martin Reinecke committed
393
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
394
            return res
395
        met = makeOp((self._theta+1) / (self._theta+3), self.domain)
Philipp Arras's avatar
Philipp Arras committed
396
397
        if self._sampling_dtype is not None:
            met = SamplingDtypeEnabler(met, self._sampling_dtype)
Philipp Arras's avatar
Philipp Arras committed
398
        return res.add_metric(met)
399
400


Martin Reinecke's avatar
Martin Reinecke committed
401
class BernoulliEnergy(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
402
    """Computes likelihood energy of expected event frequency constrained by
403
404
    event data.

Philipp Arras's avatar
Philipp Arras committed
405
406
407
408
409
410
411
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

412
413
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
414
    d : Field
Philipp Arras's avatar
Philipp Arras committed
415
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
416
    """
Philipp Arras's avatar
Philipp Arras committed
417

418
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
419
420
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
421
        if np.any(np.logical_and(d.val != 0, d.val != 1)):
Philipp Arras's avatar
Philipp Arras committed
422
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
423
        self._d = d
424
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
425

Philipp Arras's avatar
Philipp Arras committed
426
    def apply(self, x):
427
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
428
        res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
Martin Reinecke's avatar
more    
Martin Reinecke committed
429
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
430
            return res
Philipp Arras's avatar
Philipp Arras committed
431
432
        met = makeOp(1./(x.val*(1. - x.val)))
        return res.add_metric(SamplingDtypeEnabler(met, np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
433
434


435
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
436
437
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
438

Philipp Arras's avatar
Philipp Arras committed
439
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
440

Philipp Arras's avatar
Philipp Arras committed
441
442
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
443

Martin Reinecke's avatar
Martin Reinecke committed
444
    Other field priors can be represented via transformations of a white
445
446
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
447
    By implementing prior information this way, the field prior is represented
448
449
450
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
451
452
453
454
455
456
457
458
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
459
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
460
        to use to draw Gaussian samples.
Philipp Arras's avatar
Philipp Arras committed
461
    sampling_dtype : FIXME
Philipp Arras's avatar
Philipp Arras committed
462
463
464
465
466

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
467
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
468
    """
Philipp Arras's avatar
Philipp Arras committed
469

Philipp Arras's avatar
Philipp Arras committed
470
    def __init__(self, lh, ic_samp=None, _c_inp=None, sampling_dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
471
        self._lh = lh
Philipp Arras's avatar
Philipp Arras committed
472
        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=sampling_dtype)
473
474
        if _c_inp is not None:
            _, self._prior = self._prior.simplify_for_constant_input(_c_inp)
Martin Reinecke's avatar
Martin Reinecke committed
475
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
476
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
477

Philipp Arras's avatar
Philipp Arras committed
478
    def apply(self, x):
479
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
480
        if not x.want_metric or self._ic_samp is None:
Philipp Arras's avatar
Philipp Arras committed
481
            return (self._lh + self._prior)(x)
Philipp Arras's avatar
Philipp Arras committed
482
483
        lhx, prx = self._lh(x), self._prior(x)
        return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
Martin Reinecke's avatar
Martin Reinecke committed
484

Philipp Arras's avatar
Philipp Arras committed
485
486
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
487
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
488
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
489

490
491
492
493
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp, _c_inp=c_inp)

Martin Reinecke's avatar
Martin Reinecke committed
494

Martin Reinecke's avatar
Martin Reinecke committed
495
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
496
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
497

498
499
500
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
501
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
502
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
503
504
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
505

Philipp Arras's avatar
Docs    
Philipp Arras committed
506
507
508
509
510
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
511

Philipp Arras's avatar
Docs    
Philipp Arras committed
512
513
514
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
515
    """
Martin Reinecke's avatar
Martin Reinecke committed
516
517
518

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
519
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
520
521
        self._res_samples = tuple(res_samples)

Philipp Arras's avatar
Philipp Arras committed
522
    def apply(self, x):
523
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
524
525
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap)/len(self._res_samples)