energy_operators.py 18.1 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2020 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
22
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
23
24
25
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
Philipp Arras's avatar
Philipp Arras committed
26
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
27
from .operator import Operator
28
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
29
from .scaling_operator import ScalingOperator
Philipp Arras's avatar
Cleanup    
Philipp Arras committed
30
from .simple_linear_operators import VdotOperator
Philipp Arras's avatar
Philipp Arras committed
31
32
33
34
35
36


def _check_sampling_dtype(domain, dtypes):
    if dtypes is None:
        return
    if isinstance(domain, DomainTuple):
Philipp Arras's avatar
Philipp Arras committed
37
38
        np.dtype(dtypes)
        return
Philipp Arras's avatar
Philipp Arras committed
39
    elif isinstance(domain, MultiDomain):
Philipp Arras's avatar
Philipp Arras committed
40
41
42
43
44
45
46
        if isinstance(dtypes, dict):
            for dt in dtypes.values():
                np.dtype(dt)
            if set(domain.keys()) == set(dtypes.keys()):
                return
        else:
            np.dtype(dtypes)
Philipp Arras's avatar
Philipp Arras committed
47
            return
Philipp Arras's avatar
Philipp Arras committed
48
    raise TypeError
Philipp Arras's avatar
Philipp Arras committed
49
50


51
52
53
54
def _iscomplex(dtype):
    return np.issubdtype(dtype, np.complexfloating)


Philipp Arras's avatar
Philipp Arras committed
55
56
57
58
59
60
61
62
63
64
65
def _field_to_dtype(field):
    if isinstance(field, Field):
        dt = field.dtype
    elif isinstance(field, MultiField):
        dt = {kk: ff.dtype for kk, ff in field.items()}
    else:
        raise TypeError
    _check_sampling_dtype(field.domain, dt)
    return dt


Martin Reinecke's avatar
Martin Reinecke committed
66
class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
67
    """Operator which has a scalar domain as target domain.
68

Martin Reinecke's avatar
Martin Reinecke committed
69
    It is intended as an objective function for field inference.
70

Philipp Arras's avatar
Philipp Arras committed
71
72
73
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
74
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
75
       divergence.
76
    """
Martin Reinecke's avatar
Martin Reinecke committed
77
78
79
    _target = DomainTuple.scalar_domain()


80
81
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
82

Philipp Arras's avatar
Philipp Arras committed
83
84
85
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
86
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
87
    """
Philipp Arras's avatar
Philipp Arras committed
88

Martin Reinecke's avatar
Martin Reinecke committed
89
90
91
    def __init__(self, domain):
        self._domain = domain

Philipp Arras's avatar
Philipp Arras committed
92
    def apply(self, x):
93
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
94
95
        if x.jac is None:
            return x.vdot(x)
Philipp Arras's avatar
Philipp Arras committed
96
97
        res = x.val.vdot(x.val)
        return x.new(res, VdotOperator(2*x.val))
Martin Reinecke's avatar
Martin Reinecke committed
98

Martin Reinecke's avatar
Martin Reinecke committed
99

Martin Reinecke's avatar
Martin Reinecke committed
100
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
101
    """Computes the L2-norm of a Field or MultiField with respect to a
102
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
103
104
105

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
106
107
108

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
109
    endo : EndomorphicOperator
110
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
111
    """
Philipp Arras's avatar
Philipp Arras committed
112
113

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
114
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
115
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
116
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
117
118
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
119

Philipp Arras's avatar
Philipp Arras committed
120
    def apply(self, x):
121
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
122
        if x.jac is None:
Philipp Arras's avatar
Philipp Arras committed
123
124
125
            return 0.5*x.vdot(self._op(x))
        res = 0.5*x.val.vdot(self._op(x.val))
        return x.new(res, VdotOperator(self._op(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
126

Philipp Arras's avatar
Philipp Arras committed
127

128
class VariableCovarianceGaussianEnergy(EnergyOperator):
Reimar Leike's avatar
Reimar Leike committed
129
    """Computes the negative log pdf of a Gaussian with unknown covariance.
130

Reimar Leike's avatar
Reimar Leike committed
131
    The covariance is assumed to be diagonal.
132
133

    .. math ::
134
        E(s,D) = - \\log G(s, C) = 0.5 (s)^\\dagger C (s) - 0.5 tr log(C),
135
136

    an information energy for a Gaussian distribution with residual s and
137
    inverse diagonal covariance C.
Reimar Leike's avatar
Reimar Leike committed
138
139
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
140
141
142

    Parameters
    ----------
143
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
144
        domain of the residual and domain of the covariance diagonal.
145

146
    residual_key : key
Philipp Arras's avatar
Philipp Arras committed
147
        Residual key of the Gaussian.
148

149
    inverse_covariance_key : key
150
        Inverse covariance diagonal key of the Gaussian.
Philipp Arras's avatar
Philipp Arras committed
151

152
    sampling_dtype : np.dtype
Philipp Arras's avatar
Philipp Arras committed
153
        Data type of the samples. Usually either 'np.float*' or 'np.complex*'
154
155
    """

Philipp Arras's avatar
Philipp Arras committed
156
    def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype):
Philipp Arras's avatar
Philipp Arras committed
157
158
        self._kr = str(residual_key)
        self._ki = str(inverse_covariance_key)
Philipp Arras's avatar
Philipp Arras committed
159
        dom = DomainTuple.make(domain)
Philipp Arras's avatar
Philipp Arras committed
160
        self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
Philipp Arras's avatar
Philipp Arras committed
161
162
        self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
        _check_sampling_dtype(self._domain, self._dt)
163
        self._cplx = _iscomplex(sampling_dtype)
164

Philipp Arras's avatar
Philipp Arras committed
165
    def apply(self, x):
166
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
167
        r, i = x[self._kr], x[self._ki]
Philipp Arras's avatar
Philipp Arras committed
168
169
170
171
        if self._cplx:
            res = 0.5*r.vdot(r*i.real).real - i.ptw("log").sum()
        else:
            res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
Martin Reinecke's avatar
more    
Martin Reinecke committed
172
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
173
            return res
Philipp Frank's avatar
Philipp Frank committed
174
175
176
        met = 1. if self._cplx else 0.5
        met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
                                    domain=self._domain)
Philipp Arras's avatar
Philipp Arras committed
177
        return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        from .simplify_for_const import ConstantEnergyOperator
        assert len(c_inp.keys()) == 1
        key = c_inp.keys()[0]
        assert key in self._domain.keys()
        cst = c_inp[key]
        if key == self._kr:
            res = _SpecialGammaEnergy(cst).ducktape(self._ki)
        else:
            dt = self._dt[self._kr]
            res = GaussianEnergy(inverse_covariance=makeOp(cst),
                                 sampling_dtype=dt).ducktape(self._kr)
            trlog = cst.log().sum().val_rw()
            if not _iscomplex(dt):
                trlog /= 2
            res = res + ConstantEnergyOperator(-trlog)
        res = res + ConstantEnergyOperator(0.)
        assert res.target is self.target
        return None, res
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218


class _SpecialGammaEnergy(EnergyOperator):
    def __init__(self, residual):
        self._domain = DomainTuple.make(residual.domain)
        self._resi = residual
        self._cplx = _iscomplex(self._resi.dtype)
        self._scale = ScalingOperator(self._domain, 1 if self._cplx else .5)

    def apply(self, x):
        self._check_input(x)
        r = self._resi
        if self._cplx:
            res = 0.5*(r*x.real).vdot(r).real - x.ptw("log").sum()
        else:
            res = 0.5*((r*x).vdot(r) - x.ptw("log").sum())
        if not x.want_metric:
            return res
        met = makeOp((self._scale(x.val))**(-2))
        return res.add_metric(SamplingDtypeSetter(met, self._resi.dtype))

Martin Reinecke's avatar
Martin Reinecke committed
219
220

class GaussianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
221
    """Computes a negative-log Gaussian.
222

Philipp Arras's avatar
Philipp Arras committed
223
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
224

Philipp Arras's avatar
Philipp Arras committed
225
226
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
227

Philipp Arras's avatar
Philipp Arras committed
228
229
    an information energy for a Gaussian distribution with mean m and
    covariance D.
230

Philipp Arras's avatar
Philipp Arras committed
231
232
233
234
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
235
236
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
237
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
238
239
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified
Reimar Leike's avatar
Reimar Leike committed
240
    sampling_dtype : type
Martin Reinecke's avatar
Martin Reinecke committed
241
        Here one can specify whether the distribution is a complex Gaussian or
Reimar Leike's avatar
Reimar Leike committed
242
243
244
        not. Note that for a complex Gaussian the inverse_covariance is
        .. math ::
        (<ff^dagger>)^{-1}_P(f)/2,
245
        where the additional factor of 2 is necessary because the
Reimar Leike's avatar
Reimar Leike committed
246
        domain of s has double as many dimensions as in the real case.
Philipp Arras's avatar
Philipp Arras committed
247
248
249
250

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
251
    """
Martin Reinecke's avatar
Martin Reinecke committed
252

Philipp Arras's avatar
Philipp Arras committed
253
    def __init__(self, mean=None, inverse_covariance=None, domain=None, sampling_dtype=None):
Martin Reinecke's avatar
Martin Reinecke committed
254
255
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
256
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
257
258
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
259
260
261
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
262
263
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
264
265
266
267
268
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Philipp Arras's avatar
Philipp Arras committed
269
270
271
272
273
274
275
276
277
278
279

        # Infer sampling dtype
        if self._mean is None:
            _check_sampling_dtype(self._domain, sampling_dtype)
        else:
            if sampling_dtype is None:
                sampling_dtype = _field_to_dtype(self._mean)
            else:
                if sampling_dtype != _field_to_dtype(self._mean):
                    raise ValueError("Sampling dtype and mean not compatible")

Philipp Arras's avatar
Philipp Arras committed
280
        self._icov = inverse_covariance
281
        if inverse_covariance is None:
282
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Philipp Arras's avatar
Philipp Arras committed
283
            self._met = ScalingOperator(self._domain, 1)
Martin Reinecke's avatar
Martin Reinecke committed
284
        else:
285
            self._op = QuadraticFormOperator(inverse_covariance)
Philipp Arras's avatar
Philipp Arras committed
286
            self._met = inverse_covariance
Philipp Arras's avatar
Philipp Arras committed
287
        if sampling_dtype is not None:
288
            self._met = SamplingDtypeSetter(self._met, sampling_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
289
290

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
291
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
292
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
293
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
294
        else:
Philipp Arras's avatar
Philipp Arras committed
295
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
296
297
                raise ValueError("domain mismatch")

Philipp Arras's avatar
Philipp Arras committed
298
    def apply(self, x):
299
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
300
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
301
        res = self._op(residual).real
Martin Reinecke's avatar
more    
Martin Reinecke committed
302
        if x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
303
304
            return res.add_metric(self._met)
        return res
Martin Reinecke's avatar
Martin Reinecke committed
305

Philipp Arras's avatar
Philipp Arras committed
306
307
308
309
    def __repr__(self):
        dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
        return f'GaussianEnergy {dom}'

Martin Reinecke's avatar
Martin Reinecke committed
310
311

class PoissonianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
312
313
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
314

Philipp Arras's avatar
Philipp Arras committed
315
    Represents up to an f-independent term :math:`log(d!)`:
316

Philipp Arras's avatar
Philipp Arras committed
317
318
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
319

Philipp Arras's avatar
Philipp Arras committed
320
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
321
    the counts.
Philipp Arras's avatar
Philipp Arras committed
322
323
324
325
326
327

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
328
    """
Philipp Arras's avatar
Philipp Arras committed
329

330
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
331
332
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
333
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
334
            raise ValueError
335
336
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
337

Philipp Arras's avatar
Philipp Arras committed
338
    def apply(self, x):
339
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
340
        res = x.sum() - x.ptw("log").vdot(self._d)
Martin Reinecke's avatar
more    
Martin Reinecke committed
341
        if not x.want_metric:
342
            return res
343
        return res.add_metric(SamplingDtypeSetter(makeOp(1./x.val), np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
344

345

346
class InverseGammaLikelihood(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
347
    """Computes the negative log-likelihood of the inverse gamma distribution.
348
349
350

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
351
352
353
354
355
356
357
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
358
359
360
361
362
363
364

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
365
    """
Philipp Arras's avatar
Philipp Arras committed
366

367
368
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
369
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
370
        self._domain = DomainTuple.make(beta.domain)
371
372
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
373
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
374
375
376
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
Philipp Arras's avatar
Philipp Arras committed
377
378
379
380
        if not self._beta.dtype == np.float64:
            # FIXME Add proper complex support for this energy
            raise TypeError
        self._sampling_dtype = _field_to_dtype(self._beta)
381

Philipp Arras's avatar
Philipp Arras committed
382
    def apply(self, x):
383
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
384
        res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
Martin Reinecke's avatar
more    
Martin Reinecke committed
385
        if not x.want_metric:
386
            return res
Philipp Arras's avatar
Philipp Arras committed
387
388
        met = makeOp(self._alphap1/(x.val**2))
        if self._sampling_dtype is not None:
389
            met = SamplingDtypeSetter(met, self._sampling_dtype)
Philipp Arras's avatar
Philipp Arras committed
390
        return res.add_metric(met)
391
392


393
class StudentTEnergy(EnergyOperator):
Lukas Platz's avatar
Lukas Platz committed
394
    """Computes likelihood energy corresponding to Student's t-distribution.
395
396

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
397
398
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
399

Philipp Arras's avatar
Philipp Arras committed
400
401
    where f is a field defined on `domain`. Assumes that the data is `float64`
    for sampling.
402
403
404

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
405
406
    domain : `Domain` or `DomainTuple`
        Domain of the operator
Reimar Leike's avatar
Reimar Leike committed
407
    theta : Scalar or Field
408
409
410
        Degree of freedom parameter for the student t distribution
    """

Philipp Arras's avatar
Philipp Arras committed
411
    def __init__(self, domain, theta):
412
413
414
        self._domain = DomainTuple.make(domain)
        self._theta = theta

Philipp Arras's avatar
Philipp Arras committed
415
    def apply(self, x):
416
        self._check_input(x)
417
        res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
Martin Reinecke's avatar
more    
Martin Reinecke committed
418
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
419
            return res
420
        met = makeOp((self._theta+1) / (self._theta+3), self.domain)
Philipp Arras's avatar
Philipp Arras committed
421
        return res.add_metric(SamplingDtypeSetter(met, np.float64))
422
423


Martin Reinecke's avatar
Martin Reinecke committed
424
class BernoulliEnergy(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
425
    """Computes likelihood energy of expected event frequency constrained by
426
427
    event data.

Philipp Arras's avatar
Philipp Arras committed
428
429
430
431
432
433
434
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

435
436
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
437
    d : Field
Philipp Arras's avatar
Philipp Arras committed
438
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
439
    """
Philipp Arras's avatar
Philipp Arras committed
440

441
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
442
443
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
444
        if np.any(np.logical_and(d.val != 0, d.val != 1)):
Philipp Arras's avatar
Philipp Arras committed
445
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
446
        self._d = d
447
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
448

Philipp Arras's avatar
Philipp Arras committed
449
    def apply(self, x):
450
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
451
        res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
Martin Reinecke's avatar
more    
Martin Reinecke committed
452
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
453
            return res
Philipp Arras's avatar
Philipp Arras committed
454
        met = makeOp(1./(x.val*(1. - x.val)))
455
        return res.add_metric(SamplingDtypeSetter(met, np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
456
457


458
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
459
460
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
461

Philipp Arras's avatar
Philipp Arras committed
462
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
463

Philipp Arras's avatar
Philipp Arras committed
464
465
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
466

Martin Reinecke's avatar
Martin Reinecke committed
467
    Other field priors can be represented via transformations of a white
468
469
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
470
    By implementing prior information this way, the field prior is represented
471
472
473
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
474
475
476
477
478
479
480
481
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
482
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
483
        to use to draw Gaussian samples.
Philipp Arras's avatar
Philipp Arras committed
484
485
    prior_dtype : numpy.dtype or dict of numpy.dtype, optional
        Data type of prior used for sampling.
Philipp Arras's avatar
Philipp Arras committed
486
487
488
489
490

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
491
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
492
    """
Philipp Arras's avatar
Philipp Arras committed
493

494
    def __init__(self, lh, ic_samp=None, prior_dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
495
        self._lh = lh
Philipp Arras's avatar
Philipp Arras committed
496
        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
497
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
498
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
499

Philipp Arras's avatar
Philipp Arras committed
500
    def apply(self, x):
501
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
502
        if not x.want_metric or self._ic_samp is None:
Philipp Arras's avatar
Philipp Arras committed
503
            return (self._lh + self._prior)(x)
Philipp Arras's avatar
Philipp Arras committed
504
505
        lhx, prx = self._lh(x), self._prior(x)
        return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
Martin Reinecke's avatar
Martin Reinecke committed
506

Philipp Arras's avatar
Philipp Arras committed
507
508
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
509
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
510
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
511

512
513
514
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp)
515

Martin Reinecke's avatar
Martin Reinecke committed
516

Martin Reinecke's avatar
Martin Reinecke committed
517
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
518
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
519

520
521
522
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
523
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
524
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
525
526
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
527

Philipp Arras's avatar
Docs    
Philipp Arras committed
528
529
530
531
532
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
533

Philipp Arras's avatar
Docs    
Philipp Arras committed
534
535
536
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
537
    """
Martin Reinecke's avatar
Martin Reinecke committed
538
539
540

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
541
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
542
543
        self._res_samples = tuple(res_samples)

Philipp Arras's avatar
Philipp Arras committed
544
    def apply(self, x):
545
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
546
547
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap)/len(self._res_samples)