energy_operators.py 21.2 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2021 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
22
from ..field import Field
Philipp Frank's avatar
Philipp Frank committed
23
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
24
25
26
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
27
from ..utilities import myassert
Philipp Arras's avatar
Philipp Arras committed
28
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
29
from .operator import Operator
Philipp Frank's avatar
Philipp Frank committed
30
from .adder import Adder
31
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
32
from .scaling_operator import ScalingOperator
Philipp Frank's avatar
Philipp Frank committed
33
34
from .sandwich_operator import SandwichOperator
from .simple_linear_operators import VdotOperator, FieldAdapter
Philipp Arras's avatar
Philipp Arras committed
35
36
37
38
39
40


def _check_sampling_dtype(domain, dtypes):
    if dtypes is None:
        return
    if isinstance(domain, DomainTuple):
Philipp Arras's avatar
Philipp Arras committed
41
42
        np.dtype(dtypes)
        return
Philipp Arras's avatar
Philipp Arras committed
43
    elif isinstance(domain, MultiDomain):
Philipp Arras's avatar
Philipp Arras committed
44
45
46
47
48
49
50
        if isinstance(dtypes, dict):
            for dt in dtypes.values():
                np.dtype(dt)
            if set(domain.keys()) == set(dtypes.keys()):
                return
        else:
            np.dtype(dtypes)
Philipp Arras's avatar
Philipp Arras committed
51
            return
Philipp Arras's avatar
Philipp Arras committed
52
    raise TypeError
Philipp Arras's avatar
Philipp Arras committed
53
54


55
56
57
58
def _iscomplex(dtype):
    return np.issubdtype(dtype, np.complexfloating)


Philipp Arras's avatar
Philipp Arras committed
59
60
61
62
63
64
65
66
67
68
69
def _field_to_dtype(field):
    if isinstance(field, Field):
        dt = field.dtype
    elif isinstance(field, MultiField):
        dt = {kk: ff.dtype for kk, ff in field.items()}
    else:
        raise TypeError
    _check_sampling_dtype(field.domain, dt)
    return dt


Martin Reinecke's avatar
Martin Reinecke committed
70
class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
71
    """Operator which has a scalar domain as target domain.
72

Philipp Arras's avatar
Philipp Arras committed
73
74
75
    It is intended as an objective function for field inference.  It can
    implement a positive definite, symmetric form (called `metric`) that is
    used as curvature for second-order minimizations.
76

Philipp Arras's avatar
Philipp Arras committed
77
78
79
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
80
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
81
       divergence.
82
    """
Martin Reinecke's avatar
Martin Reinecke committed
83
84
85
    _target = DomainTuple.scalar_domain()


Philipp Frank's avatar
Philipp Frank committed
86
class LikelihoodOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
87
88
89
90
91
    """Represent a log-likelihood.

    The input to the Operator are the parameters of the likelihood. Unlike a
    general `EnergyOperator`, the metric of a `LikelihoodOperator` is the
    Fisher information metric of the likelihood.
Philipp Frank's avatar
Philipp Frank committed
92
    """
Philipp Arras's avatar
Philipp Arras committed
93

Philipp Frank's avatar
Philipp Frank committed
94
95
96
    def get_metric_at(self, x):
        """Computes the Fisher information metric for a `LikelihoodOperator`
        at `x` using the Jacobian of the coordinate transformation given by
Philipp Frank's avatar
Philipp Frank committed
97
        :func:`~nifty7.operators.operator.Operator.get_transformation`.
Philipp Frank's avatar
Philipp Frank committed
98
99
        """
        dtp, f = self.get_transformation()
Philipp Arras's avatar
Philipp Arras committed
100
        ch = None
Philipp Frank's avatar
Philipp Frank committed
101
        if dtp is not None:
Philipp Arras's avatar
Philipp Arras committed
102
103
104
            ch = SamplingDtypeSetter(ScalingOperator(f.target, 1.), dtp)
        bun = f(Linearization.make_var(x)).jac
        return SandwichOperator.make(bun, ch)
Philipp Frank's avatar
Philipp Frank committed
105
106


107
108
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
109

Philipp Arras's avatar
Philipp Arras committed
110
111
112
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
113
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
114
    """
Philipp Arras's avatar
Philipp Arras committed
115

Martin Reinecke's avatar
Martin Reinecke committed
116
117
118
    def __init__(self, domain):
        self._domain = domain

Philipp Arras's avatar
Philipp Arras committed
119
    def apply(self, x):
120
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
121
122
        if x.jac is None:
            return x.vdot(x)
Philipp Arras's avatar
Philipp Arras committed
123
124
        res = x.val.vdot(x.val)
        return x.new(res, VdotOperator(2*x.val))
Martin Reinecke's avatar
Martin Reinecke committed
125

Martin Reinecke's avatar
Martin Reinecke committed
126

Martin Reinecke's avatar
Martin Reinecke committed
127
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
128
    """Computes the L2-norm of a Field or MultiField with respect to a
129
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
130
131
132

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
133
134
135

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
136
    endo : EndomorphicOperator
137
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
138
    """
Philipp Arras's avatar
Philipp Arras committed
139
140

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
141
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
142
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
143
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
144
145
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
146

Philipp Arras's avatar
Philipp Arras committed
147
    def apply(self, x):
148
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
149
        if x.jac is None:
Philipp Arras's avatar
Philipp Arras committed
150
151
152
            return 0.5*x.vdot(self._op(x))
        res = 0.5*x.val.vdot(self._op(x.val))
        return x.new(res, VdotOperator(self._op(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
153

Philipp Arras's avatar
Philipp Arras committed
154

Philipp Frank's avatar
Philipp Frank committed
155
class VariableCovarianceGaussianEnergy(LikelihoodOperator):
Reimar Leike's avatar
Reimar Leike committed
156
    """Computes the negative log pdf of a Gaussian with unknown covariance.
157

Reimar Leike's avatar
Reimar Leike committed
158
    The covariance is assumed to be diagonal.
159
160

    .. math ::
161
        E(s,D) = - \\log G(s, C) = 0.5 (s)^\\dagger C (s) - 0.5 tr log(C),
162
163

    an information energy for a Gaussian distribution with residual s and
164
    inverse diagonal covariance C.
Reimar Leike's avatar
Reimar Leike committed
165
166
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
167
168
169

    Parameters
    ----------
170
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
171
        domain of the residual and domain of the covariance diagonal.
172

173
    residual_key : key
Philipp Arras's avatar
Philipp Arras committed
174
        Residual key of the Gaussian.
175

176
    inverse_covariance_key : key
177
        Inverse covariance diagonal key of the Gaussian.
Philipp Arras's avatar
Philipp Arras committed
178

179
    sampling_dtype : np.dtype
Philipp Arras's avatar
Philipp Arras committed
180
        Data type of the samples. Usually either 'np.float*' or 'np.complex*'
Philipp Frank's avatar
Philipp Frank committed
181
182

    use_full_fisher: boolean
Philipp Arras's avatar
Philipp Arras committed
183
184
185
        Determines if the proper Fisher information metric should be used as
        `metric`. If False, the same approximation as in `get_transformation`
        is used. Default is True.
186
187
    """

Philipp Arras's avatar
Philipp Arras committed
188
189
    def __init__(self, domain, residual_key, inverse_covariance_key,
                 sampling_dtype, use_full_fisher=True):
Philipp Arras's avatar
Philipp Arras committed
190
191
        self._kr = str(residual_key)
        self._ki = str(inverse_covariance_key)
Philipp Arras's avatar
Philipp Arras committed
192
        dom = DomainTuple.make(domain)
Philipp Arras's avatar
Philipp Arras committed
193
        self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
Philipp Arras's avatar
Philipp Arras committed
194
195
        self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
        _check_sampling_dtype(self._domain, self._dt)
196
        self._cplx = _iscomplex(sampling_dtype)
Philipp Arras's avatar
Philipp Arras committed
197
        self._use_full_fisher = use_full_fisher
198

Philipp Arras's avatar
Philipp Arras committed
199
    def apply(self, x):
200
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
201
        r, i = x[self._kr], x[self._ki]
Philipp Arras's avatar
Philipp Arras committed
202
203
204
205
        if self._cplx:
            res = 0.5*r.vdot(r*i.real).real - i.ptw("log").sum()
        else:
            res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
Martin Reinecke's avatar
more    
Martin Reinecke committed
206
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
207
            return res
Philipp Arras's avatar
Philipp Arras committed
208
        if self._use_full_fisher:
Philipp Frank's avatar
Philipp Frank committed
209
210
211
212
213
214
215
            met = 1. if self._cplx else 0.5
            met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
                                        domain=self._domain)
            met = SamplingDtypeSetter(makeOp(met), self._dt)
        else:
            met = self.get_metric_at(x.val)
        return res.add_metric(met)
216

217
218
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        from .simplify_for_const import ConstantEnergyOperator
219
        myassert(len(c_inp.keys()) == 1)
220
        key = c_inp.keys()[0]
221
        myassert(key in self._domain.keys())
222
223
224
225
226
227
228
229
        cst = c_inp[key]
        if key == self._kr:
            res = _SpecialGammaEnergy(cst).ducktape(self._ki)
        else:
            dt = self._dt[self._kr]
            res = GaussianEnergy(inverse_covariance=makeOp(cst),
                                 sampling_dtype=dt).ducktape(self._kr)
            trlog = cst.log().sum().val_rw()
Philipp Frank's avatar
Philipp Frank committed
230
            if not self._cplx:
231
232
233
                trlog /= 2
            res = res + ConstantEnergyOperator(-trlog)
        res = res + ConstantEnergyOperator(0.)
234
        myassert(res.target is self.target)
235
        return None, res
236

Philipp Frank's avatar
Philipp Frank committed
237
    def get_transformation(self):
Philipp Arras's avatar
Philipp Arras committed
238
        """Note that for the metric of a `VariableCovarianceGaussianEnergy` no
Philipp Frank's avatar
Philipp Frank committed
239
        global transformation to Euclidean space exists. A local approximation
Philipp Arras's avatar
Philipp Arras committed
240
        invoking the residual is used instead.
Philipp Frank's avatar
Philipp Frank committed
241
242
243
244
245
        """
        r = FieldAdapter(self._domain[self._kr], self._kr)
        ivar = FieldAdapter(self._domain[self._kr], self._ki)
        sc = 1. if self._cplx else 0.5
        return self._dt, r.adjoint@(ivar.ptw('sqrt')*r) + ivar.adjoint@(sc*ivar.ptw('log'))
246

Philipp Frank's avatar
Philipp Frank committed
247
248

class _SpecialGammaEnergy(LikelihoodOperator):
249
250
251
252
    def __init__(self, residual):
        self._domain = DomainTuple.make(residual.domain)
        self._resi = residual
        self._cplx = _iscomplex(self._resi.dtype)
Philipp Frank's avatar
Philipp Frank committed
253
        self._dt = self._resi.dtype
254
255
256
257
258
259
260
261
262
263

    def apply(self, x):
        self._check_input(x)
        r = self._resi
        if self._cplx:
            res = 0.5*(r*x.real).vdot(r).real - x.ptw("log").sum()
        else:
            res = 0.5*((r*x).vdot(r) - x.ptw("log").sum())
        if not x.want_metric:
            return res
Philipp Frank's avatar
Philipp Frank committed
264
        return res.add_metric(self.get_metric_at(x.val))
265

Philipp Frank's avatar
Philipp Frank committed
266
267
268
    def get_transformation(self):
        sc = 1. if self._cplx else np.sqrt(0.5)
        return self._dt, sc*ScalingOperator(self._domain, 1.).ptw('log')
Martin Reinecke's avatar
Martin Reinecke committed
269

Philipp Frank's avatar
Philipp Frank committed
270
class GaussianEnergy(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
271
    """Computes a negative-log Gaussian.
272

Philipp Arras's avatar
Philipp Arras committed
273
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
274

Philipp Arras's avatar
Philipp Arras committed
275
276
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
277

Philipp Arras's avatar
Philipp Arras committed
278
279
    an information energy for a Gaussian distribution with mean m and
    covariance D.
280

Philipp Arras's avatar
Philipp Arras committed
281
282
283
284
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
285
286
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
287
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
288
289
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified
Reimar Leike's avatar
Reimar Leike committed
290
    sampling_dtype : type
Martin Reinecke's avatar
Martin Reinecke committed
291
        Here one can specify whether the distribution is a complex Gaussian or
Reimar Leike's avatar
Reimar Leike committed
292
293
294
        not. Note that for a complex Gaussian the inverse_covariance is
        .. math ::
        (<ff^dagger>)^{-1}_P(f)/2,
295
        where the additional factor of 2 is necessary because the
Reimar Leike's avatar
Reimar Leike committed
296
        domain of s has double as many dimensions as in the real case.
Philipp Arras's avatar
Philipp Arras committed
297
298
299
300

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
301
    """
Martin Reinecke's avatar
Martin Reinecke committed
302

Philipp Arras's avatar
Philipp Arras committed
303
    def __init__(self, mean=None, inverse_covariance=None, domain=None, sampling_dtype=None):
Martin Reinecke's avatar
Martin Reinecke committed
304
305
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
306
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
307
308
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
309
310
311
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
312
313
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
314
315
316
317
318
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Philipp Arras's avatar
Philipp Arras committed
319
320
321
322
323
324
325
326
327
328
329

        # Infer sampling dtype
        if self._mean is None:
            _check_sampling_dtype(self._domain, sampling_dtype)
        else:
            if sampling_dtype is None:
                sampling_dtype = _field_to_dtype(self._mean)
            else:
                if sampling_dtype != _field_to_dtype(self._mean):
                    raise ValueError("Sampling dtype and mean not compatible")

Philipp Arras's avatar
Philipp Arras committed
330
        self._icov = inverse_covariance
331
        if inverse_covariance is None:
332
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Philipp Arras's avatar
Philipp Arras committed
333
            self._met = ScalingOperator(self._domain, 1)
Martin Reinecke's avatar
Martin Reinecke committed
334
        else:
335
            self._op = QuadraticFormOperator(inverse_covariance)
Philipp Arras's avatar
Philipp Arras committed
336
            self._met = inverse_covariance
Philipp Arras's avatar
Philipp Arras committed
337
        if sampling_dtype is not None:
338
            self._met = SamplingDtypeSetter(self._met, sampling_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
339
340

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
341
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
342
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
343
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
344
        else:
Philipp Arras's avatar
Philipp Arras committed
345
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
346
347
                raise ValueError("domain mismatch")

Philipp Arras's avatar
Philipp Arras committed
348
    def apply(self, x):
349
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
350
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
351
        res = self._op(residual).real
Martin Reinecke's avatar
more    
Martin Reinecke committed
352
        if x.want_metric:
Philipp Frank's avatar
Philipp Frank committed
353
            return res.add_metric(self.get_metric_at(x.val))
Philipp Arras's avatar
Philipp Arras committed
354
        return res
Martin Reinecke's avatar
Martin Reinecke committed
355

Philipp Frank's avatar
Philipp Frank committed
356
357
358
359
360
361
362
    def get_transformation(self):
        icov, dtp = self._met, None
        if isinstance(icov, SamplingDtypeSetter):
            dtp = icov._dtype
            icov = icov._op
        return dtp, icov.get_sqrt()

Philipp Arras's avatar
Philipp Arras committed
363
364
365
366
    def __repr__(self):
        dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
        return f'GaussianEnergy {dom}'

Martin Reinecke's avatar
Martin Reinecke committed
367

Philipp Frank's avatar
Philipp Frank committed
368
class PoissonianEnergy(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
369
370
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
371

Philipp Arras's avatar
Philipp Arras committed
372
    Represents up to an f-independent term :math:`log(d!)`:
373

Philipp Arras's avatar
Philipp Arras committed
374
375
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
376

Philipp Arras's avatar
Philipp Arras committed
377
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
378
    the counts.
Philipp Arras's avatar
Philipp Arras committed
379
380
381
382
383
384

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
385
    """
Philipp Arras's avatar
Philipp Arras committed
386

387
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
388
389
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
390
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
391
            raise ValueError
392
393
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
394

Philipp Arras's avatar
Philipp Arras committed
395
    def apply(self, x):
396
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
397
        res = x.sum() - x.ptw("log").vdot(self._d)
Martin Reinecke's avatar
more    
Martin Reinecke committed
398
        if not x.want_metric:
399
            return res
Philipp Frank's avatar
Philipp Frank committed
400
        return res.add_metric(self.get_metric_at(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
401

Philipp Frank's avatar
Philipp Frank committed
402
403
    def get_transformation(self):
        return np.float64, 2.*ScalingOperator(self._domain,1.).sqrt()
404

Philipp Frank's avatar
Philipp Frank committed
405
class InverseGammaLikelihood(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
406
    """Computes the negative log-likelihood of the inverse gamma distribution.
407
408
409

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
410
411
412
413
414
415
416
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
417
418
419
420
421
422
423

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
424
    """
Philipp Arras's avatar
Philipp Arras committed
425

426
427
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
428
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
429
        self._domain = DomainTuple.make(beta.domain)
430
431
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
432
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
433
434
435
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
Philipp Arras's avatar
Philipp Arras committed
436
437
438
439
        if not self._beta.dtype == np.float64:
            # FIXME Add proper complex support for this energy
            raise TypeError
        self._sampling_dtype = _field_to_dtype(self._beta)
440

Philipp Arras's avatar
Philipp Arras committed
441
    def apply(self, x):
442
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
443
        res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
Martin Reinecke's avatar
more    
Martin Reinecke committed
444
        if not x.want_metric:
445
            return res
Philipp Frank's avatar
Philipp Frank committed
446
        return res.add_metric(self.get_metric_at(x.val))
447

Philipp Frank's avatar
Philipp Frank committed
448
449
450
451
    def get_transformation(self):
        fact = self._alphap1.ptw('sqrt')
        res = makeOp(fact)@ScalingOperator(self._domain,1.).ptw('log')
        return self._sampling_dtype, res
452

Philipp Frank's avatar
Philipp Frank committed
453
class StudentTEnergy(LikelihoodOperator):
Lukas Platz's avatar
Lukas Platz committed
454
    """Computes likelihood energy corresponding to Student's t-distribution.
455
456

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
457
458
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
459

Philipp Arras's avatar
Philipp Arras committed
460
461
    where f is a field defined on `domain`. Assumes that the data is `float64`
    for sampling.
462
463
464

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
465
466
    domain : `Domain` or `DomainTuple`
        Domain of the operator
Reimar Leike's avatar
Reimar Leike committed
467
    theta : Scalar or Field
468
469
470
        Degree of freedom parameter for the student t distribution
    """

Philipp Arras's avatar
Philipp Arras committed
471
    def __init__(self, domain, theta):
472
473
474
        self._domain = DomainTuple.make(domain)
        self._theta = theta

Philipp Arras's avatar
Philipp Arras committed
475
    def apply(self, x):
476
        self._check_input(x)
477
        res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
Martin Reinecke's avatar
more    
Martin Reinecke committed
478
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
479
            return res
Philipp Frank's avatar
Philipp Frank committed
480
        return res.add_metric(self.get_metric_at(x.val))
481

Philipp Frank's avatar
Philipp Frank committed
482
483
484
485
486
487
488
    def get_transformation(self):
        if isinstance(self._theta, Field) or isinstance(self._theta, MultiField):
            th = self._theta
        else:
            from ..extra import full
            th = full(self._domain, self._theta)
        return np.float64, makeOp(((th+1)/(th+3)).ptw('sqrt'))
489

Philipp Frank's avatar
Philipp Frank committed
490
class BernoulliEnergy(LikelihoodOperator):
Philipp Arras's avatar
Philipp Arras committed
491
    """Computes likelihood energy of expected event frequency constrained by
492
493
    event data.

Philipp Arras's avatar
Philipp Arras committed
494
495
496
497
498
499
500
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

501
502
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
503
    d : Field
Philipp Arras's avatar
Philipp Arras committed
504
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
505
    """
Philipp Arras's avatar
Philipp Arras committed
506

507
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
508
509
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
510
        if np.any(np.logical_and(d.val != 0, d.val != 1)):
Philipp Arras's avatar
Philipp Arras committed
511
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
512
        self._d = d
513
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
514

Philipp Arras's avatar
Philipp Arras committed
515
    def apply(self, x):
516
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
517
        res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
Martin Reinecke's avatar
more    
Martin Reinecke committed
518
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
519
            return res
Philipp Frank's avatar
Philipp Frank committed
520
        return res.add_metric(self.get_metric_at(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
521

Philipp Frank's avatar
Philipp Frank committed
522
523
524
525
526
    def get_transformation(self):
        from ..extra import full
        res = Adder(full(self._domain,1.))@ScalingOperator(self._domain,-1)
        res = res * ScalingOperator(self._domain,1).ptw('reciprocal')
        return np.float64, -2.*res.ptw('sqrt').ptw('arctan')
Martin Reinecke's avatar
Martin Reinecke committed
527

528
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
529
530
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
531

Philipp Arras's avatar
Philipp Arras committed
532
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
533

Philipp Arras's avatar
Philipp Arras committed
534
535
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
536

Martin Reinecke's avatar
Martin Reinecke committed
537
    Other field priors can be represented via transformations of a white
538
539
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
540
    By implementing prior information this way, the field prior is represented
541
542
543
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
544
545
546
547
548
549
550
551
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
552
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
553
        to use to draw Gaussian samples.
Philipp Arras's avatar
Philipp Arras committed
554
555
    prior_dtype : numpy.dtype or dict of numpy.dtype, optional
        Data type of prior used for sampling.
Philipp Arras's avatar
Philipp Arras committed
556
557
558
559
560

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
561
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
562
    """
Philipp Arras's avatar
Philipp Arras committed
563

564
    def __init__(self, lh, ic_samp=None, prior_dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
565
        self._lh = lh
Philipp Arras's avatar
Philipp Arras committed
566
        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
567
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
568
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
569

Philipp Arras's avatar
Philipp Arras committed
570
    def apply(self, x):
571
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
572
        if not x.want_metric or self._ic_samp is None:
Philipp Arras's avatar
Philipp Arras committed
573
            return (self._lh + self._prior)(x)
Philipp Arras's avatar
Philipp Arras committed
574
575
        lhx, prx = self._lh(x), self._prior(x)
        return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
Martin Reinecke's avatar
Martin Reinecke committed
576

Philipp Arras's avatar
Philipp Arras committed
577
578
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
579
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
580
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
581

582
583
584
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp)
585

Martin Reinecke's avatar
Martin Reinecke committed
586

Martin Reinecke's avatar
Martin Reinecke committed
587
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
588
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
589

590
591
592
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
593
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
594
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
595
596
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
597

Philipp Arras's avatar
Docs    
Philipp Arras committed
598
599
600
601
602
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
603

Philipp Arras's avatar
Docs    
Philipp Arras committed
604
605
606
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
607
    """
Martin Reinecke's avatar
Martin Reinecke committed
608
609
610

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
611
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
612
613
        self._res_samples = tuple(res_samples)

Philipp Arras's avatar
Philipp Arras committed
614
    def apply(self, x):
615
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
616
617
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap)/len(self._res_samples)
Philipp Frank's avatar
Philipp Frank committed
618
619
620
621
622

    def get_transformation(self):
        dtp, trafo = self._h.get_transformation()
        mymap = map(lambda v: trafo@Adder(v), self._res_samples)
        return dtp, utilities.my_sum(mymap)/np.sqrt(len(self._res_samples))