energy_operators.py 21.1 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2021 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
22
from ..field import Field
Philipp Frank's avatar
Philipp Frank committed
23
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
24
25
26
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
27
from ..utilities import myassert
Philipp Arras's avatar
Philipp Arras committed
28
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
29
from .operator import Operator
Philipp Frank's avatar
Philipp Frank committed
30
from .adder import Adder
31
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
32
from .scaling_operator import ScalingOperator
Philipp Frank's avatar
Philipp Frank committed
33
34
from .sandwich_operator import SandwichOperator
from .simple_linear_operators import VdotOperator, FieldAdapter
Philipp Arras's avatar
Philipp Arras committed
35
36
37
38
39
40


def _check_sampling_dtype(domain, dtypes):
    if dtypes is None:
        return
    if isinstance(domain, DomainTuple):
Philipp Arras's avatar
Philipp Arras committed
41
42
        np.dtype(dtypes)
        return
Philipp Arras's avatar
Philipp Arras committed
43
    elif isinstance(domain, MultiDomain):
Philipp Arras's avatar
Philipp Arras committed
44
45
46
47
48
49
50
        if isinstance(dtypes, dict):
            for dt in dtypes.values():
                np.dtype(dt)
            if set(domain.keys()) == set(dtypes.keys()):
                return
        else:
            np.dtype(dtypes)
Philipp Arras's avatar
Philipp Arras committed
51
            return
Philipp Arras's avatar
Philipp Arras committed
52
    raise TypeError
Philipp Arras's avatar
Philipp Arras committed
53
54


55
56
57
58
def _iscomplex(dtype):
    return np.issubdtype(dtype, np.complexfloating)


Philipp Arras's avatar
Philipp Arras committed
59
60
61
62
63
64
65
66
67
68
69
def _field_to_dtype(field):
    if isinstance(field, Field):
        dt = field.dtype
    elif isinstance(field, MultiField):
        dt = {kk: ff.dtype for kk, ff in field.items()}
    else:
        raise TypeError
    _check_sampling_dtype(field.domain, dt)
    return dt


Martin Reinecke's avatar
Martin Reinecke committed
70
class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
71
    """Operator which has a scalar domain as target domain.
72

Martin Reinecke's avatar
Martin Reinecke committed
73
    It is intended as an objective function for field inference.
74

Philipp Arras's avatar
Philipp Arras committed
75
76
77
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
78
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
79
       divergence.
80
    """
Martin Reinecke's avatar
Martin Reinecke committed
81
82
83
    _target = DomainTuple.scalar_domain()


Philipp Frank's avatar
Philipp Frank committed
84
85
86
87
88
89
90
91
92
class LikelihoodOperator(EnergyOperator):
    """`EnergyOperator` representing a likelihood. The input to the Operator
    are the parameters of the likelihood. Unlike a general `EnergyOperator`,
    the metric of a `LikelihoodOperator` is the Fisher information metric of
    the likelihood.
    """
    def get_metric_at(self, x):
        """Computes the Fisher information metric for a `LikelihoodOperator`
        at `x` using the Jacobian of the coordinate transformation given by
Philipp Frank's avatar
Philipp Frank committed
93
        :func:`~nifty7.operators.operator.Operator.get_transformation`.
Philipp Frank's avatar
Philipp Frank committed
94
95
        """
        dtp, f = self.get_transformation()
Philipp Arras's avatar
Philipp Arras committed
96
        ch = None
Philipp Frank's avatar
Philipp Frank committed
97
        if dtp is not None:
Philipp Arras's avatar
Philipp Arras committed
98
99
100
            ch = SamplingDtypeSetter(ScalingOperator(f.target, 1.), dtp)
        bun = f(Linearization.make_var(x)).jac
        return SandwichOperator.make(bun, ch)
Philipp Frank's avatar
Philipp Frank committed
101
102


103
104
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
105

Philipp Arras's avatar
Philipp Arras committed
106
107
108
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
109
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
110
    """
Philipp Arras's avatar
Philipp Arras committed
111

Martin Reinecke's avatar
Martin Reinecke committed
112
113
114
    def __init__(self, domain):
        self._domain = domain

Philipp Arras's avatar
Philipp Arras committed
115
    def apply(self, x):
116
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
117
118
        if x.jac is None:
            return x.vdot(x)
Philipp Arras's avatar
Philipp Arras committed
119
120
        res = x.val.vdot(x.val)
        return x.new(res, VdotOperator(2*x.val))
Martin Reinecke's avatar
Martin Reinecke committed
121

Martin Reinecke's avatar
Martin Reinecke committed
122

Martin Reinecke's avatar
Martin Reinecke committed
123
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
124
    """Computes the L2-norm of a Field or MultiField with respect to a
125
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
126
127
128

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
129
130
131

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
132
    endo : EndomorphicOperator
133
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
134
    """
Philipp Arras's avatar
Philipp Arras committed
135
136

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
137
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
138
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
139
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
140
141
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
142

Philipp Arras's avatar
Philipp Arras committed
143
    def apply(self, x):
144
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
145
        if x.jac is None:
Philipp Arras's avatar
Philipp Arras committed
146
147
148
            return 0.5*x.vdot(self._op(x))
        res = 0.5*x.val.vdot(self._op(x.val))
        return x.new(res, VdotOperator(self._op(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
149

Philipp Arras's avatar
Philipp Arras committed
150

Philipp Frank's avatar
Philipp Frank committed
151
class VariableCovarianceGaussianEnergy(LikelihoodOperator):
Reimar Leike's avatar
Reimar Leike committed
152
    """Computes the negative log pdf of a Gaussian with unknown covariance.
153

Reimar Leike's avatar
Reimar Leike committed
154
    The covariance is assumed to be diagonal.
155
156

    .. math ::
157
        E(s,D) = - \\log G(s, C) = 0.5 (s)^\\dagger C (s) - 0.5 tr log(C),
158
159

    an information energy for a Gaussian distribution with residual s and
160
    inverse diagonal covariance C.
Reimar Leike's avatar
Reimar Leike committed
161
162
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
163
164
165

    Parameters
    ----------
166
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
167
        domain of the residual and domain of the covariance diagonal.
168

169
    residual_key : key
Philipp Arras's avatar
Philipp Arras committed
170
        Residual key of the Gaussian.
171

172
    inverse_covariance_key : key
173
        Inverse covariance diagonal key of the Gaussian.
Philipp Arras's avatar
Philipp Arras committed
174

175
    sampling_dtype : np.dtype
Philipp Arras's avatar
Philipp Arras committed
176
        Data type of the samples. Usually either 'np.float*' or 'np.complex*'
Philipp Frank's avatar
Philipp Frank committed
177
178
179

    use_full_fisher: boolean
        Whether or not the proper Fisher information metric should be used as
Philipp Arras's avatar
Philipp Arras committed
180
        a `metric`. If False the same approximation used in
Philipp Frank's avatar
Philipp Frank committed
181
182
        `get_transformation` is used instead.
        Default is True
183
184
    """

Philipp Arras's avatar
Philipp Arras committed
185
186
    def __init__(self, domain, residual_key, inverse_covariance_key,
                 sampling_dtype, use_full_fisher=True):
Philipp Arras's avatar
Philipp Arras committed
187
188
        self._kr = str(residual_key)
        self._ki = str(inverse_covariance_key)
Philipp Arras's avatar
Philipp Arras committed
189
        dom = DomainTuple.make(domain)
Philipp Arras's avatar
Philipp Arras committed
190
        self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
Philipp Arras's avatar
Philipp Arras committed
191
192
        self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
        _check_sampling_dtype(self._domain, self._dt)
193
        self._cplx = _iscomplex(sampling_dtype)
Philipp Arras's avatar
Philipp Arras committed
194
        self._use_full_fisher = use_full_fisher
195

Philipp Arras's avatar
Philipp Arras committed
196
    def apply(self, x):
197
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
198
        r, i = x[self._kr], x[self._ki]
Philipp Arras's avatar
Philipp Arras committed
199
200
201
202
        if self._cplx:
            res = 0.5*r.vdot(r*i.real).real - i.ptw("log").sum()
        else:
            res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
Martin Reinecke's avatar
more    
Martin Reinecke committed
203
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
204
            return res
Philipp Arras's avatar
Philipp Arras committed
205
        if self._use_full_fisher:
Philipp Frank's avatar
Philipp Frank committed
206
207
208
209
210
211
212
            met = 1. if self._cplx else 0.5
            met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
                                        domain=self._domain)
            met = SamplingDtypeSetter(makeOp(met), self._dt)
        else:
            met = self.get_metric_at(x.val)
        return res.add_metric(met)
213

214
215
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        from .simplify_for_const import ConstantEnergyOperator
216
        myassert(len(c_inp.keys()) == 1)
217
        key = c_inp.keys()[0]
218
        myassert(key in self._domain.keys())
219
220
221
222
223
224
225
226
        cst = c_inp[key]
        if key == self._kr:
            res = _SpecialGammaEnergy(cst).ducktape(self._ki)
        else:
            dt = self._dt[self._kr]
            res = GaussianEnergy(inverse_covariance=makeOp(cst),
                                 sampling_dtype=dt).ducktape(self._kr)
            trlog = cst.log().sum().val_rw()
Philipp Frank's avatar
Philipp Frank committed
227
            if not self._cplx:
228
229
230
                trlog /= 2
            res = res + ConstantEnergyOperator(-trlog)
        res = res + ConstantEnergyOperator(0.)
231
        myassert(res.target is self.target)
232
        return None, res
233

Philipp Frank's avatar
Philipp Frank committed
234
    def get_transformation(self):
Philipp Arras's avatar
Philipp Arras committed
235
        """Note that for the metric of a `VariableCovarianceGaussianEnergy` no
Philipp Frank's avatar
Philipp Frank committed
236
237
238
239
240
241
242
        global transformation to Euclidean space exists. A local approximation
        ivoking the resudual is used instead.
        """
        r = FieldAdapter(self._domain[self._kr], self._kr)
        ivar = FieldAdapter(self._domain[self._kr], self._ki)
        sc = 1. if self._cplx else 0.5
        return self._dt, r.adjoint@(ivar.ptw('sqrt')*r) + ivar.adjoint@(sc*ivar.ptw('log'))
243

Philipp Frank's avatar
Philipp Frank committed
244
245

class _SpecialGammaEnergy(LikelihoodOperator):
246
247
248
249
    def __init__(self, residual):
        self._domain = DomainTuple.make(residual.domain)
        self._resi = residual
        self._cplx = _iscomplex(self._resi.dtype)
Philipp Frank's avatar
Philipp Frank committed
250
        self._dt = self._resi.dtype
251
252
253
254
255
256
257
258
259
260

    def apply(self, x):
        self._check_input(x)
        r = self._resi
        if self._cplx:
            res = 0.5*(r*x.real).vdot(r).real - x.ptw("log").sum()
        else:
            res = 0.5*((r*x).vdot(r) - x.ptw("log").sum())
        if not x.want_metric:
            return res
Philipp Frank's avatar
Philipp Frank committed
261
        return res.add_metric(self.get_metric_at(x.val))
262

Philipp Frank's avatar
Philipp Frank committed
263
264
265
    def get_transformation(self):
        sc = 1. if self._cplx else np.sqrt(0.5)
        return self._dt, sc*ScalingOperator(self._domain, 1.).ptw('log')
Martin Reinecke's avatar
Martin Reinecke committed
266

Philipp Frank's avatar
Philipp Frank committed
267
class GaussianEnergy(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
268
    """Computes a negative-log Gaussian.
269

Philipp Arras's avatar
Philipp Arras committed
270
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
271

Philipp Arras's avatar
Philipp Arras committed
272
273
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
274

Philipp Arras's avatar
Philipp Arras committed
275
276
    an information energy for a Gaussian distribution with mean m and
    covariance D.
277

Philipp Arras's avatar
Philipp Arras committed
278
279
280
281
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
282
283
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
284
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
285
286
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified
Reimar Leike's avatar
Reimar Leike committed
287
    sampling_dtype : type
Martin Reinecke's avatar
Martin Reinecke committed
288
        Here one can specify whether the distribution is a complex Gaussian or
Reimar Leike's avatar
Reimar Leike committed
289
290
291
        not. Note that for a complex Gaussian the inverse_covariance is
        .. math ::
        (<ff^dagger>)^{-1}_P(f)/2,
292
        where the additional factor of 2 is necessary because the
Reimar Leike's avatar
Reimar Leike committed
293
        domain of s has double as many dimensions as in the real case.
Philipp Arras's avatar
Philipp Arras committed
294
295
296
297

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
298
    """
Martin Reinecke's avatar
Martin Reinecke committed
299

Philipp Arras's avatar
Philipp Arras committed
300
    def __init__(self, mean=None, inverse_covariance=None, domain=None, sampling_dtype=None):
Martin Reinecke's avatar
Martin Reinecke committed
301
302
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
303
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
304
305
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
306
307
308
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
309
310
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
311
312
313
314
315
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Philipp Arras's avatar
Philipp Arras committed
316
317
318
319
320
321
322
323
324
325
326

        # Infer sampling dtype
        if self._mean is None:
            _check_sampling_dtype(self._domain, sampling_dtype)
        else:
            if sampling_dtype is None:
                sampling_dtype = _field_to_dtype(self._mean)
            else:
                if sampling_dtype != _field_to_dtype(self._mean):
                    raise ValueError("Sampling dtype and mean not compatible")

Philipp Arras's avatar
Philipp Arras committed
327
        self._icov = inverse_covariance
328
        if inverse_covariance is None:
329
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Philipp Arras's avatar
Philipp Arras committed
330
            self._met = ScalingOperator(self._domain, 1)
Martin Reinecke's avatar
Martin Reinecke committed
331
        else:
332
            self._op = QuadraticFormOperator(inverse_covariance)
Philipp Arras's avatar
Philipp Arras committed
333
            self._met = inverse_covariance
Philipp Arras's avatar
Philipp Arras committed
334
        if sampling_dtype is not None:
335
            self._met = SamplingDtypeSetter(self._met, sampling_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
336
337

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
338
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
339
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
340
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
341
        else:
Philipp Arras's avatar
Philipp Arras committed
342
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
343
344
                raise ValueError("domain mismatch")

Philipp Arras's avatar
Philipp Arras committed
345
    def apply(self, x):
346
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
347
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
348
        res = self._op(residual).real
Martin Reinecke's avatar
more    
Martin Reinecke committed
349
        if x.want_metric:
Philipp Frank's avatar
Philipp Frank committed
350
            return res.add_metric(self.get_metric_at(x.val))
Philipp Arras's avatar
Philipp Arras committed
351
        return res
Martin Reinecke's avatar
Martin Reinecke committed
352

Philipp Frank's avatar
Philipp Frank committed
353
354
355
356
357
358
359
    def get_transformation(self):
        icov, dtp = self._met, None
        if isinstance(icov, SamplingDtypeSetter):
            dtp = icov._dtype
            icov = icov._op
        return dtp, icov.get_sqrt()

Philipp Arras's avatar
Philipp Arras committed
360
361
362
363
    def __repr__(self):
        dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
        return f'GaussianEnergy {dom}'

Martin Reinecke's avatar
Martin Reinecke committed
364

Philipp Frank's avatar
Philipp Frank committed
365
class PoissonianEnergy(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
366
367
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
368

Philipp Arras's avatar
Philipp Arras committed
369
    Represents up to an f-independent term :math:`log(d!)`:
370

Philipp Arras's avatar
Philipp Arras committed
371
372
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
373

Philipp Arras's avatar
Philipp Arras committed
374
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
375
    the counts.
Philipp Arras's avatar
Philipp Arras committed
376
377
378
379
380
381

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
382
    """
Philipp Arras's avatar
Philipp Arras committed
383

384
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
385
386
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
387
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
388
            raise ValueError
389
390
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
391

Philipp Arras's avatar
Philipp Arras committed
392
    def apply(self, x):
393
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
394
        res = x.sum() - x.ptw("log").vdot(self._d)
Martin Reinecke's avatar
more    
Martin Reinecke committed
395
        if not x.want_metric:
396
            return res
Philipp Frank's avatar
Philipp Frank committed
397
        return res.add_metric(self.get_metric_at(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
398

Philipp Frank's avatar
Philipp Frank committed
399
400
    def get_transformation(self):
        return np.float64, 2.*ScalingOperator(self._domain,1.).sqrt()
401

Philipp Frank's avatar
Philipp Frank committed
402
class InverseGammaLikelihood(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
403
    """Computes the negative log-likelihood of the inverse gamma distribution.
404
405
406

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
407
408
409
410
411
412
413
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
414
415
416
417
418
419
420

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
421
    """
Philipp Arras's avatar
Philipp Arras committed
422

423
424
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
425
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
426
        self._domain = DomainTuple.make(beta.domain)
427
428
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
429
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
430
431
432
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
Philipp Arras's avatar
Philipp Arras committed
433
434
435
436
        if not self._beta.dtype == np.float64:
            # FIXME Add proper complex support for this energy
            raise TypeError
        self._sampling_dtype = _field_to_dtype(self._beta)
437

Philipp Arras's avatar
Philipp Arras committed
438
    def apply(self, x):
439
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
440
        res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
Martin Reinecke's avatar
more    
Martin Reinecke committed
441
        if not x.want_metric:
442
            return res
Philipp Frank's avatar
Philipp Frank committed
443
        return res.add_metric(self.get_metric_at(x.val))
444

Philipp Frank's avatar
Philipp Frank committed
445
446
447
448
    def get_transformation(self):
        fact = self._alphap1.ptw('sqrt')
        res = makeOp(fact)@ScalingOperator(self._domain,1.).ptw('log')
        return self._sampling_dtype, res
449

Philipp Frank's avatar
Philipp Frank committed
450
class StudentTEnergy(LikelihoodOperator):
Lukas Platz's avatar
Lukas Platz committed
451
    """Computes likelihood energy corresponding to Student's t-distribution.
452
453

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
454
455
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
456

Philipp Arras's avatar
Philipp Arras committed
457
458
    where f is a field defined on `domain`. Assumes that the data is `float64`
    for sampling.
459
460
461

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
462
463
    domain : `Domain` or `DomainTuple`
        Domain of the operator
Reimar Leike's avatar
Reimar Leike committed
464
    theta : Scalar or Field
465
466
467
        Degree of freedom parameter for the student t distribution
    """

Philipp Arras's avatar
Philipp Arras committed
468
    def __init__(self, domain, theta):
469
470
471
        self._domain = DomainTuple.make(domain)
        self._theta = theta

Philipp Arras's avatar
Philipp Arras committed
472
    def apply(self, x):
473
        self._check_input(x)
474
        res = (((self._theta+1)/2)*(x**2/self._theta).ptw("log1p")).sum()
Martin Reinecke's avatar
more    
Martin Reinecke committed
475
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
476
            return res
Philipp Frank's avatar
Philipp Frank committed
477
        return res.add_metric(self.get_metric_at(x.val))
478

Philipp Frank's avatar
Philipp Frank committed
479
480
481
482
483
484
485
    def get_transformation(self):
        if isinstance(self._theta, Field) or isinstance(self._theta, MultiField):
            th = self._theta
        else:
            from ..extra import full
            th = full(self._domain, self._theta)
        return np.float64, makeOp(((th+1)/(th+3)).ptw('sqrt'))
486

Philipp Frank's avatar
Philipp Frank committed
487
class BernoulliEnergy(LikelihoodOperator):
Philipp Arras's avatar
Philipp Arras committed
488
    """Computes likelihood energy of expected event frequency constrained by
489
490
    event data.

Philipp Arras's avatar
Philipp Arras committed
491
492
493
494
495
496
497
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

498
499
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
500
    d : Field
Philipp Arras's avatar
Philipp Arras committed
501
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
502
    """
Philipp Arras's avatar
Philipp Arras committed
503

504
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
505
506
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
507
        if np.any(np.logical_and(d.val != 0, d.val != 1)):
Philipp Arras's avatar
Philipp Arras committed
508
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
509
        self._d = d
510
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
511

Philipp Arras's avatar
Philipp Arras committed
512
    def apply(self, x):
513
        self._check_input(x)
Martin Reinecke's avatar
Martin Reinecke committed
514
        res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
Martin Reinecke's avatar
more    
Martin Reinecke committed
515
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
516
            return res
Philipp Frank's avatar
Philipp Frank committed
517
        return res.add_metric(self.get_metric_at(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
518

Philipp Frank's avatar
Philipp Frank committed
519
520
521
522
523
    def get_transformation(self):
        from ..extra import full
        res = Adder(full(self._domain,1.))@ScalingOperator(self._domain,-1)
        res = res * ScalingOperator(self._domain,1).ptw('reciprocal')
        return np.float64, -2.*res.ptw('sqrt').ptw('arctan')
Martin Reinecke's avatar
Martin Reinecke committed
524

525
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
526
527
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
528

Philipp Arras's avatar
Philipp Arras committed
529
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
530

Philipp Arras's avatar
Philipp Arras committed
531
532
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
533

Martin Reinecke's avatar
Martin Reinecke committed
534
    Other field priors can be represented via transformations of a white
535
536
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
537
    By implementing prior information this way, the field prior is represented
538
539
540
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
541
542
543
544
545
546
547
548
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
549
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
550
        to use to draw Gaussian samples.
Philipp Arras's avatar
Philipp Arras committed
551
552
    prior_dtype : numpy.dtype or dict of numpy.dtype, optional
        Data type of prior used for sampling.
Philipp Arras's avatar
Philipp Arras committed
553
554
555
556
557

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
558
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
559
    """
Philipp Arras's avatar
Philipp Arras committed
560

561
    def __init__(self, lh, ic_samp=None, prior_dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
562
        self._lh = lh
Philipp Arras's avatar
Philipp Arras committed
563
        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
564
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
565
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
566

Philipp Arras's avatar
Philipp Arras committed
567
    def apply(self, x):
568
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
569
        if not x.want_metric or self._ic_samp is None:
Philipp Arras's avatar
Philipp Arras committed
570
            return (self._lh + self._prior)(x)
Philipp Arras's avatar
Philipp Arras committed
571
572
        lhx, prx = self._lh(x), self._prior(x)
        return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp))
Martin Reinecke's avatar
Martin Reinecke committed
573

Philipp Arras's avatar
Philipp Arras committed
574
575
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
576
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
577
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
578

579
580
581
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp)
582

Martin Reinecke's avatar
Martin Reinecke committed
583

Martin Reinecke's avatar
Martin Reinecke committed
584
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
585
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
586

587
588
589
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
590
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
591
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
592
593
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
594

Philipp Arras's avatar
Docs    
Philipp Arras committed
595
596
597
598
599
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
600

Philipp Arras's avatar
Docs    
Philipp Arras committed
601
602
603
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
604
    """
Martin Reinecke's avatar
Martin Reinecke committed
605
606
607

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
608
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
609
610
        self._res_samples = tuple(res_samples)

Philipp Arras's avatar
Philipp Arras committed
611
    def apply(self, x):
612
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
613
614
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap)/len(self._res_samples)
Philipp Frank's avatar
Philipp Frank committed
615
616
617
618
619

    def get_transformation(self):
        dtp, trafo = self._h.get_transformation()
        mymap = map(lambda v: trafo@Adder(v), self._res_samples)
        return dtp, utilities.my_sum(mymap)/np.sqrt(len(self._res_samples))