energy_operators.py 21.1 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2021 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
22
from ..field import Field
Philipp Frank's avatar
Philipp Frank committed
23
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
24
25
26
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
27
from ..utilities import myassert
Philipp Arras's avatar
Philipp Arras committed
28
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
29
from .operator import Operator
Philipp Frank's avatar
Philipp Frank committed
30
from .adder import Adder
31
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
32
from .scaling_operator import ScalingOperator
Philipp Frank's avatar
Philipp Frank committed
33
34
from .sandwich_operator import SandwichOperator
from .simple_linear_operators import VdotOperator, FieldAdapter
Philipp Arras's avatar
Philipp Arras committed
35
36
37
38
39
40


def _check_sampling_dtype(domain, dtypes):
    if dtypes is None:
        return
    if isinstance(domain, DomainTuple):
Philipp Arras's avatar
Philipp Arras committed
41
42
        np.dtype(dtypes)
        return
Philipp Arras's avatar
Philipp Arras committed
43
    elif isinstance(domain, MultiDomain):
Philipp Arras's avatar
Philipp Arras committed
44
45
46
47
48
49
50
        if isinstance(dtypes, dict):
            for dt in dtypes.values():
                np.dtype(dt)
            if set(domain.keys()) == set(dtypes.keys()):
                return
        else:
            np.dtype(dtypes)
Philipp Arras's avatar
Philipp Arras committed
51
            return
Philipp Arras's avatar
Philipp Arras committed
52
    raise TypeError
Philipp Arras's avatar
Philipp Arras committed
53
54


55
56
57
58
def _iscomplex(dtype):
    return np.issubdtype(dtype, np.complexfloating)


Philipp Arras's avatar
Philipp Arras committed
59
60
61
62
63
64
65
66
67
68
69
def _field_to_dtype(field):
    if isinstance(field, Field):
        dt = field.dtype
    elif isinstance(field, MultiField):
        dt = {kk: ff.dtype for kk, ff in field.items()}
    else:
        raise TypeError
    _check_sampling_dtype(field.domain, dt)
    return dt


Martin Reinecke's avatar
Martin Reinecke committed
70
class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
71
    """Operator which has a scalar domain as target domain.
72

Martin Reinecke's avatar
Martin Reinecke committed
73
    It is intended as an objective function for field inference.
74

Philipp Arras's avatar
Philipp Arras committed
75
76
77
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
78
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
79
       divergence.
80
    """
Martin Reinecke's avatar
Martin Reinecke committed
81
82
83
    _target = DomainTuple.scalar_domain()


Philipp Frank's avatar
Philipp Frank committed
84
85
86
87
88
89
90
91
92
class LikelihoodOperator(EnergyOperator):
    """`EnergyOperator` representing a likelihood. The input to the Operator
    are the parameters of the likelihood. Unlike a general `EnergyOperator`,
    the metric of a `LikelihoodOperator` is the Fisher information metric of
    the likelihood.
    """
    def get_metric_at(self, x):
        """Computes the Fisher information metric for a `LikelihoodOperator`
        at `x` using the Jacobian of the coordinate transformation given by
Philipp Frank's avatar
Philipp Frank committed
93
        :func:`~nifty7.operators.operator.Operator.get_transformation`.
Philipp Frank's avatar
Philipp Frank committed
94
95
96
97
98
99
100
101
        """
        dtp, f = self.get_transformation()
        ch = ScalingOperator(f.target, 1.)
        if dtp is not None:
            ch = SamplingDtypeSetter(ch, dtp)
        return SandwichOperator.make(f(Linearization.make_var(x)).jac, ch)


102
103
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
104

Philipp Arras's avatar
Philipp Arras committed
105
106
107
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
108
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
109
    """
Philipp Arras's avatar
Philipp Arras committed
110

Martin Reinecke's avatar
Martin Reinecke committed
111
112
113
    def __init__(self, domain):
        self._domain = domain

Philipp Arras's avatar
Philipp Arras committed
114
    def apply(self, x):
115
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
116
117
        if x.jac is None:
            return x.vdot(x)
Philipp Arras's avatar
Philipp Arras committed
118
119
        res = x.val.vdot(x.val)
        return x.new(res, VdotOperator(2*x.val))
Martin Reinecke's avatar
Martin Reinecke committed
120

Martin Reinecke's avatar
Martin Reinecke committed
121

Martin Reinecke's avatar
Martin Reinecke committed
122
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
123
    """Computes the L2-norm of a Field or MultiField with respect to a
124
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
125
126
127

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
128
129
130

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
131
    endo : EndomorphicOperator
132
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
133
    """
Philipp Arras's avatar
Philipp Arras committed
134
135

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
136
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
137
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
138
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
139
140
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
141

Philipp Arras's avatar
Philipp Arras committed
142
    def apply(self, x):
143
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
144
        if x.jac is None:
Philipp Arras's avatar
Philipp Arras committed
145
146
147
            return 0.5*x.vdot(self._op(x))
        res = 0.5*x.val.vdot(self._op(x.val))
        return x.new(res, VdotOperator(self._op(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
148

Philipp Arras's avatar
Philipp Arras committed
149

Philipp Frank's avatar
Philipp Frank committed
150
class VariableCovarianceGaussianEnergy(LikelihoodOperator):
Reimar Leike's avatar
Reimar Leike committed
151
    """Computes the negative log pdf of a Gaussian with unknown covariance.
152

Reimar Leike's avatar
Reimar Leike committed
153
    The covariance is assumed to be diagonal.
154
155

    .. math ::
156
        E(s,D) = - \\log G(s, C) = 0.5 (s)^\\dagger C (s) - 0.5 tr log(C),
157
158

    an information energy for a Gaussian distribution with residual s and
159
    inverse diagonal covariance C.
Reimar Leike's avatar
Reimar Leike committed
160
161
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
162
163
164

    Parameters
    ----------
165
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
166
        domain of the residual and domain of the covariance diagonal.
167

168
    residual_key : key
Philipp Arras's avatar
Philipp Arras committed
169
        Residual key of the Gaussian.
170

171
    inverse_covariance_key : key
172
        Inverse covariance diagonal key of the Gaussian.
Philipp Arras's avatar
Philipp Arras committed
173

174
    sampling_dtype : np.dtype
Philipp Arras's avatar
Philipp Arras committed
175
        Data type of the samples. Usually either 'np.float*' or 'np.complex*'
Philipp Frank's avatar
Philipp Frank committed
176
177
178
179
180
181

    use_full_fisher: boolean
        Whether or not the proper Fisher information metric should be used as
        a `metric`. If False the same approximation used in 
        `get_transformation` is used instead.
        Default is True
182
183
    """

Philipp Frank's avatar
Philipp Frank committed
184
185
    def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype,
                 use_full_fisher = True):
Philipp Arras's avatar
Philipp Arras committed
186
187
        self._kr = str(residual_key)
        self._ki = str(inverse_covariance_key)
Philipp Arras's avatar
Philipp Arras committed
188
        dom = DomainTuple.make(domain)
Philipp Arras's avatar
Philipp Arras committed
189
        self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
Philipp Arras's avatar
Philipp Arras committed
190
191
        self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
        _check_sampling_dtype(self._domain, self._dt)
192
        self._cplx = _iscomplex(sampling_dtype)
Philipp Frank's avatar
Philipp Frank committed
193
        self._use_fisher = use_full_fisher
194

Philipp Arras's avatar
Philipp Arras committed
195
    def apply(self, x):
196
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
197
        r, i = x[self._kr], x[self._ki]
Philipp Arras's avatar
Philipp Arras committed
198
199
200
201
        if self._cplx:
            res = 0.5*r.vdot(r*i.real).real - i.ptw("log").sum()
        else:
            res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
Martin Reinecke's avatar
more    
Martin Reinecke committed
202
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
203
            return res
Philipp Frank's avatar
Philipp Frank committed
204
205
206
207
208
209
210
211
        if self._use_fisher:
            met = 1. if self._cplx else 0.5
            met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
                                        domain=self._domain)
            met = SamplingDtypeSetter(makeOp(met), self._dt)
        else:
            met = self.get_metric_at(x.val)
        return res.add_metric(met)
212

213
214
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        from .simplify_for_const import ConstantEnergyOperator
215
        myassert(len(c_inp.keys()) == 1)
216
        key = c_inp.keys()[0]
217
        myassert(key in self._domain.keys())
218
219
220
221
222
223
224
225
        cst = c_inp[key]
        if key == self._kr:
            res = _SpecialGammaEnergy(cst).ducktape(self._ki)
        else:
            dt = self._dt[self._kr]
            res = GaussianEnergy(inverse_covariance=makeOp(cst),
                                 sampling_dtype=dt).ducktape(self._kr)
            trlog = cst.log().sum().val_rw()
Philipp Frank's avatar
Philipp Frank committed
226
            if not self._cplx:
227
228
229
                trlog /= 2
            res = res + ConstantEnergyOperator(-trlog)
        res = res + ConstantEnergyOperator(0.)
230
        myassert(res.target is self.target)
231
        return None, res
232

Philipp Frank's avatar
Philipp Frank committed
233
234
235
236
237
238
239
240
241
    def get_transformation(self):
        """Note that for the metric of a `VariableCovarianceGaussianEnergy` no 
        global transformation to Euclidean space exists. A local approximation
        ivoking the resudual is used instead.
        """
        r = FieldAdapter(self._domain[self._kr], self._kr)
        ivar = FieldAdapter(self._domain[self._kr], self._ki)
        sc = 1. if self._cplx else 0.5
        return self._dt, r.adjoint@(ivar.ptw('sqrt')*r) + ivar.adjoint@(sc*ivar.ptw('log'))
242

Philipp Frank's avatar
Philipp Frank committed
243
244

class _SpecialGammaEnergy(LikelihoodOperator):
245
246
247
248
    def __init__(self, residual):
        self._domain = DomainTuple.make(residual.domain)
        self._resi = residual
        self._cplx = _iscomplex(self._resi.dtype)
Philipp Frank's avatar
Philipp Frank committed
249
        self._dt = self._resi.dtype
250
251
252
253
254
255
256
257
258
259

    def apply(self, x):
        self._check_input(x)
        r = self._resi
        if self._cplx:
            res = 0.5*(r*x.real).vdot(r).real - x.ptw("log").sum()
        else:
            res = 0.5*((r*x).vdot(r) - x.ptw("log").sum())
        if not x.want_metric:
            return res
Philipp Frank's avatar
Philipp Frank committed
260
        return res.add_metric(self.get_metric_at(x.val))
261

Philipp Frank's avatar
Philipp Frank committed
262
263
264
    def get_transformation(self):
        sc = 1. if self._cplx else np.sqrt(0.5)
        return self._dt, sc*ScalingOperator(self._domain, 1.).ptw('log')
Martin Reinecke's avatar
Martin Reinecke committed
265

Philipp Frank's avatar
Philipp Frank committed
266
class GaussianEnergy(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
267
    """Computes a negative-log Gaussian.
268

Philipp Arras's avatar
Philipp Arras committed
269
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
270

Philipp Arras's avatar
Philipp Arras committed
271
272
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
273

Philipp Arras's avatar
Philipp Arras committed
274
275
    an information energy for a Gaussian distribution with mean m and
    covariance D.
276

Philipp Arras's avatar
Philipp Arras committed
277
278
279
280
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
281
282
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
283
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
284
285
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified
Reimar Leike's avatar
Reimar Leike committed
286
    sampling_dtype : type
Martin Reinecke's avatar
Martin Reinecke committed
287
        Here one can specify whether the distribution is a complex Gaussian or
Reimar Leike's avatar
Reimar Leike committed
288
289
290
        not. Note that for a complex Gaussian the inverse_covariance is
        .. math ::
        (<ff^dagger>)^{-1}_P(f)/2,
291
        where the additional factor of 2 is necessary because the
Reimar Leike's avatar
Reimar Leike committed
292
        domain of s has double as many dimensions as in the real case.
Philipp Arras's avatar
Philipp Arras committed
293
294
295
296

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
297
    """
Martin Reinecke's avatar
Martin Reinecke committed
298

Philipp Arras's avatar
Philipp Arras committed
299
    def __init__(self, mean=None, inverse_covariance=None, domain=None, sampling_dtype=None):
Martin Reinecke's avatar
Martin Reinecke committed
300
301
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
302
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
303
304
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
305
306
307
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
308
309
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
310
311
312
313
314
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Philipp Arras's avatar
Philipp Arras committed
315
316
317
318
319
320
321
322
323
324
325

        # Infer sampling dtype
        if self._mean is None:
            _check_sampling_dtype(self._domain, sampling_dtype)
        else:
            if sampling_dtype is None:
                sampling_dtype = _field_to_dtype(self._mean)
            else:
                if sampling_dtype != _field_to_dtype(self._mean):
                    raise ValueError("Sampling dtype and mean not compatible")

Philipp Arras's avatar
Philipp Arras committed
326
        self._icov = inverse_covariance
327
        if inverse_covariance is None:
328
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Philipp Arras's avatar
Philipp Arras committed
329
            self._met = ScalingOperator(self._domain, 1)
Martin Reinecke's avatar
Martin Reinecke committed
330
        else:
331
            self._op = QuadraticFormOperator(inverse_covariance)
Philipp Arras's avatar
Philipp Arras committed
332
            self._met = inverse_covariance
Philipp Arras's avatar
Philipp Arras committed
333
        if sampling_dtype is not None:
334
            self._met = SamplingDtypeSetter(self._met, sampling_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
335
336

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
337
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
338
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
339
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
340
        else:
Philipp Arras's avatar
Philipp Arras committed
341
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
342
343
                raise ValueError("domain mismatch")

Philipp Arras's avatar
Philipp Arras committed
344
    def apply(self, x):
345
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
346
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
347
        res = self._op(residual).real
Martin Reinecke's avatar
more    
Martin Reinecke committed
348
        if x.want_metric:
Philipp Frank's avatar
Philipp Frank committed
349
            return res.add_metric(self.get_metric_at(x.val))
Philipp Arras's avatar
Philipp Arras committed
350
        return res
Martin Reinecke's avatar
Martin Reinecke committed
351

Philipp Frank's avatar
Philipp Frank committed
352
353
354
355
356
357
358
    def get_transformation(self):
        icov, dtp = self._met, None
        if isinstance(icov, SamplingDtypeSetter):
            dtp = icov._dtype
            icov = icov._op
        return dtp, icov.get_sqrt()

Philipp Arras's avatar
Philipp Arras committed
359
360
361
362
    def __repr__(self):
        dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
        return f'GaussianEnergy {dom}'

Martin Reinecke's avatar
Martin Reinecke committed
363

Philipp Frank's avatar
Philipp Frank committed
364
class PoissonianEnergy(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
365
366
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
367

Philipp Arras's avatar
Philipp Arras committed
368
    Represents up to an f-independent term :math:`log(d!)`:
369

Philipp Arras's avatar
Philipp Arras committed
370
371
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
372

Philipp Arras's avatar
Philipp Arras committed
373
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
374
    the counts.
Philipp Arras's avatar
Philipp Arras committed
375
376
377
378
379
380

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
381
    """
Philipp Arras's avatar
Philipp Arras committed
382

383
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
384
385
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
386
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
387
            raise ValueError
388
389
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
390

Philipp Arras's avatar
Philipp Arras committed
391
    def apply(self, x):
392
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
393
        res = x.sum() - x.ptw("log").vdot(self._d)
Martin Reinecke's avatar
more    
Martin Reinecke committed
394
        if not x.want_metric:
395
            return res
Philipp Frank's avatar
Philipp Frank committed
396
        return res.add_metric(self.get_metric_at(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
397

Philipp Frank's avatar
Philipp Frank committed
398
399
    def get_transformation(self):
        return np.float64, 2.*ScalingOperator(self._domain,1.).sqrt()
400

Philipp Frank's avatar
Philipp Frank committed
401
class InverseGammaLikelihood(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
402
    """Computes the negative log-likelihood of the inverse gamma distribution.
403
404
405

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
406
407
408
409
410
411
412
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
413
414
415
416
417
418
419

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
420
    """
Philipp Arras's avatar
Philipp Arras committed
421

422
423
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
424
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
425
        self._domain = DomainTuple.make(beta.domain)
426
427
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
428
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
429
430
431
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
Philipp Arras's avatar
Philipp Arras committed
432
433
434
435
        if not self._beta.dtype == np.float64:
            # FIXME Add proper complex support for this energy
            raise TypeError
        self._sampling_dtype = _field_to_dtype(self._beta)
436

Philipp Arras's avatar
Philipp Arras committed
437
    def apply(self, x):
438
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
439
        res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
Martin Reinecke's avatar
more    
Martin Reinecke committed
440
        if not x.want_metric:
441
            return res
Philipp Frank's avatar
Philipp Frank committed
442
        return res.add_metric(self.get_metric_at(x.val))
443

Philipp Frank's avatar
Philipp Frank committed
444
445
446
447
    def get_transformation(self):
        fact = self._alphap1.ptw('sqrt')
        res = makeOp(fact)@ScalingOperator(self._domain,1.).ptw('log')
        return self._sampling_dtype, res
448

Philipp Frank's avatar
Philipp Frank committed
449
class StudentTEnergy(LikelihoodOperator):
Lukas Platz's avatar
Lukas Platz committed
450
    """Computes likelihood energy corresponding to Student's t-distribution.
451
452

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
453
454
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
455

Philipp Arras's avatar
Philipp Arras committed
456
457
    where f is a field defined on `domain`. Assumes that the data is `float64`
    for sampling.
458
459
460

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
461
462
    domain : `Domain` or `DomainTuple`
        Domain of the operator
Reimar Leike's avatar
Reimar Leike committed
463
    theta : Scalar or Field
464
465
466
        Degree of freedom parameter for the student t distribution
    """

Philipp Arras's avatar
Philipp Arras committed
467
    def __init__(self, domain, theta):
468
469
470
        self._domain = DomainTuple.make(domain)
        self._theta = theta

Philipp Arras's avatar
Philipp Arras committed
471
    def apply(self, x):
472
        self._check_input(x)
473
        res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
Martin Reinecke's avatar
more    
Martin Reinecke committed
474
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
475
            return res
Philipp Frank's avatar
Philipp Frank committed
476
        return res.add_metric(self.get_metric_at(x.val))
477

Philipp Frank's avatar
Philipp Frank committed
478
479
480
481
482
483
484
    def get_transformation(self):
        if isinstance(self._theta, Field) or isinstance(self._theta, MultiField):
            th = self._theta
        else:
            from ..extra import full
            th = full(self._domain, self._theta)
        return np.float64, makeOp(((th+1)/(th+3)).ptw('sqrt'))
485

Philipp Frank's avatar
Philipp Frank committed
486
class BernoulliEnergy(LikelihoodOperator):
Philipp Arras's avatar
Philipp Arras committed
487
    """Computes likelihood energy of expected event frequency constrained by
488
489
    event data.

Philipp Arras's avatar
Philipp Arras committed
490
491
492
493
494
495
496
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

497
498
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
499
    d : Field
Philipp Arras's avatar
Philipp Arras committed
500
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
501
    """
Philipp Arras's avatar
Philipp Arras committed
502

503
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
504
505
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
506
        if np.any(np.logical_and(d.val != 0, d.val != 1)):
Philipp Arras's avatar
Philipp Arras committed
507
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
508
        self._d = d
509
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
510

Philipp Arras's avatar
Philipp Arras committed
511
    def apply(self, x):
512
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
513
        res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
Martin Reinecke's avatar
more    
Martin Reinecke committed
514
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
515
            return res
Philipp Frank's avatar
Philipp Frank committed
516
        return res.add_metric(self.get_metric_at(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
517

Philipp Frank's avatar
Philipp Frank committed
518
519
520
521
522
    def get_transformation(self):
        from ..extra import full
        res = Adder(full(self._domain,1.))@ScalingOperator(self._domain,-1)
        res = res * ScalingOperator(self._domain,1).ptw('reciprocal')
        return np.float64, -2.*res.ptw('sqrt').ptw('arctan')
Martin Reinecke's avatar
Martin Reinecke committed
523

524
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
525
526
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
527

Philipp Arras's avatar
Philipp Arras committed
528
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
529

Philipp Arras's avatar
Philipp Arras committed
530
531
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
532

Martin Reinecke's avatar
Martin Reinecke committed
533
    Other field priors can be represented via transformations of a white
534
535
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
536
    By implementing prior information this way, the field prior is represented
537
538
539
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
540
541
542
543
544
545
546
547
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
548
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
549
        to use to draw Gaussian samples.
Philipp Arras's avatar
Philipp Arras committed
550
551
    prior_dtype : numpy.dtype or dict of numpy.dtype, optional
        Data type of prior used for sampling.
Philipp Arras's avatar
Philipp Arras committed
552
553
554
555
556

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
557
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
558
    """
Philipp Arras's avatar
Philipp Arras committed
559

560
    def __init__(self, lh, ic_samp=None, prior_dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
561
        self._lh = lh
Philipp Arras's avatar
Philipp Arras committed
562
        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
563
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
564
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
565

Philipp Arras's avatar
Philipp Arras committed
566
    def apply(self, x):
567
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
568
        if not x.want_metric or self._ic_samp is None:
Philipp Arras's avatar
Philipp Arras committed
569
            return (self._lh + self._prior)(x)
Philipp Arras's avatar
Philipp Arras committed
570
571
        lhx, prx = self._lh(x), self._prior(x)
        return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
Martin Reinecke's avatar
Martin Reinecke committed
572

Philipp Arras's avatar
Philipp Arras committed
573
574
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
575
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
576
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
577

578
579
580
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp)
581

Martin Reinecke's avatar
Martin Reinecke committed
582

Martin Reinecke's avatar
Martin Reinecke committed
583
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
584
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
585

586
587
588
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
589
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
590
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
591
592
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
593

Philipp Arras's avatar
Docs    
Philipp Arras committed
594
595
596
597
598
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
599

Philipp Arras's avatar
Docs    
Philipp Arras committed
600
601
602
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
603
    """
Martin Reinecke's avatar
Martin Reinecke committed
604
605
606

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
607
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
608
609
        self._res_samples = tuple(res_samples)

Philipp Arras's avatar
Philipp Arras committed
610
    def apply(self, x):
611
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
612
613
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap)/len(self._res_samples)
Philipp Frank's avatar
Philipp Frank committed
614
615
616
617
618
619
620
621

    def get_transformation(self):
        dtp, trafo = self._h.get_transformation()
        mymap = map(lambda v: trafo@Adder(v), self._res_samples)
        return dtp, utilities.my_sum(mymap)/np.sqrt(len(self._res_samples))