metric_gaussian_kl.py 10.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19
20

from .. import random, utilities
21
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
22
from ..linearization import Linearization
23
from ..multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
24
25
26
27
28
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
from ..sugar import full, makeOp
from .energy import Energy
29
30
31
32
33
34
35
36
37
38


def _shareRange(nwork, nshares, myshare):
    nbase = nwork//nshares
    additional = nwork % nshares
    lo = myshare*nbase + min(myshare, additional)
    hi = lo + nbase + int(myshare < additional)
    return lo, hi


Martin Reinecke's avatar
Martin Reinecke committed
39
def _np_allreduce_sum(comm, arr):
40
41
42
43
44
45
46
47
48
    if comm is None:
        return arr
    from mpi4py import MPI
    arr = np.array(arr)
    res = np.empty_like(arr)
    comm.Allreduce(arr, res, MPI.SUM)
    return res


Martin Reinecke's avatar
Martin Reinecke committed
49
def _allreduce_sum_field(comm, fld):
50
51
52
    if comm is None:
        return fld
    if isinstance(fld, Field):
Martin Reinecke's avatar
Martin Reinecke committed
53
        return Field(fld.domain, _np_allreduce_sum(fld.val))
54
    res = tuple(
Martin Reinecke's avatar
Martin Reinecke committed
55
        Field(f.domain, _np_allreduce_sum(comm, f.val))
56
57
58
59
        for f in fld.values())
    return MultiField(fld.domain, res)


Martin Reinecke's avatar
Martin Reinecke committed
60
class _KLMetric(EndomorphicOperator):
61
62
63
64
65
66
67
68
69
70
    def __init__(self, KL):
        self._KL = KL
        self._capability = self.TIMES | self.ADJOINT_TIMES
        self._domain = KL.position.domain

    def apply(self, x, mode):
        self._check_input(x, mode)
        return self._KL.apply_metric(x)

    def draw_sample(self, from_inverse=False, dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
71
        return self._KL._metric_sample(from_inverse, dtype)
72
73
74


class MetricGaussianKL(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
75
76
77
    """Provides the sampled Kullback-Leibler divergence between a distribution
    and a Metric Gaussian.

Philipp Arras's avatar
Docs    
Philipp Arras committed
78
79
80
81
    A Metric Gaussian is used to approximate another probability distribution.
    It is a Gaussian distribution that uses the Fisher information metric of
    the other distribution at the location of its mean to approximate the
    variance. In order to infer the mean, a stochastic estimate of the
Martin Reinecke's avatar
Martin Reinecke committed
82
    Kullback-Leibler divergence is minimized. This estimate is obtained by
Philipp Arras's avatar
Docs    
Philipp Arras committed
83
84
85
86
87
    sampling the Metric Gaussian at the current mean. During minimization
    these samples are kept constant; only the mean is updated. Due to the
    typically nonlinear structure of the true distribution these samples have
    to be updated eventually by intantiating `MetricGaussianKL` again. For the
    true probability distribution the standard parametrization is assumed.
88
    The samples of this class can be distributed among MPI tasks.
89
90
91
92

    Parameters
    ----------
    mean : Field
Philipp Arras's avatar
Docs    
Philipp Arras committed
93
        Mean of the Gaussian probability distribution.
Jakob Knollmueller's avatar
Jakob Knollmueller committed
94
    hamiltonian : StandardHamiltonian
Philipp Arras's avatar
Docs    
Philipp Arras committed
95
        Hamiltonian of the approximated probability distribution.
96
    n_samples : integer
Philipp Arras's avatar
Docs    
Philipp Arras committed
97
        Number of samples used to stochastically estimate the KL.
98
    constants : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
99
100
        List of parameter keys that are kept constant during optimization.
        Default is no constants.
101
    point_estimates : list
Philipp Arras's avatar
Docs    
Philipp Arras committed
102
103
104
        List of parameter keys for which no samples are drawn, but that are
        (possibly) optimized for, corresponding to point estimates of these.
        Default is to draw samples for the complete domain.
105
106
    mirror_samples : boolean
        Whether the negative of the drawn samples are also used,
Philipp Arras's avatar
Docs    
Philipp Arras committed
107
        as they are equally legitimate samples. If true, the number of used
108
        samples doubles. Mirroring samples stabilizes the KL estimate as
Philipp Arras's avatar
Docs    
Philipp Arras committed
109
        extreme sample variation is counterbalanced. Default is False.
110
111
112
    napprox : int
        Number of samples for computing preconditioner for sampling. No
        preconditioning is done by default.
113
114
115
    comm : MPI communicator or None
        If not None, samples will be distributed as evenly as possible
        across this communicator. If `mirror_samples` is set, then a sample and
116
        its mirror image will always reside on the same task.
Philipp Arras's avatar
Philipp Arras committed
117
118
119
120
121
122
123
    lh_sampling_dtype : type
        Determines which dtype in data space shall be used for drawing samples
        from the metric. If the inference is based on complex data,
        lh_sampling_dtype shall be set to complex accordingly. The reason for
        the presence of this parameter is that metric of the likelihood energy
        is just an `Operator` which does not know anything about the dtype of
        the fields on which it acts. Default is float64.
Philipp Arras's avatar
Philipp Arras committed
124
125
    _samples : None
        Only a parameter for internal uses. Typically not to be set by users.
126

Philipp Arras's avatar
Docs    
Philipp Arras committed
127
128
129
130
131
132
133
134
    Note
    ----
    The two lists `constants` and `point_estimates` are independent from each
    other. It is possible to sample along domains which are kept constant
    during minimization and vice versa.

    See also
    --------
135
136
    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
137
138
    """

Martin Reinecke's avatar
typo    
Martin Reinecke committed
139
    def __init__(self, mean, hamiltonian, n_samples, constants=[],
Philipp Arras's avatar
Philipp Arras committed
140
                 point_estimates=[], mirror_samples=False,
141
                 napprox=0, comm=None, _samples=None,
142
                 lh_sampling_dtype=np.float64):
143
        super(MetricGaussianKL, self).__init__(mean)
Philipp Arras's avatar
Philipp Arras committed
144
145
146

        if not isinstance(hamiltonian, StandardHamiltonian):
            raise TypeError
147
        if hamiltonian.domain is not mean.domain:
Philipp Arras's avatar
Philipp Arras committed
148
149
150
            raise ValueError
        if not isinstance(n_samples, int):
            raise TypeError
151
152
        self._constants = tuple(constants)
        self._point_estimates = tuple(point_estimates)
Philipp Arras's avatar
Philipp Arras committed
153
        if not isinstance(mirror_samples, bool):
154
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
155

156
        self._hamiltonian = hamiltonian
Philipp Arras's avatar
Philipp Arras committed
157

Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
158
        self._n_samples = int(n_samples)
159
160
        if comm is not None:
            self._comm = comm
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
161
162
163
            ntask = self._comm.Get_size()
            rank = self._comm.Get_rank()
            self._lo, self._hi = _shareRange(self._n_samples, ntask, rank)
164
        else:
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
165
166
            self._comm = None
            self._lo, self._hi = 0, self._n_samples
167
168
169
170
171
172

        self._mirror_samples = bool(mirror_samples)
        self._n_eff_samples = self._n_samples
        if self._mirror_samples:
            self._n_eff_samples *= 2

173
174
        if _samples is None:
            met = hamiltonian(Linearization.make_partial_var(
175
                mean, self._point_estimates, True)).metric
Philipp Arras's avatar
Fixup    
Philipp Arras committed
176
            if napprox >= 1:
177
                met._approximation = makeOp(approximation2endo(met, napprox))
178
179
180
181
182
183
184
185
186
            _samples = []
            sseq = random.spawn_sseq(self._n_samples)
            for i in range(self._lo, self._hi):
                random.push_sseq(sseq[i])
                _samples.append(met.draw_sample(from_inverse=True,
                                                dtype=lh_sampling_dtype))
                random.pop_sseq()
            _samples = tuple(_samples)
        else:
Martin Reinecke's avatar
fix    
Martin Reinecke committed
187
            if len(_samples) != self._hi-self._lo:
188
                raise ValueError("# of samples mismatch")
189
        self._samples = _samples
190
        self._lin = Linearization.make_partial_var(mean, self._constants)
191
        v, g = None, None
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        if len(self._samples) == 0:  # hack if there are too many MPI tasks
            tmp = self._hamiltonian(self._lin)
            v = 0. * tmp.val.val
            g = 0. * tmp.gradient
        else:
            for s in self._samples:
                tmp = self._hamiltonian(self._lin+s)
                if self._mirror_samples:
                    tmp = tmp + self._hamiltonian(self._lin-s)
                if v is None:
                    v = tmp.val.val_rw()
                    g = tmp.gradient
                else:
                    v += tmp.val.val
                    g = g + tmp.gradient
Martin Reinecke's avatar
Martin Reinecke committed
207
208
        self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
        self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
209
        self._metric = None
210
        self._sampdt = lh_sampling_dtype
211
212

    def at(self, position):
213
214
        return MetricGaussianKL(
            position, self._hamiltonian, self._n_samples, self._constants,
215
            self._point_estimates, self._mirror_samples, comm=self._comm,
216
            _samples=self._samples, lh_sampling_dtype=self._sampdt)
217
218
219
220
221
222
223
224
225
226

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

    def _get_metric(self):
227
        lin = self._lin.with_want_metric()
228
        if self._metric is None:
229
230
231
232
233
234
235
            if len(self._samples) == 0:  # hack if there are too many MPI tasks
                self._metric = self._hamiltonian(lin).metric.scale(0.)
            else:
                mymap = map(lambda v: self._hamiltonian(lin+v).metric,
                            self._samples)
                self.unscaled_metric = utilities.my_sum(mymap)
                self._metric = self.unscaled_metric.scale(1./self._n_eff_samples)
236
237
238

    def apply_metric(self, x):
        self._get_metric()
Martin Reinecke's avatar
Martin Reinecke committed
239
        return _allreduce_sum_field(self._comm, self._metric(x))
240
241
242

    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
243
        return _KLMetric(self)
244
245
246

    @property
    def samples(self):
247
        if self._comm is not None:
Philipp Arras's avatar
Fixups    
Philipp Arras committed
248
            res = self._comm.allgather(self._samples)
249
            res = tuple(item for sublist in res for item in sublist)
250
251
252
253
254
255
        else:
            res = self._samples
        if self._mirror_samples:
            res = res + tuple(-item for item in res)
        return res

Martin Reinecke's avatar
Martin Reinecke committed
256
    def _unscaled_metric_sample(self, from_inverse=False, dtype=np.float64):
257
258
259
260
        if from_inverse:
            raise NotImplementedError()
        lin = self._lin.with_want_metric()
        samp = full(self._hamiltonian.domain, 0.)
Philipp Arras's avatar
Fixups    
Philipp Arras committed
261
        sseq = random.spawn_sseq(self._n_samples)
262
263
264
265
266
267
        for i, v in enumerate(self._samples):
            random.push_sseq(sseq[self._lo+i])
            samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False, dtype=dtype)
            if self._mirror_samples:
                samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False, dtype=dtype)
            random.pop_sseq()
Martin Reinecke's avatar
Martin Reinecke committed
268
        return _allreduce_sum_field(self._comm, samp)
269

Martin Reinecke's avatar
Martin Reinecke committed
270
271
    def _metric_sample(self, from_inverse=False, dtype=np.float64):
        return self._unscaled_metric_sample(from_inverse, dtype)/self._n_eff_samples