metric_gaussian_kl.py 10.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19
20

from .. import random, utilities
21
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
22
from ..linearization import Linearization
23
from ..multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
24
25
26
27
28
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import full, makeOp
from .energy import Energy
29
30
31
32
33
34
35
36
37
38


def _shareRange(nwork, nshares, myshare):
    nbase = nwork//nshares
    additional = nwork % nshares
    lo = myshare*nbase + min(myshare, additional)
    hi = lo + nbase + int(myshare < additional)
    return lo, hi


Martin Reinecke's avatar
Martin Reinecke committed
39
def _np_allreduce_sum(comm, arr):
40
41
42
43
44
45
46
47
48
    if comm is None:
        return arr
    from mpi4py import MPI
    arr = np.array(arr)
    res = np.empty_like(arr)
    comm.Allreduce(arr, res, MPI.SUM)
    return res


Martin Reinecke's avatar
Martin Reinecke committed
49
def _allreduce_sum_field(comm, fld):
50
51
52
    if comm is None:
        return fld
    if isinstance(fld, Field):
Philipp Arras's avatar
Philipp Arras committed
53
        return Field(fld.domain, _np_allreduce_sum(comm, fld.val))
54
    res = tuple(
Martin Reinecke's avatar
Martin Reinecke committed
55
        Field(f.domain, _np_allreduce_sum(comm, f.val))
56
57
58
59
        for f in fld.values())
    return MultiField(fld.domain, res)


Martin Reinecke's avatar
Martin Reinecke committed
60
class _KLMetric(EndomorphicOperator):
61
62
63
64
65
66
67
68
69
    def __init__(self, KL):
        self._KL = KL
        self._capability = self.TIMES | self.ADJOINT_TIMES
        self._domain = KL.position.domain

    def apply(self, x, mode):
        self._check_input(x, mode)
        return self._KL.apply_metric(x)

Philipp Arras's avatar
Philipp Arras committed
70
71
    def draw_sample(self, from_inverse=False):
        return self._KL._metric_sample(from_inverse)
72
73
74


class MetricGaussianKL(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
75
76
77
    """Provides the sampled Kullback-Leibler divergence between a distribution
    and a Metric Gaussian.

Philipp Arras's avatar
Docs    
Philipp Arras committed
78
79
80
81
    A Metric Gaussian is used to approximate another probability distribution.
    It is a Gaussian distribution that uses the Fisher information metric of
    the other distribution at the location of its mean to approximate the
    variance. In order to infer the mean, a stochastic estimate of the
Martin Reinecke's avatar
Martin Reinecke committed
82
    Kullback-Leibler divergence is minimized. This estimate is obtained by
Philipp Arras's avatar
Docs    
Philipp Arras committed
83
84
85
86
87
    sampling the Metric Gaussian at the current mean. During minimization
    these samples are kept constant; only the mean is updated. Due to the
    typically nonlinear structure of the true distribution these samples have
    to be updated eventually by intantiating `MetricGaussianKL` again. For the
    true probability distribution the standard parametrization is assumed.
88
    The samples of this class can be distributed among MPI tasks.
89
90
91
92

    Parameters
    ----------
    mean : Field
Philipp Arras's avatar
Docs    
Philipp Arras committed
93
        Mean of the Gaussian probability distribution.
Jakob Knollmueller's avatar
Jakob Knollmueller committed
94
    hamiltonian : StandardHamiltonian
Philipp Arras's avatar
Docs    
Philipp Arras committed
95
        Hamiltonian of the approximated probability distribution.
96
    n_samples : integer
Philipp Arras's avatar
Docs    
Philipp Arras committed
97
        Number of samples used to stochastically estimate the KL.
98
    constants : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
99
100
        List of parameter keys that are kept constant during optimization.
        Default is no constants.
101
    point_estimates : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
102
103
104
        List of parameter keys for which no samples are drawn, but that are
        (possibly) optimized for, corresponding to point estimates of these.
        Default is to draw samples for the complete domain.
105
106
    mirror_samples : boolean
        Whether the negative of the drawn samples are also used,
Philipp Arras's avatar
Docs    
Philipp Arras committed
107
        as they are equally legitimate samples. If true, the number of used
108
        samples doubles. Mirroring samples stabilizes the KL estimate as
Philipp Arras's avatar
Docs    
Philipp Arras committed
109
        extreme sample variation is counterbalanced. Default is False.
110
111
112
    napprox : int
        Number of samples for computing preconditioner for sampling. No
        preconditioning is done by default.
113
114
115
    comm : MPI communicator or None
        If not None, samples will be distributed as evenly as possible
        across this communicator. If `mirror_samples` is set, then a sample and
116
        its mirror image will always reside on the same task.
117
    _local_samples : None
Philipp Arras's avatar
Philipp Arras committed
118
        Only a parameter for internal uses. Typically not to be set by users.
119

Philipp Arras's avatar
Docs    
Philipp Arras committed
120
121
122
123
124
125
126
127
    Note
    ----
    The two lists `constants` and `point_estimates` are independent from each
    other. It is possible to sample along domains which are kept constant
    during minimization and vice versa.

    See also
    --------
128
129
    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
130
131
    """

Martin Reinecke's avatar
typo    
Martin Reinecke committed
132
    def __init__(self, mean, hamiltonian, n_samples, constants=[],
Philipp Arras's avatar
Philipp Arras committed
133
                 point_estimates=[], mirror_samples=False,
134
135
                 napprox=0, comm=None, _local_samples=None,
                 nanisinf=False):
136
        super(MetricGaussianKL, self).__init__(mean)
Philipp Arras's avatar
Philipp Arras committed
137
138
139

        if not isinstance(hamiltonian, StandardHamiltonian):
            raise TypeError
140
        if hamiltonian.domain is not mean.domain:
Philipp Arras's avatar
Philipp Arras committed
141
142
143
            raise ValueError
        if not isinstance(n_samples, int):
            raise TypeError
144
145
        self._constants = tuple(constants)
        self._point_estimates = tuple(point_estimates)
146
        self._mitigate_nans = nanisinf
Philipp Arras's avatar
Philipp Arras committed
147
        if not isinstance(mirror_samples, bool):
148
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
149

150
        self._hamiltonian = hamiltonian
Philipp Arras's avatar
Philipp Arras committed
151

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
152
        self._n_samples = int(n_samples)
153
154
        if comm is not None:
            self._comm = comm
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
155
156
157
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
            self._lo, self._hi = _shareRange(self._n_samples, ntask, rank)
158
        else:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
159
160
            self._comm = None
            self._lo, self._hi = 0, self._n_samples
161
162
163
164
165
166

        self._mirror_samples = bool(mirror_samples)
        self._n_eff_samples = self._n_samples
        if self._mirror_samples:
            self._n_eff_samples *= 2

167
        if _local_samples is None:
168
            met = hamiltonian(Linearization.make_partial_var(
169
                mean, self._point_estimates, True)).metric
Philipp Arras's avatar
Fixup    
Philipp Arras committed
170
            if napprox >= 1:
171
                met._approximation = makeOp(approximation2endo(met, napprox))
172
            _local_samples = []
173
174
            sseq = random.spawn_sseq(self._n_samples)
            for i in range(self._lo, self._hi):
Martin Reinecke's avatar
Martin Reinecke committed
175
                with random.Context(sseq[i]):
Philipp Arras's avatar
Philipp Arras committed
176
                    _local_samples.append(met.draw_sample(from_inverse=True))
177
            _local_samples = tuple(_local_samples)
178
        else:
179
            if len(_local_samples) != self._hi-self._lo:
180
                raise ValueError("# of samples mismatch")
181
        self._local_samples = _local_samples
182
        self._lin = Linearization.make_partial_var(mean, self._constants)
183
        v, g = None, None
184
        if len(self._local_samples) == 0:  # hack if there are too many MPI tasks
185
186
187
188
            tmp = self._hamiltonian(self._lin)
            v = 0. * tmp.val.val
            g = 0. * tmp.gradient
        else:
189
            for s in self._local_samples:
190
191
192
193
194
195
196
197
198
                tmp = self._hamiltonian(self._lin+s)
                if self._mirror_samples:
                    tmp = tmp + self._hamiltonian(self._lin-s)
                if v is None:
                    v = tmp.val.val_rw()
                    g = tmp.gradient
                else:
                    v += tmp.val.val
                    g = g + tmp.gradient
Martin Reinecke's avatar
Martin Reinecke committed
199
        self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
200
201
        if np.isnan(self._val) and self._mitigate_nans:
            self._val = np.inf
Martin Reinecke's avatar
Martin Reinecke committed
202
        self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
203
204
205
        self._metric = None

    def at(self, position):
206
207
        return MetricGaussianKL(
            position, self._hamiltonian, self._n_samples, self._constants,
208
            self._point_estimates, self._mirror_samples, comm=self._comm,
209
            _local_samples=self._local_samples, nanisinf=self._mitigate_nans)
210
211
212
213
214
215
216
217
218
219

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

    def _get_metric(self):
220
        lin = self._lin.with_want_metric()
221
        if self._metric is None:
222
            if len(self._local_samples) == 0:  # hack if there are too many MPI tasks
223
224
225
                self._metric = self._hamiltonian(lin).metric.scale(0.)
            else:
                mymap = map(lambda v: self._hamiltonian(lin+v).metric,
226
                            self._local_samples)
Martin Reinecke's avatar
fix    
Martin Reinecke committed
227
228
229
                unscaled_metric = utilities.my_sum(mymap)
                if self._mirror_samples:
                    mymap = map(lambda v: self._hamiltonian(lin-v).metric,
230
                            self._local_samples)
Martin Reinecke's avatar
fix    
Martin Reinecke committed
231
232
                    unscaled_metric = unscaled_metric + utilities.my_sum(mymap)
                self._metric = unscaled_metric.scale(1./self._n_eff_samples)
233
234
235

    def apply_metric(self, x):
        self._get_metric()
Martin Reinecke's avatar
Martin Reinecke committed
236
        return _allreduce_sum_field(self._comm, self._metric(x))
237
238
239

    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
240
        return _KLMetric(self)
241
242
243

    @property
    def samples(self):
244
245
246
247
248
        if self._comm is None:
            for s in self._local_samples:
                yield s
                if self._mirror_samples:
                    yield -s
249
        else:
250
251
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
252
253
254
255
256
257
258
259
            rank_lo_hi = [_shareRange(self._n_samples, ntask, i) for i in range(ntask)]
            for itask, (l, h) in enumerate(rank_lo_hi):
                for i in range(l, h):
                    data = self._local_samples[i-self._lo] if rank == itask else None
                    s = self._comm.bcast(data, root=itask)
                    yield s
                    if self._mirror_samples:
                        yield -s
260

Philipp Arras's avatar
Philipp Arras committed
261
    def _metric_sample(self, from_inverse=False):
262
263
264
265
        if from_inverse:
            raise NotImplementedError()
        lin = self._lin.with_want_metric()
        samp = full(self._hamiltonian.domain, 0.)
Philipp Arras's avatar
Fixups    
Philipp Arras committed
266
        sseq = random.spawn_sseq(self._n_samples)
267
        for i, v in enumerate(self._local_samples):
Martin Reinecke's avatar
Martin Reinecke committed
268
            with random.Context(sseq[self._lo+i]):
Philipp Arras's avatar
Philipp Arras committed
269
                samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False)
Martin Reinecke's avatar
Martin Reinecke committed
270
                if self._mirror_samples:
Philipp Arras's avatar
Philipp Arras committed
271
                    samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False)
Martin Reinecke's avatar
Martin Reinecke committed
272
        return _allreduce_sum_field(self._comm, samp)/self._n_eff_samples