metric_gaussian_kl.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19
20

from .. import random, utilities
21
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
22
from ..linearization import Linearization
23
from ..multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
24
25
26
27
28
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import full, makeOp
from .energy import Energy
29
30
31
32
33
34
35
36
37
38


def _shareRange(nwork, nshares, myshare):
    nbase = nwork//nshares
    additional = nwork % nshares
    lo = myshare*nbase + min(myshare, additional)
    hi = lo + nbase + int(myshare < additional)
    return lo, hi


39
40
41
42
43
44
45
46
47
48
49
50
def _getTask(iwork, nwork, nshares):
    nbase = nwork//nshares
    additional = nwork % nshares
    # FIXME: this is crappy code!
    for ishare in range(nshares):
        lo = ishare*nbase + min(ishare, additional)
        hi = lo + nbase + int(ishare < additional)
        if hi>iwork:
            return ishare
    raise RunTimeError("must not arrive here")


Martin Reinecke's avatar
Martin Reinecke committed
51
def _np_allreduce_sum(comm, arr):
52
53
54
55
56
57
58
59
60
    if comm is None:
        return arr
    from mpi4py import MPI
    arr = np.array(arr)
    res = np.empty_like(arr)
    comm.Allreduce(arr, res, MPI.SUM)
    return res


Martin Reinecke's avatar
Martin Reinecke committed
61
def _allreduce_sum_field(comm, fld):
62
63
64
    if comm is None:
        return fld
    if isinstance(fld, Field):
Martin Reinecke's avatar
Martin Reinecke committed
65
        return Field(fld.domain, _np_allreduce_sum(fld.val))
66
    res = tuple(
Martin Reinecke's avatar
Martin Reinecke committed
67
        Field(f.domain, _np_allreduce_sum(comm, f.val))
68
69
70
71
        for f in fld.values())
    return MultiField(fld.domain, res)


Martin Reinecke's avatar
Martin Reinecke committed
72
class _KLMetric(EndomorphicOperator):
73
74
75
76
77
78
79
80
81
82
    def __init__(self, KL):
        self._KL = KL
        self._capability = self.TIMES | self.ADJOINT_TIMES
        self._domain = KL.position.domain

    def apply(self, x, mode):
        self._check_input(x, mode)
        return self._KL.apply_metric(x)

    def draw_sample(self, from_inverse=False, dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
83
        return self._KL._metric_sample(from_inverse, dtype)
84
85
86


class MetricGaussianKL(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
87
88
89
    """Provides the sampled Kullback-Leibler divergence between a distribution
    and a Metric Gaussian.

Philipp Arras's avatar
Docs    
Philipp Arras committed
90
91
92
93
    A Metric Gaussian is used to approximate another probability distribution.
    It is a Gaussian distribution that uses the Fisher information metric of
    the other distribution at the location of its mean to approximate the
    variance. In order to infer the mean, a stochastic estimate of the
Martin Reinecke's avatar
Martin Reinecke committed
94
    Kullback-Leibler divergence is minimized. This estimate is obtained by
Philipp Arras's avatar
Docs    
Philipp Arras committed
95
96
97
98
99
    sampling the Metric Gaussian at the current mean. During minimization
    these samples are kept constant; only the mean is updated. Due to the
    typically nonlinear structure of the true distribution these samples have
    to be updated eventually by intantiating `MetricGaussianKL` again. For the
    true probability distribution the standard parametrization is assumed.
100
    The samples of this class can be distributed among MPI tasks.
101
102
103
104

    Parameters
    ----------
    mean : Field
Philipp Arras's avatar
Docs    
Philipp Arras committed
105
        Mean of the Gaussian probability distribution.
Jakob Knollmueller's avatar
Jakob Knollmueller committed
106
    hamiltonian : StandardHamiltonian
Philipp Arras's avatar
Docs    
Philipp Arras committed
107
        Hamiltonian of the approximated probability distribution.
108
    n_samples : integer
Philipp Arras's avatar
Docs    
Philipp Arras committed
109
        Number of samples used to stochastically estimate the KL.
110
    constants : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
111
112
        List of parameter keys that are kept constant during optimization.
        Default is no constants.
113
    point_estimates : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
114
115
116
        List of parameter keys for which no samples are drawn, but that are
        (possibly) optimized for, corresponding to point estimates of these.
        Default is to draw samples for the complete domain.
117
118
    mirror_samples : boolean
        Whether the negative of the drawn samples are also used,
Philipp Arras's avatar
Docs    
Philipp Arras committed
119
        as they are equally legitimate samples. If true, the number of used
120
        samples doubles. Mirroring samples stabilizes the KL estimate as
Philipp Arras's avatar
Docs    
Philipp Arras committed
121
        extreme sample variation is counterbalanced. Default is False.
122
123
124
    napprox : int
        Number of samples for computing preconditioner for sampling. No
        preconditioning is done by default.
125
126
127
    comm : MPI communicator or None
        If not None, samples will be distributed as evenly as possible
        across this communicator. If `mirror_samples` is set, then a sample and
128
        its mirror image will always reside on the same task.
Philipp Arras's avatar
Philipp Arras committed
129
130
131
132
133
134
135
    lh_sampling_dtype : type
        Determines which dtype in data space shall be used for drawing samples
        from the metric. If the inference is based on complex data,
        lh_sampling_dtype shall be set to complex accordingly. The reason for
        the presence of this parameter is that metric of the likelihood energy
        is just an `Operator` which does not know anything about the dtype of
        the fields on which it acts. Default is float64.
136
    _local_samples : None
Philipp Arras's avatar
Philipp Arras committed
137
        Only a parameter for internal uses. Typically not to be set by users.
138

Philipp Arras's avatar
Docs    
Philipp Arras committed
139
140
141
142
143
144
145
146
    Note
    ----
    The two lists `constants` and `point_estimates` are independent from each
    other. It is possible to sample along domains which are kept constant
    during minimization and vice versa.

    See also
    --------
147
148
    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
149
150
    """

Martin Reinecke's avatar
typo    
Martin Reinecke committed
151
    def __init__(self, mean, hamiltonian, n_samples, constants=[],
Philipp Arras's avatar
Philipp Arras committed
152
                 point_estimates=[], mirror_samples=False,
153
                 napprox=0, comm=None, _local_samples=None,
154
                 lh_sampling_dtype=np.float64):
155
        super(MetricGaussianKL, self).__init__(mean)
Philipp Arras's avatar
Philipp Arras committed
156
157
158

        if not isinstance(hamiltonian, StandardHamiltonian):
            raise TypeError
159
        if hamiltonian.domain is not mean.domain:
Philipp Arras's avatar
Philipp Arras committed
160
161
162
            raise ValueError
        if not isinstance(n_samples, int):
            raise TypeError
163
164
        self._constants = tuple(constants)
        self._point_estimates = tuple(point_estimates)
Philipp Arras's avatar
Philipp Arras committed
165
        if not isinstance(mirror_samples, bool):
166
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
167

168
        self._hamiltonian = hamiltonian
Philipp Arras's avatar
Philipp Arras committed
169

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
170
        self._n_samples = int(n_samples)
171
172
        if comm is not None:
            self._comm = comm
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
173
174
175
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
            self._lo, self._hi = _shareRange(self._n_samples, ntask, rank)
176
        else:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
177
178
            self._comm = None
            self._lo, self._hi = 0, self._n_samples
179
180
181
182
183
184

        self._mirror_samples = bool(mirror_samples)
        self._n_eff_samples = self._n_samples
        if self._mirror_samples:
            self._n_eff_samples *= 2

185
        if _local_samples is None:
186
            met = hamiltonian(Linearization.make_partial_var(
187
                mean, self._point_estimates, True)).metric
Philipp Arras's avatar
Fixup    
Philipp Arras committed
188
            if napprox >= 1:
189
                met._approximation = makeOp(approximation2endo(met, napprox))
190
            _local_samples = []
191
192
193
            sseq = random.spawn_sseq(self._n_samples)
            for i in range(self._lo, self._hi):
                random.push_sseq(sseq[i])
194
195
                _local_samples.append(met.draw_sample(from_inverse=True,
                                                      dtype=lh_sampling_dtype))
196
                random.pop_sseq()
197
            _local_samples = tuple(_local_samples)
198
        else:
199
            if len(_local_samples) != self._hi-self._lo:
200
                raise ValueError("# of samples mismatch")
201
        self._local_samples = _local_samples
202
        self._lin = Linearization.make_partial_var(mean, self._constants)
203
        v, g = None, None
204
        if len(self._local_samples) == 0:  # hack if there are too many MPI tasks
205
206
207
208
            tmp = self._hamiltonian(self._lin)
            v = 0. * tmp.val.val
            g = 0. * tmp.gradient
        else:
209
            for s in self._local_samples:
210
211
212
213
214
215
216
217
218
                tmp = self._hamiltonian(self._lin+s)
                if self._mirror_samples:
                    tmp = tmp + self._hamiltonian(self._lin-s)
                if v is None:
                    v = tmp.val.val_rw()
                    g = tmp.gradient
                else:
                    v += tmp.val.val
                    g = g + tmp.gradient
Martin Reinecke's avatar
Martin Reinecke committed
219
220
        self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
        self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
221
        self._metric = None
222
        self._sampdt = lh_sampling_dtype
223
224

    def at(self, position):
225
226
        return MetricGaussianKL(
            position, self._hamiltonian, self._n_samples, self._constants,
227
            self._point_estimates, self._mirror_samples, comm=self._comm,
228
            _local_samples=self._local_samples, lh_sampling_dtype=self._sampdt)
229
230
231
232
233
234
235
236
237
238

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

    def _get_metric(self):
239
        lin = self._lin.with_want_metric()
240
        if self._metric is None:
241
            if len(self._local_samples) == 0:  # hack if there are too many MPI tasks
242
243
244
                self._metric = self._hamiltonian(lin).metric.scale(0.)
            else:
                mymap = map(lambda v: self._hamiltonian(lin+v).metric,
245
                            self._local_samples)
Martin Reinecke's avatar
fix    
Martin Reinecke committed
246
247
248
                unscaled_metric = utilities.my_sum(mymap)
                if self._mirror_samples:
                    mymap = map(lambda v: self._hamiltonian(lin-v).metric,
249
                            self._local_samples)
Martin Reinecke's avatar
fix    
Martin Reinecke committed
250
251
                    unscaled_metric = unscaled_metric + utilities.my_sum(mymap)
                self._metric = unscaled_metric.scale(1./self._n_eff_samples)
252
253
254

    def apply_metric(self, x):
        self._get_metric()
Martin Reinecke's avatar
Martin Reinecke committed
255
        return _allreduce_sum_field(self._comm, self._metric(x))
256
257
258

    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
259
        return _KLMetric(self)
260
261
262

    @property
    def samples(self):
263
264
265
266
267
        if self._comm is None:
            for s in self._local_samples:
                yield s
                if self._mirror_samples:
                    yield -s
268
        else:
269
270
271
272
273
274
275
276
277
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
            for i in range(self._n_samples):
                itask = _getTask(i, self._n_samples, ntask)
                data = self._local_samples[i-self._lo] if rank == itask else None
                s = self._comm.bcast(data, root=itask)
                yield s
                if self._mirror_samples:
                    yield -s
278

Martin Reinecke's avatar
Martin Reinecke committed
279
    def _metric_sample(self, from_inverse=False, dtype=np.float64):
280
281
282
283
        if from_inverse:
            raise NotImplementedError()
        lin = self._lin.with_want_metric()
        samp = full(self._hamiltonian.domain, 0.)
Philipp Arras's avatar
Fixups    
Philipp Arras committed
284
        sseq = random.spawn_sseq(self._n_samples)
285
        for i, v in enumerate(self._local_samples):
286
287
288
289
290
            random.push_sseq(sseq[self._lo+i])
            samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False, dtype=dtype)
            if self._mirror_samples:
                samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False, dtype=dtype)
            random.pop_sseq()
Martin Reinecke's avatar
Martin Reinecke committed
291
        return _allreduce_sum_field(self._comm, samp)/self._n_eff_samples