energy_operators.py 18.2 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2021 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
22
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
23
24
25
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
26
from ..utilities import myassert
Philipp Arras's avatar
Philipp Arras committed
27
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
28
from .operator import Operator
29
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
30
from .scaling_operator import ScalingOperator
Philipp Arras's avatar
Cleanup    
Philipp Arras committed
31
from .simple_linear_operators import VdotOperator
Philipp Arras's avatar
Philipp Arras committed
32
33
34
35
36
37


def _check_sampling_dtype(domain, dtypes):
    if dtypes is None:
        return
    if isinstance(domain, DomainTuple):
Philipp Arras's avatar
Philipp Arras committed
38
39
        np.dtype(dtypes)
        return
Philipp Arras's avatar
Philipp Arras committed
40
    elif isinstance(domain, MultiDomain):
Philipp Arras's avatar
Philipp Arras committed
41
42
43
44
45
46
47
        if isinstance(dtypes, dict):
            for dt in dtypes.values():
                np.dtype(dt)
            if set(domain.keys()) == set(dtypes.keys()):
                return
        else:
            np.dtype(dtypes)
Philipp Arras's avatar
Philipp Arras committed
48
            return
Philipp Arras's avatar
Philipp Arras committed
49
    raise TypeError
Philipp Arras's avatar
Philipp Arras committed
50
51


52
53
54
55
def _iscomplex(dtype):
    return np.issubdtype(dtype, np.complexfloating)


Philipp Arras's avatar
Philipp Arras committed
56
57
58
59
60
61
62
63
64
65
66
def _field_to_dtype(field):
    if isinstance(field, Field):
        dt = field.dtype
    elif isinstance(field, MultiField):
        dt = {kk: ff.dtype for kk, ff in field.items()}
    else:
        raise TypeError
    _check_sampling_dtype(field.domain, dt)
    return dt


Martin Reinecke's avatar
Martin Reinecke committed
67
class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
68
    """Operator which has a scalar domain as target domain.
69

Martin Reinecke's avatar
Martin Reinecke committed
70
    It is intended as an objective function for field inference.
71

Philipp Arras's avatar
Philipp Arras committed
72
73
74
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
75
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
76
       divergence.
77
    """
Martin Reinecke's avatar
Martin Reinecke committed
78
79
80
    _target = DomainTuple.scalar_domain()


81
82
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
83

Philipp Arras's avatar
Philipp Arras committed
84
85
86
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
87
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
88
    """
Philipp Arras's avatar
Philipp Arras committed
89

Martin Reinecke's avatar
Martin Reinecke committed
90
91
92
    def __init__(self, domain):
        self._domain = domain

Philipp Arras's avatar
Philipp Arras committed
93
    def apply(self, x):
94
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
95
96
        if x.jac is None:
            return x.vdot(x)
Philipp Arras's avatar
Philipp Arras committed
97
98
        res = x.val.vdot(x.val)
        return x.new(res, VdotOperator(2*x.val))
Martin Reinecke's avatar
Martin Reinecke committed
99

Martin Reinecke's avatar
Martin Reinecke committed
100

Martin Reinecke's avatar
Martin Reinecke committed
101
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
102
    """Computes the L2-norm of a Field or MultiField with respect to a
103
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
104
105
106

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
107
108
109

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
110
    endo : EndomorphicOperator
111
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
112
    """
Philipp Arras's avatar
Philipp Arras committed
113
114

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
115
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
116
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
117
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
118
119
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
120

Philipp Arras's avatar
Philipp Arras committed
121
    def apply(self, x):
122
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
123
        if x.jac is None:
Philipp Arras's avatar
Philipp Arras committed
124
125
126
            return 0.5*x.vdot(self._op(x))
        res = 0.5*x.val.vdot(self._op(x.val))
        return x.new(res, VdotOperator(self._op(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
127

Philipp Arras's avatar
Philipp Arras committed
128

129
class VariableCovarianceGaussianEnergy(EnergyOperator):
Reimar Leike's avatar
Reimar Leike committed
130
    """Computes the negative log pdf of a Gaussian with unknown covariance.
131

Reimar Leike's avatar
Reimar Leike committed
132
    The covariance is assumed to be diagonal.
133
134

    .. math ::
135
        E(s,D) = - \\log G(s, C) = 0.5 (s)^\\dagger C (s) - 0.5 tr log(C),
136
137

    an information energy for a Gaussian distribution with residual s and
138
    inverse diagonal covariance C.
Reimar Leike's avatar
Reimar Leike committed
139
140
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
141
142
143

    Parameters
    ----------
144
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
145
        domain of the residual and domain of the covariance diagonal.
146

147
    residual_key : key
Philipp Arras's avatar
Philipp Arras committed
148
        Residual key of the Gaussian.
149

150
    inverse_covariance_key : key
151
        Inverse covariance diagonal key of the Gaussian.
Philipp Arras's avatar
Philipp Arras committed
152

153
    sampling_dtype : np.dtype
Philipp Arras's avatar
Philipp Arras committed
154
        Data type of the samples. Usually either 'np.float*' or 'np.complex*'
155
156
    """

Philipp Arras's avatar
Philipp Arras committed
157
    def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype):
Philipp Arras's avatar
Philipp Arras committed
158
159
        self._kr = str(residual_key)
        self._ki = str(inverse_covariance_key)
Philipp Arras's avatar
Philipp Arras committed
160
        dom = DomainTuple.make(domain)
Philipp Arras's avatar
Philipp Arras committed
161
        self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
Philipp Arras's avatar
Philipp Arras committed
162
163
        self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
        _check_sampling_dtype(self._domain, self._dt)
164
        self._cplx = _iscomplex(sampling_dtype)
165

Philipp Arras's avatar
Philipp Arras committed
166
    def apply(self, x):
167
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
168
        r, i = x[self._kr], x[self._ki]
Philipp Arras's avatar
Philipp Arras committed
169
170
171
172
        if self._cplx:
            res = 0.5*r.vdot(r*i.real).real - i.ptw("log").sum()
        else:
            res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
Martin Reinecke's avatar
more    
Martin Reinecke committed
173
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
174
            return res
175
        met = 1. if self._cplx else .5
Philipp Frank's avatar
Philipp Frank committed
176
177
        met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
                                    domain=self._domain)
Philipp Arras's avatar
Philipp Arras committed
178
        return res.add_metric(SamplingDtypeSetter(makeOp(met), self._dt))
179

180
181
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        from .simplify_for_const import ConstantEnergyOperator
182
        myassert(len(c_inp.keys()) == 1)
183
        key = c_inp.keys()[0]
184
        myassert(key in self._domain.keys())
185
186
187
188
189
190
191
192
193
194
195
196
        cst = c_inp[key]
        if key == self._kr:
            res = _SpecialGammaEnergy(cst).ducktape(self._ki)
        else:
            dt = self._dt[self._kr]
            res = GaussianEnergy(inverse_covariance=makeOp(cst),
                                 sampling_dtype=dt).ducktape(self._kr)
            trlog = cst.log().sum().val_rw()
            if not _iscomplex(dt):
                trlog /= 2
            res = res + ConstantEnergyOperator(-trlog)
        res = res + ConstantEnergyOperator(0.)
197
        myassert(res.target is self.target)
198
        return None, res
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219


class _SpecialGammaEnergy(EnergyOperator):
    def __init__(self, residual):
        self._domain = DomainTuple.make(residual.domain)
        self._resi = residual
        self._cplx = _iscomplex(self._resi.dtype)
        self._scale = ScalingOperator(self._domain, 1 if self._cplx else .5)

    def apply(self, x):
        self._check_input(x)
        r = self._resi
        if self._cplx:
            res = 0.5*(r*x.real).vdot(r).real - x.ptw("log").sum()
        else:
            res = 0.5*((r*x).vdot(r) - x.ptw("log").sum())
        if not x.want_metric:
            return res
        met = makeOp((self._scale(x.val))**(-2))
        return res.add_metric(SamplingDtypeSetter(met, self._resi.dtype))

Martin Reinecke's avatar
Martin Reinecke committed
220
221

class GaussianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
222
    """Computes a negative-log Gaussian.
223

Philipp Arras's avatar
Philipp Arras committed
224
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
225

Philipp Arras's avatar
Philipp Arras committed
226
227
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
228

Philipp Arras's avatar
Philipp Arras committed
229
230
    an information energy for a Gaussian distribution with mean m and
    covariance D.
231

Philipp Arras's avatar
Philipp Arras committed
232
233
234
235
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
236
237
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
238
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
239
240
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified
Reimar Leike's avatar
Reimar Leike committed
241
    sampling_dtype : type
Martin Reinecke's avatar
Martin Reinecke committed
242
        Here one can specify whether the distribution is a complex Gaussian or
Reimar Leike's avatar
Reimar Leike committed
243
244
245
        not. Note that for a complex Gaussian the inverse_covariance is
        .. math ::
        (<ff^dagger>)^{-1}_P(f)/2,
246
        where the additional factor of 2 is necessary because the
Reimar Leike's avatar
Reimar Leike committed
247
        domain of s has double as many dimensions as in the real case.
Philipp Arras's avatar
Philipp Arras committed
248
249
250
251

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
252
    """
Martin Reinecke's avatar
Martin Reinecke committed
253

Philipp Arras's avatar
Philipp Arras committed
254
    def __init__(self, mean=None, inverse_covariance=None, domain=None, sampling_dtype=None):
Martin Reinecke's avatar
Martin Reinecke committed
255
256
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
257
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
258
259
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
260
261
262
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
263
264
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
265
266
267
268
269
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Philipp Arras's avatar
Philipp Arras committed
270
271
272
273
274
275
276
277
278
279
280

        # Infer sampling dtype
        if self._mean is None:
            _check_sampling_dtype(self._domain, sampling_dtype)
        else:
            if sampling_dtype is None:
                sampling_dtype = _field_to_dtype(self._mean)
            else:
                if sampling_dtype != _field_to_dtype(self._mean):
                    raise ValueError("Sampling dtype and mean not compatible")

Philipp Arras's avatar
Philipp Arras committed
281
        self._icov = inverse_covariance
282
        if inverse_covariance is None:
283
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Philipp Arras's avatar
Philipp Arras committed
284
            self._met = ScalingOperator(self._domain, 1)
Martin Reinecke's avatar
Martin Reinecke committed
285
        else:
286
            self._op = QuadraticFormOperator(inverse_covariance)
Philipp Arras's avatar
Philipp Arras committed
287
            self._met = inverse_covariance
Philipp Arras's avatar
Philipp Arras committed
288
        if sampling_dtype is not None:
289
            self._met = SamplingDtypeSetter(self._met, sampling_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
290
291

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
292
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
293
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
294
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
295
        else:
Philipp Arras's avatar
Philipp Arras committed
296
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
297
298
                raise ValueError("domain mismatch")

Philipp Arras's avatar
Philipp Arras committed
299
    def apply(self, x):
300
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
301
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
302
        res = self._op(residual).real
Martin Reinecke's avatar
more    
Martin Reinecke committed
303
        if x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
304
305
            return res.add_metric(self._met)
        return res
Martin Reinecke's avatar
Martin Reinecke committed
306

Philipp Arras's avatar
Philipp Arras committed
307
308
309
310
    def __repr__(self):
        dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
        return f'GaussianEnergy {dom}'

Martin Reinecke's avatar
Martin Reinecke committed
311
312

class PoissonianEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
313
314
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
315

Philipp Arras's avatar
Philipp Arras committed
316
    Represents up to an f-independent term :math:`log(d!)`:
317

Philipp Arras's avatar
Philipp Arras committed
318
319
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
320

Philipp Arras's avatar
Philipp Arras committed
321
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
322
    the counts.
Philipp Arras's avatar
Philipp Arras committed
323
324
325
326
327
328

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
329
    """
Philipp Arras's avatar
Philipp Arras committed
330

331
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
332
333
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
334
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
335
            raise ValueError
336
337
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
338

Philipp Arras's avatar
Philipp Arras committed
339
    def apply(self, x):
340
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
341
        res = x.sum() - x.ptw("log").vdot(self._d)
Martin Reinecke's avatar
more    
Martin Reinecke committed
342
        if not x.want_metric:
343
            return res
344
        return res.add_metric(SamplingDtypeSetter(makeOp(1./x.val), np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
345

346

347
class InverseGammaLikelihood(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
348
    """Computes the negative log-likelihood of the inverse gamma distribution.
349
350
351

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
352
353
354
355
356
357
358
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
359
360
361
362
363
364
365

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
366
    """
Philipp Arras's avatar
Philipp Arras committed
367

368
369
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
370
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
371
        self._domain = DomainTuple.make(beta.domain)
372
373
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
374
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
375
376
377
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
Philipp Arras's avatar
Philipp Arras committed
378
379
380
381
        if not self._beta.dtype == np.float64:
            # FIXME Add proper complex support for this energy
            raise TypeError
        self._sampling_dtype = _field_to_dtype(self._beta)
382

Philipp Arras's avatar
Philipp Arras committed
383
    def apply(self, x):
384
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
385
        res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
Martin Reinecke's avatar
more    
Martin Reinecke committed
386
        if not x.want_metric:
387
            return res
Philipp Arras's avatar
Philipp Arras committed
388
389
        met = makeOp(self._alphap1/(x.val**2))
        if self._sampling_dtype is not None:
390
            met = SamplingDtypeSetter(met, self._sampling_dtype)
Philipp Arras's avatar
Philipp Arras committed
391
        return res.add_metric(met)
392
393


394
class StudentTEnergy(EnergyOperator):
Lukas Platz's avatar
Lukas Platz committed
395
    """Computes likelihood energy corresponding to Student's t-distribution.
396
397

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
398
399
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
400

Philipp Arras's avatar
Philipp Arras committed
401
402
    where f is a field defined on `domain`. Assumes that the data is `float64`
    for sampling.
403
404
405

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
406
407
    domain : `Domain` or `DomainTuple`
        Domain of the operator
Reimar Leike's avatar
Reimar Leike committed
408
    theta : Scalar or Field
409
410
411
        Degree of freedom parameter for the student t distribution
    """

Philipp Arras's avatar
Philipp Arras committed
412
    def __init__(self, domain, theta):
413
414
415
        self._domain = DomainTuple.make(domain)
        self._theta = theta

Philipp Arras's avatar
Philipp Arras committed
416
    def apply(self, x):
417
        self._check_input(x)
418
        res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
Martin Reinecke's avatar
more    
Martin Reinecke committed
419
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
420
            return res
421
        met = makeOp((self._theta+1) / (self._theta+3), self.domain)
Philipp Arras's avatar
Philipp Arras committed
422
        return res.add_metric(SamplingDtypeSetter(met, np.float64))
423
424


Martin Reinecke's avatar
Martin Reinecke committed
425
class BernoulliEnergy(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
426
    """Computes likelihood energy of expected event frequency constrained by
427
428
    event data.

Philipp Arras's avatar
Philipp Arras committed
429
430
431
432
433
434
435
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

436
437
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
438
    d : Field
Philipp Arras's avatar
Philipp Arras committed
439
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
440
    """
Philipp Arras's avatar
Philipp Arras committed
441

442
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
443
444
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
445
        if np.any(np.logical_and(d.val != 0, d.val != 1)):
Philipp Arras's avatar
Philipp Arras committed
446
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
447
        self._d = d
448
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
449

Philipp Arras's avatar
Philipp Arras committed
450
    def apply(self, x):
451
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
452
        res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
Martin Reinecke's avatar
more    
Martin Reinecke committed
453
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
454
            return res
Philipp Arras's avatar
Philipp Arras committed
455
        met = makeOp(1./(x.val*(1. - x.val)))
456
        return res.add_metric(SamplingDtypeSetter(met, np.float64))
Martin Reinecke's avatar
Martin Reinecke committed
457
458


459
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
460
461
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
462

Philipp Arras's avatar
Philipp Arras committed
463
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
464

Philipp Arras's avatar
Philipp Arras committed
465
466
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
467

Martin Reinecke's avatar
Martin Reinecke committed
468
    Other field priors can be represented via transformations of a white
469
470
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
471
    By implementing prior information this way, the field prior is represented
472
473
474
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
475
476
477
478
479
480
481
482
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
483
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
484
        to use to draw Gaussian samples.
Philipp Arras's avatar
Philipp Arras committed
485
486
    prior_dtype : numpy.dtype or dict of numpy.dtype, optional
        Data type of prior used for sampling.
Philipp Arras's avatar
Philipp Arras committed
487
488
489
490
491

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
492
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
493
    """
Philipp Arras's avatar
Philipp Arras committed
494

495
    def __init__(self, lh, ic_samp=None, prior_dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
496
        self._lh = lh
Philipp Arras's avatar
Philipp Arras committed
497
        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
498
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
499
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
500

Philipp Arras's avatar
Philipp Arras committed
501
    def apply(self, x):
502
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
503
        if not x.want_metric or self._ic_samp is None:
Philipp Arras's avatar
Philipp Arras committed
504
            return (self._lh + self._prior)(x)
Philipp Arras's avatar
Philipp Arras committed
505
506
        lhx, prx = self._lh(x), self._prior(x)
        return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
Martin Reinecke's avatar
Martin Reinecke committed
507

Philipp Arras's avatar
Philipp Arras committed
508
509
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
510
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
511
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
512

513
514
515
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp)
516

Martin Reinecke's avatar
Martin Reinecke committed
517

Martin Reinecke's avatar
Martin Reinecke committed
518
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
519
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
520

521
522
523
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
524
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
525
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
526
527
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
528

Philipp Arras's avatar
Docs    
Philipp Arras committed
529
530
531
532
533
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
534

Philipp Arras's avatar
Docs    
Philipp Arras committed
535
536
537
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
538
    """
Martin Reinecke's avatar
Martin Reinecke committed
539
540
541

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
542
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
543
544
        self._res_samples = tuple(res_samples)

Philipp Arras's avatar
Philipp Arras committed
545
    def apply(self, x):
546
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
547
548
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap)/len(self._res_samples)