metric_gaussian_kl.py 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19
20
21

from .. import random, utilities
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
22
from ..logger import logger
23
from ..multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
24
25
26
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
Philipp Arras's avatar
Philipp Arras committed
27
from ..sugar import makeDomain, makeOp
Philipp Arras's avatar
Philipp Arras committed
28
from .energy import Energy
29
30


Martin Reinecke's avatar
Martin Reinecke committed
31
class _KLMetric(EndomorphicOperator):
32
33
34
35
36
37
38
39
40
    def __init__(self, KL):
        self._KL = KL
        self._capability = self.TIMES | self.ADJOINT_TIMES
        self._domain = KL.position.domain

    def apply(self, x, mode):
        self._check_input(x, mode)
        return self._KL.apply_metric(x)

Philipp Arras's avatar
Philipp Arras committed
41
42
    def draw_sample(self, from_inverse=False):
        return self._KL._metric_sample(from_inverse)
43
44
45


class MetricGaussianKL(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
46
47
48
    """Provides the sampled Kullback-Leibler divergence between a distribution
    and a Metric Gaussian.

Philipp Arras's avatar
Docs    
Philipp Arras committed
49
50
51
52
    A Metric Gaussian is used to approximate another probability distribution.
    It is a Gaussian distribution that uses the Fisher information metric of
    the other distribution at the location of its mean to approximate the
    variance. In order to infer the mean, a stochastic estimate of the
Martin Reinecke's avatar
Martin Reinecke committed
53
    Kullback-Leibler divergence is minimized. This estimate is obtained by
Philipp Arras's avatar
Docs    
Philipp Arras committed
54
55
56
57
58
    sampling the Metric Gaussian at the current mean. During minimization
    these samples are kept constant; only the mean is updated. Due to the
    typically nonlinear structure of the true distribution these samples have
    to be updated eventually by intantiating `MetricGaussianKL` again. For the
    true probability distribution the standard parametrization is assumed.
59
    The samples of this class can be distributed among MPI tasks.
60
61
62
63

    Parameters
    ----------
    mean : Field
Philipp Arras's avatar
Docs    
Philipp Arras committed
64
        Mean of the Gaussian probability distribution.
Jakob Knollmueller's avatar
Jakob Knollmueller committed
65
    hamiltonian : StandardHamiltonian
Philipp Arras's avatar
Docs    
Philipp Arras committed
66
        Hamiltonian of the approximated probability distribution.
67
    n_samples : integer
Philipp Arras's avatar
Docs    
Philipp Arras committed
68
        Number of samples used to stochastically estimate the KL.
69
    constants : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
70
71
        List of parameter keys that are kept constant during optimization.
        Default is no constants.
72
    point_estimates : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
73
74
75
        List of parameter keys for which no samples are drawn, but that are
        (possibly) optimized for, corresponding to point estimates of these.
        Default is to draw samples for the complete domain.
76
77
    mirror_samples : boolean
        Whether the negative of the drawn samples are also used,
Philipp Arras's avatar
Docs    
Philipp Arras committed
78
        as they are equally legitimate samples. If true, the number of used
79
        samples doubles. Mirroring samples stabilizes the KL estimate as
Philipp Arras's avatar
Docs    
Philipp Arras committed
80
        extreme sample variation is counterbalanced. Default is False.
81
82
83
    napprox : int
        Number of samples for computing preconditioner for sampling. No
        preconditioning is done by default.
84
85
86
    comm : MPI communicator or None
        If not None, samples will be distributed as evenly as possible
        across this communicator. If `mirror_samples` is set, then a sample and
87
        its mirror image will always reside on the same task.
Philipp Arras's avatar
Philipp Arras committed
88
89
90
91
92
    nanisinf : bool
        If true, nan energies which can happen due to overflows in the forward
        model are interpreted as inf. Thereby, the code does not crash on
        these occaisions but rather the minimizer is told that the position it
        has tried is not sensible.
93
    _local_samples : None
Philipp Arras's avatar
Philipp Arras committed
94
        Only a parameter for internal uses. Typically not to be set by users.
95

Philipp Arras's avatar
Docs    
Philipp Arras committed
96
97
98
99
100
101
102
103
    Note
    ----
    The two lists `constants` and `point_estimates` are independent from each
    other. It is possible to sample along domains which are kept constant
    during minimization and vice versa.

    See also
    --------
104
105
    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
106
107
    """

Martin Reinecke's avatar
typo    
Martin Reinecke committed
108
    def __init__(self, mean, hamiltonian, n_samples, constants=[],
Philipp Arras's avatar
Philipp Arras committed
109
                 point_estimates=[], mirror_samples=False,
110
111
                 napprox=0, comm=None, _local_samples=None,
                 nanisinf=False):
112
        super(MetricGaussianKL, self).__init__(mean)
Philipp Arras's avatar
Philipp Arras committed
113
114
115

        if not isinstance(hamiltonian, StandardHamiltonian):
            raise TypeError
116
        if hamiltonian.domain is not mean.domain:
Philipp Arras's avatar
Philipp Arras committed
117
118
119
            raise ValueError
        if not isinstance(n_samples, int):
            raise TypeError
120
        self._mitigate_nans = nanisinf
Philipp Arras's avatar
Philipp Arras committed
121
        if not isinstance(mirror_samples, bool):
122
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
123
        if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
Philipp Arras's avatar
Philipp Arras committed
124
125
            raise RuntimeError(
                'Point estimates for whole domain. Use EnergyAdapter instead.')
Philipp Arras's avatar
Philipp Arras committed
126

127
        self._hamiltonian = hamiltonian
Philipp Arras's avatar
Philipp Arras committed
128
129
130
131
132
        if len(constants) > 0:
            dom = {kk: vv for kk, vv in mean.domain.items() if kk in constants}
            dom = makeDomain(dom)
            cstpos = mean.extract(dom)
            _, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
Philipp Arras's avatar
Philipp Arras committed
133

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
134
        self._n_samples = int(n_samples)
135
136
        if comm is not None:
            self._comm = comm
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
137
138
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
139
            self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank)
140
        else:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
141
142
            self._comm = None
            self._lo, self._hi = 0, self._n_samples
143
144
145
146
147
148

        self._mirror_samples = bool(mirror_samples)
        self._n_eff_samples = self._n_samples
        if self._mirror_samples:
            self._n_eff_samples *= 2

149
        if _local_samples is None:
150
            if len(point_estimates) > 0:
Philipp Arras's avatar
Philipp Arras committed
151
152
                dom = {kk: vv for kk, vv in mean.domain.items()
                       if kk in point_estimates}
153
                dom = makeDomain(dom)
Philipp Arras's avatar
Philipp Arras committed
154
                cstpos = mean.extract(dom)
Philipp Arras's avatar
Philipp Arras committed
155
156
                _, hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
            met = hamiltonian(Linearization.make_var(mean, True)).metric
Philipp Arras's avatar
Fixup    
Philipp Arras committed
157
            if napprox >= 1:
158
                met._approximation = makeOp(approximation2endo(met, napprox))
159
            _local_samples = []
160
161
            sseq = random.spawn_sseq(self._n_samples)
            for i in range(self._lo, self._hi):
Martin Reinecke's avatar
Martin Reinecke committed
162
                with random.Context(sseq[i]):
Philipp Arras's avatar
Philipp Arras committed
163
                    _local_samples.append(met.draw_sample(from_inverse=True))
164
            _local_samples = tuple(_local_samples)
165
        else:
166
            if len(_local_samples) != self._hi-self._lo:
167
                raise ValueError("# of samples mismatch")
168
        self._local_samples = _local_samples
169
        self._lin = Linearization.make_var(mean)
170
        v, g = [], []
171
        for s in self._local_samples:
172
            tmp = self._hamiltonian(self._lin+s)
Philipp Arras's avatar
Philipp Arras committed
173
            tv = tmp.val.val
174
175
176
            tg = tmp.gradient
            if self._mirror_samples:
                tmp = self._hamiltonian(self._lin-s)
Philipp Arras's avatar
Philipp Arras committed
177
                tv = tv + tmp.val.val
178
179
180
                tg = tg + tmp.gradient
            v.append(tv)
            g.append(tg)
181
        self._val = utilities.allreduce_sum(v, self._comm)[()]/self._n_eff_samples
182
183
        if np.isnan(self._val) and self._mitigate_nans:
            self._val = np.inf
184
        self._grad = utilities.allreduce_sum(g, self._comm)/self._n_eff_samples
185
186

    def at(self, position):
187
        return MetricGaussianKL(
Philipp Arras's avatar
Philipp Arras committed
188
189
            position, self._hamiltonian, self._n_samples,
            mirror_samples=self._mirror_samples, comm=self._comm,
190
            _local_samples=self._local_samples, nanisinf=self._mitigate_nans)
191
192
193
194
195
196
197
198
199
200

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

    def apply_metric(self, x):
201
202
        lin = self._lin.with_want_metric()
        res = []
203
204
205
206
207
        for s in self._local_samples:
            tmp = self._hamiltonian(lin+s).metric(x)
            if self._mirror_samples:
                tmp = tmp + self._hamiltonian(lin-s).metric(x)
            res.append(tmp)
208
        return utilities.allreduce_sum(res, self._comm)/self._n_eff_samples
209
210
211

    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
212
        return _KLMetric(self)
213
214
215

    @property
    def samples(self):
216
217
218
219
220
        if self._comm is None:
            for s in self._local_samples:
                yield s
                if self._mirror_samples:
                    yield -s
221
        else:
222
223
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
224
            rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)]
225
226
227
228
229
230
231
            for itask, (l, h) in enumerate(rank_lo_hi):
                for i in range(l, h):
                    data = self._local_samples[i-self._lo] if rank == itask else None
                    s = self._comm.bcast(data, root=itask)
                    yield s
                    if self._mirror_samples:
                        yield -s
232

Philipp Arras's avatar
Philipp Arras committed
233
    def _metric_sample(self, from_inverse=False):
234
235
        if from_inverse:
            raise NotImplementedError()
Philipp Arras's avatar
Philipp Arras committed
236
237
238
239
        s = ('This draws from the Hamiltonian used for evaluation and does '
             ' not take point_estimates into accout. Make sure that this '
             'is your intended use.')
        logger.warning(s)
240
        lin = self._lin.with_want_metric()
241
        samp = []
Philipp Arras's avatar
Fixups    
Philipp Arras committed
242
        sseq = random.spawn_sseq(self._n_samples)
243
        for i, v in enumerate(self._local_samples):
Martin Reinecke's avatar
Martin Reinecke committed
244
            with random.Context(sseq[self._lo+i]):
245
                tmp = self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False)
Martin Reinecke's avatar
Martin Reinecke committed
246
                if self._mirror_samples:
247
248
                    tmp = tmp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False)
                samp.append(tmp)
249
        return utilities.allreduce_sum(samp, self._comm)/self._n_eff_samples