metric_gaussian_kl.py 12.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19
20

from .. import random, utilities
21
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
22
from ..linearization import Linearization
23
from ..multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
24
25
26
27
28
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import full, makeOp
from .energy import Energy
29
30
31
32
33
34
35
36
37
38


def _shareRange(nwork, nshares, myshare):
    nbase = nwork//nshares
    additional = nwork % nshares
    lo = myshare*nbase + min(myshare, additional)
    hi = lo + nbase + int(myshare < additional)
    return lo, hi


Martin Reinecke's avatar
Martin Reinecke committed
39
class _KLMetric(EndomorphicOperator):
40
41
42
43
44
45
46
47
48
    def __init__(self, KL):
        self._KL = KL
        self._capability = self.TIMES | self.ADJOINT_TIMES
        self._domain = KL.position.domain

    def apply(self, x, mode):
        self._check_input(x, mode)
        return self._KL.apply_metric(x)

Philipp Arras's avatar
Philipp Arras committed
49
50
    def draw_sample(self, from_inverse=False):
        return self._KL._metric_sample(from_inverse)
51
52
53


class MetricGaussianKL(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
54
55
56
    """Provides the sampled Kullback-Leibler divergence between a distribution
    and a Metric Gaussian.

Philipp Arras's avatar
Docs    
Philipp Arras committed
57
58
59
60
    A Metric Gaussian is used to approximate another probability distribution.
    It is a Gaussian distribution that uses the Fisher information metric of
    the other distribution at the location of its mean to approximate the
    variance. In order to infer the mean, a stochastic estimate of the
Martin Reinecke's avatar
Martin Reinecke committed
61
    Kullback-Leibler divergence is minimized. This estimate is obtained by
Philipp Arras's avatar
Docs    
Philipp Arras committed
62
63
64
65
66
    sampling the Metric Gaussian at the current mean. During minimization
    these samples are kept constant; only the mean is updated. Due to the
    typically nonlinear structure of the true distribution these samples have
    to be updated eventually by intantiating `MetricGaussianKL` again. For the
    true probability distribution the standard parametrization is assumed.
67
    The samples of this class can be distributed among MPI tasks.
68
69
70
71

    Parameters
    ----------
    mean : Field
Philipp Arras's avatar
Docs    
Philipp Arras committed
72
        Mean of the Gaussian probability distribution.
Jakob Knollmueller's avatar
Jakob Knollmueller committed
73
    hamiltonian : StandardHamiltonian
Philipp Arras's avatar
Docs    
Philipp Arras committed
74
        Hamiltonian of the approximated probability distribution.
75
    n_samples : integer
Philipp Arras's avatar
Docs    
Philipp Arras committed
76
        Number of samples used to stochastically estimate the KL.
77
    constants : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
78
79
        List of parameter keys that are kept constant during optimization.
        Default is no constants.
80
    point_estimates : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
81
82
83
        List of parameter keys for which no samples are drawn, but that are
        (possibly) optimized for, corresponding to point estimates of these.
        Default is to draw samples for the complete domain.
84
85
    mirror_samples : boolean
        Whether the negative of the drawn samples are also used,
Philipp Arras's avatar
Docs    
Philipp Arras committed
86
        as they are equally legitimate samples. If true, the number of used
87
        samples doubles. Mirroring samples stabilizes the KL estimate as
Philipp Arras's avatar
Docs    
Philipp Arras committed
88
        extreme sample variation is counterbalanced. Default is False.
89
90
91
    napprox : int
        Number of samples for computing preconditioner for sampling. No
        preconditioning is done by default.
92
93
94
    comm : MPI communicator or None
        If not None, samples will be distributed as evenly as possible
        across this communicator. If `mirror_samples` is set, then a sample and
95
        its mirror image will always reside on the same task.
Philipp Arras's avatar
Philipp Arras committed
96
97
98
99
100
    nanisinf : bool
        If true, nan energies which can happen due to overflows in the forward
        model are interpreted as inf. Thereby, the code does not crash on
        these occaisions but rather the minimizer is told that the position it
        has tried is not sensible.
101
    _local_samples : None
Philipp Arras's avatar
Philipp Arras committed
102
        Only a parameter for internal uses. Typically not to be set by users.
103

Philipp Arras's avatar
Docs    
Philipp Arras committed
104
105
106
107
108
109
110
111
    Note
    ----
    The two lists `constants` and `point_estimates` are independent from each
    other. It is possible to sample along domains which are kept constant
    during minimization and vice versa.

    See also
    --------
112
113
    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
114
115
    """

Martin Reinecke's avatar
typo    
Martin Reinecke committed
116
    def __init__(self, mean, hamiltonian, n_samples, constants=[],
Philipp Arras's avatar
Philipp Arras committed
117
                 point_estimates=[], mirror_samples=False,
118
119
                 napprox=0, comm=None, _local_samples=None,
                 nanisinf=False):
120
        super(MetricGaussianKL, self).__init__(mean)
Philipp Arras's avatar
Philipp Arras committed
121
122
123

        if not isinstance(hamiltonian, StandardHamiltonian):
            raise TypeError
124
        if hamiltonian.domain is not mean.domain:
Philipp Arras's avatar
Philipp Arras committed
125
126
127
            raise ValueError
        if not isinstance(n_samples, int):
            raise TypeError
128
129
        self._constants = tuple(constants)
        self._point_estimates = tuple(point_estimates)
130
        self._mitigate_nans = nanisinf
Philipp Arras's avatar
Philipp Arras committed
131
        if not isinstance(mirror_samples, bool):
132
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
133

134
        self._hamiltonian = hamiltonian
Philipp Arras's avatar
Philipp Arras committed
135

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
136
        self._n_samples = int(n_samples)
137
138
        if comm is not None:
            self._comm = comm
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
139
140
141
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
            self._lo, self._hi = _shareRange(self._n_samples, ntask, rank)
142
        else:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
143
144
            self._comm = None
            self._lo, self._hi = 0, self._n_samples
145
146
147
148
149
150

        self._mirror_samples = bool(mirror_samples)
        self._n_eff_samples = self._n_samples
        if self._mirror_samples:
            self._n_eff_samples *= 2

151
        if _local_samples is None:
152
            met = hamiltonian(Linearization.make_partial_var(
153
                mean, self._point_estimates, True)).metric
Philipp Arras's avatar
Fixup    
Philipp Arras committed
154
            if napprox >= 1:
155
                met._approximation = makeOp(approximation2endo(met, napprox))
156
            _local_samples = []
157
158
            sseq = random.spawn_sseq(self._n_samples)
            for i in range(self._lo, self._hi):
Martin Reinecke's avatar
Martin Reinecke committed
159
                with random.Context(sseq[i]):
Philipp Arras's avatar
Philipp Arras committed
160
                    _local_samples.append(met.draw_sample(from_inverse=True))
161
            _local_samples = tuple(_local_samples)
162
        else:
163
            if len(_local_samples) != self._hi-self._lo:
164
                raise ValueError("# of samples mismatch")
165
        self._local_samples = _local_samples
166
        self._lin = Linearization.make_partial_var(mean, self._constants)
167
        v, g = [], []
168
        for s in self._local_samples:
169
            tmp = self._hamiltonian(self._lin+s)
Philipp Arras's avatar
Philipp Arras committed
170
            tv = tmp.val.val
171
172
173
            tg = tmp.gradient
            if self._mirror_samples:
                tmp = self._hamiltonian(self._lin-s)
Philipp Arras's avatar
Philipp Arras committed
174
                tv = tv + tmp.val.val
175
176
177
                tg = tg + tmp.gradient
            v.append(tv)
            g.append(tg)
178
        self._val = self._sumup(v)[()]/self._n_eff_samples
179
180
        if np.isnan(self._val) and self._mitigate_nans:
            self._val = np.inf
181
        self._grad = self._sumup(g)/self._n_eff_samples
182
183

    def at(self, position):
184
185
        return MetricGaussianKL(
            position, self._hamiltonian, self._n_samples, self._constants,
186
            self._point_estimates, self._mirror_samples, comm=self._comm,
187
            _local_samples=self._local_samples, nanisinf=self._mitigate_nans)
188
189
190
191
192
193
194
195
196
197

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

    def apply_metric(self, x):
198
199
        lin = self._lin.with_want_metric()
        res = []
200
201
202
203
204
        for s in self._local_samples:
            tmp = self._hamiltonian(lin+s).metric(x)
            if self._mirror_samples:
                tmp = tmp + self._hamiltonian(lin-s).metric(x)
            res.append(tmp)
205
        return self._sumup(res)/self._n_eff_samples
206
207
208

    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
209
        return _KLMetric(self)
210
211
212

    @property
    def samples(self):
213
214
215
216
217
        if self._comm is None:
            for s in self._local_samples:
                yield s
                if self._mirror_samples:
                    yield -s
218
        else:
219
220
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
221
222
223
224
225
226
227
228
            rank_lo_hi = [_shareRange(self._n_samples, ntask, i) for i in range(ntask)]
            for itask, (l, h) in enumerate(rank_lo_hi):
                for i in range(l, h):
                    data = self._local_samples[i-self._lo] if rank == itask else None
                    s = self._comm.bcast(data, root=itask)
                    yield s
                    if self._mirror_samples:
                        yield -s
229

230
    def _sumup(self, obj):
Reimar Leike's avatar
Reimar Leike committed
231
232
233
234
235
236
237
238
        """ This is a deterministic implementation of MPI allreduce

        Numeric addition is not associative due to rounding errors.
        Therefore we provide our own implementation that is consistent
        no matter if MPI is used and how many tasks there are.

        At the beginning, a list `who` is constructed, that states which obj can
        be found on which MPI task.
Martin Reinecke's avatar
merge    
Martin Reinecke committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        Then elements are added pairwise, with increasing pair distance.
        In the first round, the distance between pair members is 1:
          v[0] := v[0] + v[1]
          v[2] := v[2] + v[3]
          v[4] := v[4] + v[5]
        Entries whose summation partner lies beyond the end of the array
        stay unchanged.
        When both summation partners are not located on the same MPI task,
        the second summand is sent to the task holding the first summand and
        the operation is carried out there.
        For the next round, the distance is doubled:
          v[0] := v[0] + v[2]
          v[4] := v[4] + v[6]
          v[8] := v[8] + v[10]
        This is repeated until the distance exceeds the length of the array.
        At this point v[0] contains the sum of all entries, which is then
        broadcast to all tasks.
Reimar Leike's avatar
Reimar Leike committed
256
        """
257
        if self._comm is None:
258
259
            who = np.zeros(self._n_samples, dtype=np.int32)
            rank = 0
Martin Reinecke's avatar
merge    
Martin Reinecke committed
260
            vals = list(obj)  # necessary since we don't want to modify `obj`
261
262
263
264
        else:
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
            rank_lo_hi = [_shareRange(self._n_samples, ntask, i) for i in range(ntask)]
265
266
            lo, hi = rank_lo_hi[rank]
            vals = [None]*lo + list(obj) + [None]*(self._n_samples-hi)
267
268
269
270
            who = np.zeros(len(vals), dtype=np.int32)
            for t, (l,h) in enumerate(rank_lo_hi):
                who[l:h] = t

271
272
273
274
275
276
277
        step = 1
        while step < self._n_samples:
            for j in range(0, self._n_samples, 2*step):
                if j+step < self._n_samples:  # summation partner found
                    if rank == who[j]:
                        if who[j] == who[j+step]:  # no communication required
                            vals[j] = vals[j] + vals[j+step]
Martin Reinecke's avatar
merge    
Martin Reinecke committed
278
                            vals[j+step] = None
279
280
281
282
                        else:
                            vals[j] = vals[j] + self._comm.recv(source=who[j+step])
                    elif rank == who[j+step]:
                        self._comm.send(vals[j+step], dest=who[j])
Martin Reinecke's avatar
merge    
Martin Reinecke committed
283
                        vals[j+step] = None
284
            step *= 2
285
286
        if self._comm is None:
            return vals[0]
287
        return self._comm.bcast(vals[0], root=who[0])
288

Philipp Arras's avatar
Philipp Arras committed
289
    def _metric_sample(self, from_inverse=False):
290
291
292
        if from_inverse:
            raise NotImplementedError()
        lin = self._lin.with_want_metric()
293
        samp = []
Philipp Arras's avatar
Fixups    
Philipp Arras committed
294
        sseq = random.spawn_sseq(self._n_samples)
295
        for i, v in enumerate(self._local_samples):
Martin Reinecke's avatar
Martin Reinecke committed
296
            with random.Context(sseq[self._lo+i]):
297
                tmp = self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False)
Martin Reinecke's avatar
Martin Reinecke committed
298
                if self._mirror_samples:
299
300
                    tmp = tmp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False)
                samp.append(tmp)
301
        return self._sumup(samp)/self._n_eff_samples