energy_operators.py 21.2 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2021 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Philipp Arras's avatar
Philipp Arras committed
18
19
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from .. import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..domain_tuple import DomainTuple
Philipp Arras's avatar
Philipp Arras committed
22
from ..field import Field
Philipp Frank's avatar
Philipp Frank committed
23
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
24
25
26
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
27
from ..utilities import myassert
Philipp Arras's avatar
Philipp Arras committed
28
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
29
from .operator import Operator
Philipp Frank's avatar
Philipp Frank committed
30
from .adder import Adder
31
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
32
from .scaling_operator import ScalingOperator
Philipp Frank's avatar
Philipp Frank committed
33
34
from .sandwich_operator import SandwichOperator
from .simple_linear_operators import VdotOperator, FieldAdapter
Philipp Arras's avatar
Philipp Arras committed
35
36
37
38
39
40


def _check_sampling_dtype(domain, dtypes):
    if dtypes is None:
        return
    if isinstance(domain, DomainTuple):
Philipp Arras's avatar
Philipp Arras committed
41
42
        np.dtype(dtypes)
        return
Philipp Arras's avatar
Philipp Arras committed
43
    elif isinstance(domain, MultiDomain):
Philipp Arras's avatar
Philipp Arras committed
44
45
46
47
48
49
50
        if isinstance(dtypes, dict):
            for dt in dtypes.values():
                np.dtype(dt)
            if set(domain.keys()) == set(dtypes.keys()):
                return
        else:
            np.dtype(dtypes)
Philipp Arras's avatar
Philipp Arras committed
51
            return
Philipp Arras's avatar
Philipp Arras committed
52
    raise TypeError
Philipp Arras's avatar
Philipp Arras committed
53
54


55
56
57
58
def _iscomplex(dtype):
    return np.issubdtype(dtype, np.complexfloating)


Philipp Arras's avatar
Philipp Arras committed
59
60
61
62
63
64
65
66
67
68
69
def _field_to_dtype(field):
    if isinstance(field, Field):
        dt = field.dtype
    elif isinstance(field, MultiField):
        dt = {kk: ff.dtype for kk, ff in field.items()}
    else:
        raise TypeError
    _check_sampling_dtype(field.domain, dt)
    return dt


Martin Reinecke's avatar
Martin Reinecke committed
70
class EnergyOperator(Operator):
Philipp Arras's avatar
Philipp Arras committed
71
    """Operator which has a scalar domain as target domain.
72

Philipp Arras's avatar
Philipp Arras committed
73
74
75
    It is intended as an objective function for field inference.  It can
    implement a positive definite, symmetric form (called `metric`) that is
    used as curvature for second-order minimizations.
76

Philipp Arras's avatar
Philipp Arras committed
77
78
79
    Examples
    --------
     - Information Hamiltonian, i.e. negative-log-probabilities.
Martin Reinecke's avatar
Martin Reinecke committed
80
     - Gibbs free energy, i.e. an averaged Hamiltonian, aka Kullback-Leibler
Philipp Arras's avatar
Philipp Arras committed
81
       divergence.
82
    """
Martin Reinecke's avatar
Martin Reinecke committed
83
84
85
    _target = DomainTuple.scalar_domain()


Philipp Frank's avatar
Philipp Frank committed
86
class LikelihoodOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
87
88
89
90
91
    """Represent a log-likelihood.

    The input to the Operator are the parameters of the likelihood. Unlike a
    general `EnergyOperator`, the metric of a `LikelihoodOperator` is the
    Fisher information metric of the likelihood.
Philipp Frank's avatar
Philipp Frank committed
92
    """
Philipp Arras's avatar
Philipp Arras committed
93

Philipp Frank's avatar
Philipp Frank committed
94
95
96
    def get_metric_at(self, x):
        """Computes the Fisher information metric for a `LikelihoodOperator`
        at `x` using the Jacobian of the coordinate transformation given by
Philipp Frank's avatar
Philipp Frank committed
97
        :func:`~nifty7.operators.operator.Operator.get_transformation`.
Philipp Frank's avatar
Philipp Frank committed
98
99
        """
        dtp, f = self.get_transformation()
Philipp Arras's avatar
Philipp Arras committed
100
        ch = None
Philipp Frank's avatar
Philipp Frank committed
101
        if dtp is not None:
Philipp Arras's avatar
Philipp Arras committed
102
103
104
            ch = SamplingDtypeSetter(ScalingOperator(f.target, 1.), dtp)
        bun = f(Linearization.make_var(x)).jac
        return SandwichOperator.make(bun, ch)
Philipp Frank's avatar
Philipp Frank committed
105
106


107
108
class Squared2NormOperator(EnergyOperator):
    """Computes the square of the L2-norm of the output of an operator.
109

Philipp Arras's avatar
Philipp Arras committed
110
111
112
    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
113
        Domain of the operator in which the L2-norm shall be computed.
Martin Reinecke's avatar
Martin Reinecke committed
114
    """
Philipp Arras's avatar
Philipp Arras committed
115

Martin Reinecke's avatar
Martin Reinecke committed
116
117
118
    def __init__(self, domain):
        self._domain = domain

Philipp Arras's avatar
Philipp Arras committed
119
    def apply(self, x):
120
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
121
122
        if x.jac is None:
            return x.vdot(x)
Philipp Arras's avatar
Philipp Arras committed
123
124
        res = x.val.vdot(x.val)
        return x.new(res, VdotOperator(2*x.val))
Martin Reinecke's avatar
Martin Reinecke committed
125

Martin Reinecke's avatar
Martin Reinecke committed
126

Martin Reinecke's avatar
Martin Reinecke committed
127
class QuadraticFormOperator(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
128
    """Computes the L2-norm of a Field or MultiField with respect to a
129
    specific kernel given by `endo`.
Philipp Arras's avatar
Philipp Arras committed
130
131
132

    .. math ::
        E(f) = \\frac12 f^\\dagger \\text{endo}(f)
133
134
135

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
136
    endo : EndomorphicOperator
137
         Kernel of the quadratic form
Martin Reinecke's avatar
Martin Reinecke committed
138
    """
Philipp Arras's avatar
Philipp Arras committed
139
140

    def __init__(self, endo):
Martin Reinecke's avatar
Martin Reinecke committed
141
        from .endomorphic_operator import EndomorphicOperator
Philipp Arras's avatar
Philipp Arras committed
142
        if not isinstance(endo, EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
143
            raise TypeError("op must be an EndomorphicOperator")
Philipp Arras's avatar
Philipp Arras committed
144
145
        self._op = endo
        self._domain = endo.domain
Martin Reinecke's avatar
Martin Reinecke committed
146

Philipp Arras's avatar
Philipp Arras committed
147
    def apply(self, x):
148
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
149
        if x.jac is None:
Philipp Arras's avatar
Philipp Arras committed
150
151
152
            return 0.5*x.vdot(self._op(x))
        res = 0.5*x.val.vdot(self._op(x.val))
        return x.new(res, VdotOperator(self._op(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
153

Philipp Arras's avatar
Philipp Arras committed
154

Philipp Frank's avatar
Philipp Frank committed
155
class VariableCovarianceGaussianEnergy(LikelihoodOperator):
Reimar Leike's avatar
Reimar Leike committed
156
    """Computes the negative log pdf of a Gaussian with unknown covariance.
157

Reimar Leike's avatar
Reimar Leike committed
158
    The covariance is assumed to be diagonal.
159
160

    .. math ::
161
        E(s,D) = - \\log G(s, C) = 0.5 (s)^\\dagger C (s) - 0.5 tr log(C),
162
163

    an information energy for a Gaussian distribution with residual s and
164
    inverse diagonal covariance C.
Reimar Leike's avatar
Reimar Leike committed
165
166
    The domain of this energy will be a MultiDomain with two keys,
    the target will be the scalar domain.
167
168
169

    Parameters
    ----------
170
    domain : Domain, DomainTuple, tuple of Domain
Reimar Leike's avatar
Reimar Leike committed
171
        domain of the residual and domain of the covariance diagonal.
172

173
    residual_key : key
Philipp Arras's avatar
Philipp Arras committed
174
        Residual key of the Gaussian.
175

176
    inverse_covariance_key : key
177
        Inverse covariance diagonal key of the Gaussian.
Philipp Arras's avatar
Philipp Arras committed
178

179
    sampling_dtype : np.dtype
Philipp Arras's avatar
Philipp Arras committed
180
        Data type of the samples. Usually either 'np.float*' or 'np.complex*'
Philipp Frank's avatar
Philipp Frank committed
181
182

    use_full_fisher: boolean
Philipp Arras's avatar
Philipp Arras committed
183
184
185
        Determines if the proper Fisher information metric should be used as
        `metric`. If False, the same approximation as in `get_transformation`
        is used. Default is True.
186
187
    """

Philipp Arras's avatar
Philipp Arras committed
188
189
    def __init__(self, domain, residual_key, inverse_covariance_key,
                 sampling_dtype, use_full_fisher=True):
Philipp Arras's avatar
Philipp Arras committed
190
191
        self._kr = str(residual_key)
        self._ki = str(inverse_covariance_key)
Philipp Arras's avatar
Philipp Arras committed
192
        dom = DomainTuple.make(domain)
Philipp Arras's avatar
Philipp Arras committed
193
        self._domain = MultiDomain.make({self._kr: dom, self._ki: dom})
Philipp Arras's avatar
Philipp Arras committed
194
195
        self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
        _check_sampling_dtype(self._domain, self._dt)
196
        self._cplx = _iscomplex(sampling_dtype)
Philipp Arras's avatar
Philipp Arras committed
197
        self._use_full_fisher = use_full_fisher
198

Philipp Arras's avatar
Philipp Arras committed
199
    def apply(self, x):
200
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
201
        r, i = x[self._kr], x[self._ki]
Philipp Arras's avatar
Philipp Arras committed
202
        if self._cplx:
Philipp Arras's avatar
Philipp Arras committed
203
            res = 0.5*r.vdot(r*i.real).real - i.log().sum()
Philipp Arras's avatar
Philipp Arras committed
204
        else:
Philipp Arras's avatar
Philipp Arras committed
205
            res = 0.5*(r.vdot(r*i) - i.log().sum())
Martin Reinecke's avatar
more    
Martin Reinecke committed
206
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
207
            return res
Philipp Arras's avatar
Philipp Arras committed
208
        if self._use_full_fisher:
Philipp Frank's avatar
Philipp Frank committed
209
210
211
212
213
214
215
            met = 1. if self._cplx else 0.5
            met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
                                        domain=self._domain)
            met = SamplingDtypeSetter(makeOp(met), self._dt)
        else:
            met = self.get_metric_at(x.val)
        return res.add_metric(met)
216

217
218
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        from .simplify_for_const import ConstantEnergyOperator
219
        myassert(len(c_inp.keys()) == 1)
220
        key = c_inp.keys()[0]
221
        myassert(key in self._domain.keys())
222
223
224
225
226
227
228
229
        cst = c_inp[key]
        if key == self._kr:
            res = _SpecialGammaEnergy(cst).ducktape(self._ki)
        else:
            dt = self._dt[self._kr]
            res = GaussianEnergy(inverse_covariance=makeOp(cst),
                                 sampling_dtype=dt).ducktape(self._kr)
            trlog = cst.log().sum().val_rw()
Philipp Frank's avatar
Philipp Frank committed
230
            if not self._cplx:
231
232
233
                trlog /= 2
            res = res + ConstantEnergyOperator(-trlog)
        res = res + ConstantEnergyOperator(0.)
234
        myassert(res.target is self.target)
235
        return None, res
236

Philipp Frank's avatar
Philipp Frank committed
237
    def get_transformation(self):
Philipp Arras's avatar
Philipp Arras committed
238
239
240
241
242
243
        """
        Note
        ----
        For `VariableCovarianceGaussianEnergy`, a global transformation to
        Euclidean space does not exist. A local approximation invoking the
        residual is used instead.
Philipp Frank's avatar
Philipp Frank committed
244
245
        """
        r = FieldAdapter(self._domain[self._kr], self._kr)
Philipp Frank's avatar
Philipp Frank committed
246
        ivar = FieldAdapter(self._domain[self._kr], self._ki).real
Philipp Frank's avatar
Philipp Frank committed
247
        sc = 1. if self._cplx else 0.5
Philipp Arras's avatar
Philipp Arras committed
248
249
        f = r.adjoint @ (ivar.sqrt()*r) + ivar.adjoint @ (sc*ivar.log())
        return self._dt, f
250

Philipp Frank's avatar
Philipp Frank committed
251
252

class _SpecialGammaEnergy(LikelihoodOperator):
253
254
255
256
    def __init__(self, residual):
        self._domain = DomainTuple.make(residual.domain)
        self._resi = residual
        self._cplx = _iscomplex(self._resi.dtype)
Philipp Frank's avatar
Philipp Frank committed
257
        self._dt = self._resi.dtype
258
259
260
261
262

    def apply(self, x):
        self._check_input(x)
        r = self._resi
        if self._cplx:
Philipp Arras's avatar
Philipp Arras committed
263
            res = 0.5*(r*x.real).vdot(r).real - x.log().sum()
264
        else:
Philipp Arras's avatar
Philipp Arras committed
265
            res = 0.5*((r*x).vdot(r) - x.log().sum())
266
267
        if not x.want_metric:
            return res
Philipp Frank's avatar
Philipp Frank committed
268
        return res.add_metric(self.get_metric_at(x.val))
269

Philipp Frank's avatar
Philipp Frank committed
270
271
    def get_transformation(self):
        sc = 1. if self._cplx else np.sqrt(0.5)
Philipp Arras's avatar
Philipp Arras committed
272
        return self._dt, sc*ScalingOperator(self._domain, 1.).log()
Martin Reinecke's avatar
Martin Reinecke committed
273

Philipp Arras's avatar
Philipp Arras committed
274

Philipp Frank's avatar
Philipp Frank committed
275
class GaussianEnergy(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
276
    """Computes a negative-log Gaussian.
277

Philipp Arras's avatar
Philipp Arras committed
278
    Represents up to constants in :math:`m`:
Martin Reinecke's avatar
Martin Reinecke committed
279

Philipp Arras's avatar
Philipp Arras committed
280
281
    .. math ::
        E(f) = - \\log G(f-m, D) = 0.5 (f-m)^\\dagger D^{-1} (f-m),
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
282

Philipp Arras's avatar
Philipp Arras committed
283
284
    an information energy for a Gaussian distribution with mean m and
    covariance D.
285

Philipp Arras's avatar
Philipp Arras committed
286
287
288
289
    Parameters
    ----------
    mean : Field
        Mean of the Gaussian. Default is 0.
290
291
    inverse_covariance : LinearOperator
        Inverse covariance of the Gaussian. Default is the identity operator.
Philipp Arras's avatar
Fixup    
Philipp Arras committed
292
    domain : Domain, DomainTuple, tuple of Domain or MultiDomain
Philipp Arras's avatar
Philipp Arras committed
293
294
        Operator domain. By default it is inferred from `mean` or
        `covariance` if specified
Reimar Leike's avatar
Reimar Leike committed
295
    sampling_dtype : type
Martin Reinecke's avatar
Martin Reinecke committed
296
        Here one can specify whether the distribution is a complex Gaussian or
Reimar Leike's avatar
Reimar Leike committed
297
298
299
        not. Note that for a complex Gaussian the inverse_covariance is
        .. math ::
        (<ff^dagger>)^{-1}_P(f)/2,
300
        where the additional factor of 2 is necessary because the
Reimar Leike's avatar
Reimar Leike committed
301
        domain of s has double as many dimensions as in the real case.
Philipp Arras's avatar
Philipp Arras committed
302
303
304
305

    Note
    ----
    At least one of the arguments has to be provided.
Martin Reinecke's avatar
Martin Reinecke committed
306
    """
Martin Reinecke's avatar
Martin Reinecke committed
307

Philipp Arras's avatar
Philipp Arras committed
308
    def __init__(self, mean=None, inverse_covariance=None, domain=None, sampling_dtype=None):
Martin Reinecke's avatar
Martin Reinecke committed
309
310
        if mean is not None and not isinstance(mean, (Field, MultiField)):
            raise TypeError
311
        if inverse_covariance is not None and not isinstance(inverse_covariance, LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
312
313
            raise TypeError

Martin Reinecke's avatar
Martin Reinecke committed
314
315
316
        self._domain = None
        if mean is not None:
            self._checkEquivalence(mean.domain)
317
318
        if inverse_covariance is not None:
            self._checkEquivalence(inverse_covariance.domain)
Martin Reinecke's avatar
Martin Reinecke committed
319
320
321
322
323
        if domain is not None:
            self._checkEquivalence(domain)
        if self._domain is None:
            raise ValueError("no domain given")
        self._mean = mean
Philipp Arras's avatar
Philipp Arras committed
324
325
326
327
328
329
330
331
332
333
334

        # Infer sampling dtype
        if self._mean is None:
            _check_sampling_dtype(self._domain, sampling_dtype)
        else:
            if sampling_dtype is None:
                sampling_dtype = _field_to_dtype(self._mean)
            else:
                if sampling_dtype != _field_to_dtype(self._mean):
                    raise ValueError("Sampling dtype and mean not compatible")

Philipp Arras's avatar
Philipp Arras committed
335
        self._icov = inverse_covariance
336
        if inverse_covariance is None:
337
            self._op = Squared2NormOperator(self._domain).scale(0.5)
Philipp Arras's avatar
Philipp Arras committed
338
            self._met = ScalingOperator(self._domain, 1)
Martin Reinecke's avatar
Martin Reinecke committed
339
        else:
340
            self._op = QuadraticFormOperator(inverse_covariance)
Philipp Arras's avatar
Philipp Arras committed
341
            self._met = inverse_covariance
Philipp Arras's avatar
Philipp Arras committed
342
        if sampling_dtype is not None:
343
            self._met = SamplingDtypeSetter(self._met, sampling_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
344
345

    def _checkEquivalence(self, newdom):
Martin Reinecke's avatar
fix    
Martin Reinecke committed
346
        newdom = makeDomain(newdom)
Martin Reinecke's avatar
Martin Reinecke committed
347
        if self._domain is None:
Philipp Arras's avatar
Philipp Arras committed
348
            self._domain = newdom
Martin Reinecke's avatar
Martin Reinecke committed
349
        else:
Philipp Arras's avatar
Philipp Arras committed
350
            if self._domain != newdom:
Martin Reinecke's avatar
Martin Reinecke committed
351
352
                raise ValueError("domain mismatch")

Philipp Arras's avatar
Philipp Arras committed
353
    def apply(self, x):
354
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
355
        residual = x if self._mean is None else x - self._mean
Philipp Arras's avatar
Changes    
Philipp Arras committed
356
        res = self._op(residual).real
Martin Reinecke's avatar
more    
Martin Reinecke committed
357
        if x.want_metric:
Philipp Frank's avatar
Philipp Frank committed
358
            return res.add_metric(self.get_metric_at(x.val))
Philipp Arras's avatar
Philipp Arras committed
359
        return res
Martin Reinecke's avatar
Martin Reinecke committed
360

Philipp Frank's avatar
Philipp Frank committed
361
362
363
364
365
366
367
    def get_transformation(self):
        icov, dtp = self._met, None
        if isinstance(icov, SamplingDtypeSetter):
            dtp = icov._dtype
            icov = icov._op
        return dtp, icov.get_sqrt()

Philipp Arras's avatar
Philipp Arras committed
368
369
370
371
    def __repr__(self):
        dom = '()' if isinstance(self.domain, DomainTuple) else self.domain.keys()
        return f'GaussianEnergy {dom}'

Martin Reinecke's avatar
Martin Reinecke committed
372

Philipp Frank's avatar
Philipp Frank committed
373
class PoissonianEnergy(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
374
375
    """Computes likelihood Hamiltonians of expected count field constrained by
    Poissonian count data.
376

Philipp Arras's avatar
Philipp Arras committed
377
    Represents up to an f-independent term :math:`log(d!)`:
378

Philipp Arras's avatar
Philipp Arras committed
379
380
    .. math ::
        E(f) = -\\log \\text{Poisson}(d|f) = \\sum f - d^\\dagger \\log(f),
381

Philipp Arras's avatar
Philipp Arras committed
382
    where f is a :class:`Field` in data space with the expectation values for
Martin Reinecke's avatar
Martin Reinecke committed
383
    the counts.
Philipp Arras's avatar
Philipp Arras committed
384
385
386
387
388
389

    Parameters
    ----------
    d : Field
        Data field with counts. Needs to have integer dtype and all field
        values need to be non-negative.
Martin Reinecke's avatar
Martin Reinecke committed
390
    """
Philipp Arras's avatar
Philipp Arras committed
391

392
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
393
394
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
395
        if np.any(d.val < 0):
Philipp Arras's avatar
Philipp Arras committed
396
            raise ValueError
397
398
        self._d = d
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
399

Philipp Arras's avatar
Philipp Arras committed
400
    def apply(self, x):
401
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
402
        res = x.sum() - x.log().vdot(self._d)
Martin Reinecke's avatar
more    
Martin Reinecke committed
403
        if not x.want_metric:
404
            return res
Philipp Frank's avatar
Philipp Frank committed
405
        return res.add_metric(self.get_metric_at(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
406

Philipp Frank's avatar
Philipp Frank committed
407
    def get_transformation(self):
Philipp Arras's avatar
Philipp Arras committed
408
409
        return np.float64, 2.*ScalingOperator(self._domain, 1.).sqrt()

410

Philipp Frank's avatar
Philipp Frank committed
411
class InverseGammaLikelihood(LikelihoodOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
412
    """Computes the negative log-likelihood of the inverse gamma distribution.
413
414
415

    It negative log-pdf(x) is given by

Martin Reinecke's avatar
Martin Reinecke committed
416
417
418
419
420
421
422
    .. math ::

        \\sum_i (\\alpha_i+1)*\\ln(x_i) + \\beta_i/x_i

    This is the likelihood for the variance :math:`x=S_k` given data
    :math:`\\beta = 0.5 |s_k|^2` where the Field :math:`s` is known to have
    the covariance :math:`S_k`.
423
424
425
426
427
428
429

    Parameters
    ----------
    beta : Field
        beta parameter of the inverse gamma distribution
    alpha : Scalar, Field, optional
        alpha parameter of the inverse gamma distribution
430
    """
Philipp Arras's avatar
Philipp Arras committed
431

432
433
    def __init__(self, beta, alpha=-0.5):
        if not isinstance(beta, Field):
Philipp Arras's avatar
Philipp Arras committed
434
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
435
        self._domain = DomainTuple.make(beta.domain)
436
437
        self._beta = beta
        if np.isscalar(alpha):
Martin Reinecke's avatar
stage2    
Martin Reinecke committed
438
            alpha = Field(beta.domain, np.full(beta.shape, alpha))
439
440
441
        elif not isinstance(alpha, Field):
            raise TypeError
        self._alphap1 = alpha+1
Philipp Arras's avatar
Philipp Arras committed
442
443
444
445
        if not self._beta.dtype == np.float64:
            # FIXME Add proper complex support for this energy
            raise TypeError
        self._sampling_dtype = _field_to_dtype(self._beta)
446

Philipp Arras's avatar
Philipp Arras committed
447
    def apply(self, x):
448
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
449
        res = x.log().vdot(self._alphap1) + x.reciprocal().vdot(self._beta)
Martin Reinecke's avatar
more    
Martin Reinecke committed
450
        if not x.want_metric:
451
            return res
Philipp Frank's avatar
Philipp Frank committed
452
        return res.add_metric(self.get_metric_at(x.val))
453

Philipp Frank's avatar
Philipp Frank committed
454
    def get_transformation(self):
Philipp Arras's avatar
Philipp Arras committed
455
456
        fact = self._alphap1.sqrt()
        res = makeOp(fact) @ ScalingOperator(self._domain, 1.).log()
Philipp Frank's avatar
Philipp Frank committed
457
        return self._sampling_dtype, res
458

Philipp Arras's avatar
Philipp Arras committed
459

Philipp Frank's avatar
Philipp Frank committed
460
class StudentTEnergy(LikelihoodOperator):
Lukas Platz's avatar
Lukas Platz committed
461
    """Computes likelihood energy corresponding to Student's t-distribution.
462
463

    .. math ::
Lukas Platz's avatar
Lukas Platz committed
464
465
        E_\\theta(f) = -\\log \\text{StudentT}_\\theta(f)
                     = \\frac{\\theta + 1}{2} \\log(1 + \\frac{f^2}{\\theta}),
466

Philipp Arras's avatar
Philipp Arras committed
467
468
    where f is a field defined on `domain`. Assumes that the data is `float64`
    for sampling.
469
470
471

    Parameters
    ----------
Lukas Platz's avatar
Lukas Platz committed
472
473
    domain : `Domain` or `DomainTuple`
        Domain of the operator
Reimar Leike's avatar
Reimar Leike committed
474
    theta : Scalar or Field
475
476
477
        Degree of freedom parameter for the student t distribution
    """

Philipp Arras's avatar
Philipp Arras committed
478
    def __init__(self, domain, theta):
479
480
481
        self._domain = DomainTuple.make(domain)
        self._theta = theta

Philipp Arras's avatar
Philipp Arras committed
482
    def apply(self, x):
483
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
484
        res = (((self._theta+1)/2)*(x**2/self._theta).log1p()).sum()
Martin Reinecke's avatar
more    
Martin Reinecke committed
485
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
486
            return res
Philipp Frank's avatar
Philipp Frank committed
487
        return res.add_metric(self.get_metric_at(x.val))
488

Philipp Frank's avatar
Philipp Frank committed
489
490
491
492
493
494
    def get_transformation(self):
        if isinstance(self._theta, Field) or isinstance(self._theta, MultiField):
            th = self._theta
        else:
            from ..extra import full
            th = full(self._domain, self._theta)
Philipp Arras's avatar
Philipp Arras committed
495
        return np.float64, makeOp(((th+1)/(th+3)).sqrt())
496

Philipp Arras's avatar
Philipp Arras committed
497

Philipp Frank's avatar
Philipp Frank committed
498
class BernoulliEnergy(LikelihoodOperator):
Philipp Arras's avatar
Philipp Arras committed
499
    """Computes likelihood energy of expected event frequency constrained by
500
501
    event data.

Philipp Arras's avatar
Philipp Arras committed
502
503
504
505
506
507
508
    .. math ::
        E(f) = -\\log \\text{Bernoulli}(d|f)
             = -d^\\dagger \\log f  - (1-d)^\\dagger \\log(1-f),

    where f is a field defined on `d.domain` with the expected
    frequencies of events.

509
510
    Parameters
    ----------
Martin Reinecke's avatar
Martin Reinecke committed
511
    d : Field
Philipp Arras's avatar
Philipp Arras committed
512
        Data field with events (1) or non-events (0).
Martin Reinecke's avatar
Martin Reinecke committed
513
    """
Philipp Arras's avatar
Philipp Arras committed
514

515
    def __init__(self, d):
Philipp Arras's avatar
Philipp Arras committed
516
517
        if not isinstance(d, Field) or not np.issubdtype(d.dtype, np.integer):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
518
        if np.any(np.logical_and(d.val != 0, d.val != 1)):
Philipp Arras's avatar
Philipp Arras committed
519
            raise ValueError
Martin Reinecke's avatar
Martin Reinecke committed
520
        self._d = d
521
        self._domain = DomainTuple.make(d.domain)
Martin Reinecke's avatar
Martin Reinecke committed
522

Philipp Arras's avatar
Philipp Arras committed
523
    def apply(self, x):
524
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
525
        res = -x.log().vdot(self._d) + (1.-x).log().vdot(self._d-1.)
Martin Reinecke's avatar
more    
Martin Reinecke committed
526
        if not x.want_metric:
Philipp Arras's avatar
Philipp Arras committed
527
            return res
Philipp Frank's avatar
Philipp Frank committed
528
        return res.add_metric(self.get_metric_at(x.val))
Martin Reinecke's avatar
Martin Reinecke committed
529

Philipp Frank's avatar
Philipp Frank committed
530
531
    def get_transformation(self):
        from ..extra import full
Philipp Arras's avatar
Philipp Arras committed
532
533
534
        res = Adder(full(self._domain, 1.)) @ ScalingOperator(self._domain, -1)
        res = res * ScalingOperator(self._domain, 1).reciprocal()
        return np.float64, -2.*res.sqrt().arctan()
Martin Reinecke's avatar
Martin Reinecke committed
535

Philipp Arras's avatar
Philipp Arras committed
536

537
class StandardHamiltonian(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
538
539
    """Computes an information Hamiltonian in its standard form, i.e. with the
    prior being a Gaussian with unit covariance.
540

Philipp Arras's avatar
Philipp Arras committed
541
    Let the likelihood energy be :math:`E_{lh}`. Then this operator computes:
542

Philipp Arras's avatar
Philipp Arras committed
543
544
    .. math ::
         H(f) = 0.5 f^\\dagger f + E_{lh}(f):
545

Martin Reinecke's avatar
Martin Reinecke committed
546
    Other field priors can be represented via transformations of a white
547
548
    Gaussian field into a field with the desired prior probability structure.

Martin Reinecke's avatar
Martin Reinecke committed
549
    By implementing prior information this way, the field prior is represented
550
551
552
    by a generative model, from which NIFTy can draw samples and infer a field
    using the Maximum a Posteriori (MAP) or the Variational Bayes (VB) method.

Philipp Arras's avatar
Philipp Arras committed
553
554
555
556
557
558
559
560
    The metric of this operator can be used as covariance for drawing Gaussian
    samples.

    Parameters
    ----------
    lh : EnergyOperator
        The likelihood energy.
    ic_samp : IterationController
561
        Tells an internal :class:`SamplingEnabler` which convergence criterion
Philipp Arras's avatar
Philipp Arras committed
562
        to use to draw Gaussian samples.
Philipp Arras's avatar
Philipp Arras committed
563
564
    prior_dtype : numpy.dtype or dict of numpy.dtype, optional
        Data type of prior used for sampling.
Philipp Arras's avatar
Philipp Arras committed
565
566
567
568
569

    See also
    --------
    `Encoding prior knowledge in the structure of the likelihood`,
    Jakob Knollmüller, Torsten A. Ensslin,
Martin Reinecke's avatar
Martin Reinecke committed
570
    `<https://arxiv.org/abs/1812.04403>`_
Martin Reinecke's avatar
Martin Reinecke committed
571
    """
Philipp Arras's avatar
Philipp Arras committed
572

573
    def __init__(self, lh, ic_samp=None, prior_dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
574
        self._lh = lh
Philipp Arras's avatar
Philipp Arras committed
575
        self._prior = GaussianEnergy(domain=lh.domain, sampling_dtype=prior_dtype)
Martin Reinecke's avatar
Martin Reinecke committed
576
        self._ic_samp = ic_samp
Martin Reinecke's avatar
Martin Reinecke committed
577
        self._domain = lh.domain
Martin Reinecke's avatar
Martin Reinecke committed
578

Philipp Arras's avatar
Philipp Arras committed
579
    def apply(self, x):
580
        self._check_input(x)
Martin Reinecke's avatar
more    
Martin Reinecke committed
581
        if not x.want_metric or self._ic_samp is None:
Philipp Arras's avatar
Philipp Arras committed
582
            return (self._lh + self._prior)(x)
Philipp Arras's avatar
Philipp Arras committed
583
        lhx, prx = self._lh(x), self._prior(x)
Philipp Arras's avatar
Philipp Arras committed
584
585
        met = SamplingEnabler(lhx.metric, prx.metric, self._ic_samp)
        return (lhx+prx).add_metric(met)
Martin Reinecke's avatar
Martin Reinecke committed
586

Philipp Arras's avatar
Philipp Arras committed
587
588
    def __repr__(self):
        subs = 'Likelihood:\n{}'.format(utilities.indent(self._lh.__repr__()))
589
        subs += '\nPrior:\n{}'.format(self._prior)
Martin Reinecke's avatar
Martin Reinecke committed
590
        return 'StandardHamiltonian:\n' + utilities.indent(subs)
Philipp Arras's avatar
Philipp Arras committed
591

592
593
594
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        out, lh1 = self._lh.simplify_for_constant_input(c_inp)
        return out, StandardHamiltonian(lh1, self._ic_samp)
595

Martin Reinecke's avatar
Martin Reinecke committed
596

Martin Reinecke's avatar
Martin Reinecke committed
597
class AveragedEnergy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
598
    """Averages an energy over samples.
Martin Reinecke's avatar
Martin Reinecke committed
599

600
601
602
    Parameters
    ----------
    h: Hamiltonian
Philipp Arras's avatar
Philipp Arras committed
603
       The energy to be averaged.
Martin Reinecke's avatar
Martin Reinecke committed
604
    res_samples : iterable of Fields
Torsten Ensslin's avatar
Torsten Ensslin committed
605
606
       Set of residual sample points to be added to mean field for
       approximate estimation of the KL.
607

Philipp Arras's avatar
Docs    
Philipp Arras committed
608
609
610
611
612
    Notes
    -----
    - Having symmetrized residual samples, with both :math:`v_i` and
      :math:`-v_i` being present, ensures that the distribution mean is
      exactly represented.
Torsten Ensslin's avatar
Fix te    
Torsten Ensslin committed
613

Philipp Arras's avatar
Docs    
Philipp Arras committed
614
615
616
    - :class:`AveragedEnergy(h)` approximates
      :math:`\\left< H(f) \\right>_{G(f-m,D)}` if the residuals :math:`f-m`
      are drawn from a Gaussian distribution with covariance :math:`D`.
Martin Reinecke's avatar
Martin Reinecke committed
617
    """
Martin Reinecke's avatar
Martin Reinecke committed
618
619
620

    def __init__(self, h, res_samples):
        self._h = h
Martin Reinecke's avatar
Martin Reinecke committed
621
        self._domain = h.domain
Martin Reinecke's avatar
Martin Reinecke committed
622
623
        self._res_samples = tuple(res_samples)

Philipp Arras's avatar
Philipp Arras committed
624
    def apply(self, x):
625
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
626
627
        mymap = map(lambda v: self._h(x+v), self._res_samples)
        return utilities.my_sum(mymap)/len(self._res_samples)
Philipp Frank's avatar
Philipp Frank committed
628
629
630
631
632

    def get_transformation(self):
        dtp, trafo = self._h.get_transformation()
        mymap = map(lambda v: trafo@Adder(v), self._res_samples)
        return dtp, utilities.my_sum(mymap)/np.sqrt(len(self._res_samples))