metric_gaussian_kl.py 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19
20
21

from .. import random, utilities
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
22
from ..logger import logger
23
from ..multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
24
25
26
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..probing import approximation2endo
27
from ..sugar import makeOp
Philipp Arras's avatar
Philipp Arras committed
28
from .energy import Energy
29
30


Martin Reinecke's avatar
Martin Reinecke committed
31
class _KLMetric(EndomorphicOperator):
32
33
34
35
36
37
38
39
40
    def __init__(self, KL):
        self._KL = KL
        self._capability = self.TIMES | self.ADJOINT_TIMES
        self._domain = KL.position.domain

    def apply(self, x, mode):
        self._check_input(x, mode)
        return self._KL.apply_metric(x)

Philipp Arras's avatar
Philipp Arras committed
41
42
    def draw_sample(self, from_inverse=False):
        return self._KL._metric_sample(from_inverse)
43
44


45
46
47
48
49
def _get_lo_hi(comm, n_samples):
    ntask, rank, _ = utilities.get_MPI_params_from_comm(comm)
    return utilities.shareRange(n_samples, ntask, rank)


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def _modify_sample_domain(sample, domain):
    """Takes only keys from sample which are also in domain and inserts zeros
    in sample if key is not in domain."""
    from ..multi_domain import MultiDomain
    if not isinstance(sample, MultiField):
        assert sample.domain is domain
        return sample
    assert isinstance(domain, MultiDomain)
    if sample.domain is domain:
        return sample
    out = {kk: vv for kk, vv in sample.items() if kk in domain.keys()}
    out = MultiField.from_dict(out, domain)
    assert domain is out.domain
    return out


66
class MetricGaussianKL(Energy):
Martin Reinecke's avatar
Martin Reinecke committed
67
68
69
    """Provides the sampled Kullback-Leibler divergence between a distribution
    and a Metric Gaussian.

Philipp Arras's avatar
Docs    
Philipp Arras committed
70
71
72
73
    A Metric Gaussian is used to approximate another probability distribution.
    It is a Gaussian distribution that uses the Fisher information metric of
    the other distribution at the location of its mean to approximate the
    variance. In order to infer the mean, a stochastic estimate of the
Martin Reinecke's avatar
Martin Reinecke committed
74
    Kullback-Leibler divergence is minimized. This estimate is obtained by
Philipp Arras's avatar
Docs    
Philipp Arras committed
75
76
77
78
79
    sampling the Metric Gaussian at the current mean. During minimization
    these samples are kept constant; only the mean is updated. Due to the
    typically nonlinear structure of the true distribution these samples have
    to be updated eventually by intantiating `MetricGaussianKL` again. For the
    true probability distribution the standard parametrization is assumed.
80
    The samples of this class can be distributed among MPI tasks.
81

82
83
    Notes
    -----
Philipp Arras's avatar
Docs    
Philipp Arras committed
84

85
86
    DomainTuples should never be created using the constructor, but rather
    via the factory function :attr:`make`!
Philipp Arras's avatar
Docs    
Philipp Arras committed
87
88
    See also
    --------
89
90
    `Metric Gaussian Variational Inference`, Jakob Knollmüller,
    Torsten A. Enßlin, `<https://arxiv.org/abs/1901.11033>`_
91
    """
92
93
94
95
    def __init__(self, mean, hamiltonian, n_samples, mirror_samples, comm,
                 local_samples, nanisinf, _callingfrommake=False):
        if not _callingfrommake:
            raise NotImplementedError
96
        super(MetricGaussianKL, self).__init__(mean)
97
        assert mean.domain is hamiltonian.domain
98
99
100
101
102
103
104
105
106
107
        self._hamiltonian = hamiltonian
        self._n_samples = int(n_samples)
        self._mirror_samples = bool(mirror_samples)
        self._comm = comm
        self._local_samples = local_samples
        self._nanisinf = bool(nanisinf)

        lin = Linearization.make_var(mean)
        v, g = [], []
        for s in self._local_samples:
108
            s = _modify_sample_domain(s, mean.domain)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            tmp = hamiltonian(lin+s)
            tv = tmp.val.val
            tg = tmp.gradient
            if mirror_samples:
                tmp = hamiltonian(lin-s)
                tv = tv + tmp.val.val
                tg = tg + tmp.gradient
            v.append(tv)
            g.append(tg)
        self._val = utilities.allreduce_sum(v, self._comm)[()]/self.n_eff_samples
        if np.isnan(self._val) and self._nanisinf:
            self._val = np.inf
        self._grad = utilities.allreduce_sum(g, self._comm)/self.n_eff_samples

    @staticmethod
    def make(mean, hamiltonian, n_samples, constants=[], point_estimates=[],
             mirror_samples=False, napprox=0, comm=None, nanisinf=False):
        """Return instance of :class:`MetricGaussianKL`.

        Parameters
        ----------
        mean : Field
            Mean of the Gaussian probability distribution.
        hamiltonian : StandardHamiltonian
            Hamiltonian of the approximated probability distribution.
        n_samples : integer
            Number of samples used to stochastically estimate the KL.
        constants : list
            List of parameter keys that are kept constant during optimization.
            Default is no constants.
        point_estimates : list
            List of parameter keys for which no samples are drawn, but that are
            (possibly) optimized for, corresponding to point estimates of these.
            Default is to draw samples for the complete domain.
        mirror_samples : boolean
            Whether the negative of the drawn samples are also used,
            as they are equally legitimate samples. If true, the number of used
            samples doubles. Mirroring samples stabilizes the KL estimate as
            extreme sample variation is counterbalanced. Default is False.
        napprox : int
            Number of samples for computing preconditioner for sampling. No
            preconditioning is done by default.
        comm : MPI communicator or None
            If not None, samples will be distributed as evenly as possible
            across this communicator. If `mirror_samples` is set, then a sample and
            its mirror image will always reside on the same task.
        nanisinf : bool
            If true, nan energies which can happen due to overflows in the forward
            model are interpreted as inf. Thereby, the code does not crash on
            these occaisions but rather the minimizer is told that the position it
            has tried is not sensible.

        Note
        ----
        The two lists `constants` and `point_estimates` are independent from each
        other. It is possible to sample along domains which are kept constant
        during minimization and vice versa.
        """
Philipp Arras's avatar
Philipp Arras committed
167
168
169

        if not isinstance(hamiltonian, StandardHamiltonian):
            raise TypeError
170
        if hamiltonian.domain is not mean.domain:
Philipp Arras's avatar
Philipp Arras committed
171
172
173
174
            raise ValueError
        if not isinstance(n_samples, int):
            raise TypeError
        if not isinstance(mirror_samples, bool):
175
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
176
        if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
Philipp Arras's avatar
Philipp Arras committed
177
178
            raise RuntimeError(
                'Point estimates for whole domain. Use EnergyAdapter instead.')
179
180
        n_samples = int(n_samples)
        mirror_samples = bool(mirror_samples)
Philipp Arras's avatar
Philipp Arras committed
181

182
183
184
        if isinstance(mean, MultiField):
            cstpos = mean.extract_by_keys(point_estimates)
            _, ham_sampling = hamiltonian.simplify_for_constant_input(cstpos)
185
        else:
186
            ham_sampling = hamiltonian
187
188
        lin = Linearization.make_var(mean.extract(ham_sampling.domain), True)
        met = ham_sampling(lin).metric
189
190
191
192
193
194
195
196
197
198
199
        if napprox >= 1:
            met._approximation = makeOp(approximation2endo(met, napprox))
        local_samples = []
        sseq = random.spawn_sseq(n_samples)
        for i in range(*_get_lo_hi(comm, n_samples)):
            with random.Context(sseq[i]):
                local_samples.append(met.draw_sample(from_inverse=True))
        local_samples = tuple(local_samples)

        if isinstance(mean, MultiField):
            _, hamiltonian = hamiltonian.simplify_for_constant_input(mean.extract_by_keys(constants))
200
            mean = mean.extract_by_keys(set(mean.keys()) - set(constants))
201
202
203
        return MetricGaussianKL(
            mean, hamiltonian, n_samples, mirror_samples, comm, local_samples,
            nanisinf, _callingfrommake=True)
204
205

    def at(self, position):
206
        return MetricGaussianKL(
207
208
            position, self._hamiltonian, self._n_samples, self._mirror_samples,
            self._comm, self._local_samples, self._nanisinf, True)
209
210
211
212
213
214
215
216
217
218

    @property
    def value(self):
        return self._val

    @property
    def gradient(self):
        return self._grad

    def apply_metric(self, x):
219
        lin = Linearization.make_var(self.position, want_metric=True)
220
        res = []
221
        for s in self._local_samples:
222
            s = _modify_sample_domain(s, self._hamiltonian.domain)
223
224
225
226
            tmp = self._hamiltonian(lin+s).metric(x)
            if self._mirror_samples:
                tmp = tmp + self._hamiltonian(lin-s).metric(x)
            res.append(tmp)
227
228
229
230
231
232
233
        return utilities.allreduce_sum(res, self._comm)/self.n_eff_samples

    @property
    def n_eff_samples(self):
        if self._mirror_samples:
            return 2*self._n_samples
        return self._n_samples
234
235
236

    @property
    def metric(self):
Martin Reinecke's avatar
Martin Reinecke committed
237
        return _KLMetric(self)
238
239
240

    @property
    def samples(self):
Martin Reinecke's avatar
Martin Reinecke committed
241
242
        ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
        if ntask == 1:
243
244
245
246
            for s in self._local_samples:
                yield s
                if self._mirror_samples:
                    yield -s
247
        else:
248
            rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)]
249
            lo, _ = _get_lo_hi(self._comm, self._n_samples)
250
251
            for itask, (l, h) in enumerate(rank_lo_hi):
                for i in range(l, h):
252
                    data = self._local_samples[i-lo] if rank == itask else None
253
254
255
256
                    s = self._comm.bcast(data, root=itask)
                    yield s
                    if self._mirror_samples:
                        yield -s
257

Philipp Arras's avatar
Philipp Arras committed
258
    def _metric_sample(self, from_inverse=False):
259
260
        if from_inverse:
            raise NotImplementedError()
Philipp Arras's avatar
Philipp Arras committed
261
262
263
264
        s = ('This draws from the Hamiltonian used for evaluation and does '
             ' not take point_estimates into accout. Make sure that this '
             'is your intended use.')
        logger.warning(s)
265
        lin = Linearization.make_var(self.position, True)
266
        samp = []
Philipp Arras's avatar
Fixups    
Philipp Arras committed
267
        sseq = random.spawn_sseq(self._n_samples)
268
269
        for i, s in enumerate(self._local_samples):
            s = _modify_sample_domain(s, self._hamiltonian.domain)
Martin Reinecke's avatar
Martin Reinecke committed
270
            with random.Context(sseq[self._lo+i]):
271
                tmp = self._hamiltonian(lin+s).metric.draw_sample(from_inverse=False)
Martin Reinecke's avatar
Martin Reinecke committed
272
                if self._mirror_samples:
273
                    tmp = tmp + self._hamiltonian(lin-s).metric.draw_sample(from_inverse=False)
274
                samp.append(tmp)
275
        return utilities.allreduce_sum(samp, self._comm)/self.n_eff_samples