variational_models.py 11.5 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19

20
21
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
Philipp Arras's avatar
Philipp Arras committed
22
23
from ..field import Field
from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
24
from ..minimization.energy_adapter import StochasticEnergyAdapter
Philipp Arras's avatar
Philipp Arras committed
25
26
27
28
29
30
from ..multi_field import MultiField
from ..operators.einsum import MultiLinearEinsum
from ..operators.energy_operators import EnergyOperator
from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator
Philipp Frank's avatar
Philipp Frank committed
31
from ..operators.simple_linear_operators import FieldAdapter
Philipp Arras's avatar
Philipp Arras committed
32
from ..sugar import from_random, full, is_fieldlike, makeDomain, makeField
Philipp Frank's avatar
cleanup    
Philipp Frank committed
33
34
35
from ..utilities import myassert


Philipp Frank's avatar
Philipp Frank committed
36
class MeanFieldVI:
Philipp Arras's avatar
Docs    
Philipp Arras committed
37
38
39
    """Collect the operators required for Gaussian meanfield variational
    inference.

Jakob Knollmüller's avatar
Jakob Knollmüller committed
40
41
42
43
44
45
46
47
48
49
    Gaussian meanfield variational inference approximates some target
    distribution with a Gaussian distribution with a diagonal covariance
    matrix. The parameters of the approximation, in this case the mean and
    standard deviation, are obtained by minimizing a stochastic estimate 
    of the  Kullback-Leibler divergence between the target and the approximation.
    In order to obtain gradients w.r.t the parameters, the reparametrization
    trick is employed, which separates the stochastic part of the approximation
    from a deterministic function, the generator. Samples from the approximation
    are drawn by processing samples from a standard Gaussian through this generator.

Philipp Arras's avatar
Docs    
Philipp Arras committed
50
51
    Parameters
    ----------
Jakob Knollmüller's avatar
Jakob Knollmüller committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    position : Field
        The initial estimate of the approximate mean parameter.
    hamiltonian : Energy
        Hamiltonian of the approximated probability distribution.
    n_samples : int
        Number of samples used to stochastically estimate the KL.
    mirror_samples : bool
        Whether the negative of the drawn samples are also used, as they are
        equally legitimate samples. If true, the number of used samples
        doubles. Mirroring samples stabilizes the KL estimate as extreme
        sample variation is counterbalanced. Since it improves stability in
        many cases, it is recommended to set `mirror_samples` to `True`. 
    initial_sig : positive Field or positive float
        The initial estimate of the standard deviation.
    comm : MPI communicator or None
        If not None, samples will be distributed as evenly as possible
        across this communicator. If `mirror_samples` is set, then a sample and
        its mirror image will always reside on the same task.
Jakob Knollmüller's avatar
fix    
Jakob Knollmüller committed
70
    nanisinf : bool
Jakob Knollmüller's avatar
Jakob Knollmüller committed
71
72
73
74
        If true, nan energies which can happen due to overflows in the forward
        model are interpreted as inf. Thereby, the code does not crash on
        these occasions but rather the minimizer is told that the position it
        has tried is not sensible.
Philipp Arras's avatar
Docs    
Philipp Arras committed
75
76
    """
    def __init__(self, position, hamiltonian, n_samples, mirror_samples,
Philipp Frank's avatar
Philipp Frank committed
77
                 initial_sig=1, comm=None, nanisinf=False):
Philipp Arras's avatar
Docs    
Philipp Arras committed
78
        Flat = Multifield2Vector(position.domain)
Philipp Frank's avatar
Philipp Frank committed
79
        self._std = FieldAdapter(Flat.target, 'std').absolute()
Philipp Frank's avatar
Philipp Frank committed
80
        latent = FieldAdapter(Flat.target,'latent')
Philipp Frank's avatar
Philipp Frank committed
81
82
83
84
85
        self._mean = FieldAdapter(Flat.target, 'mean')
        self._generator = Flat.adjoint(self._mean + self._std * latent)
        self._entropy = GaussianEntropy(self._std.target) @ self._std
        self._mean = Flat.adjoint @ self._mean
        self._std = Flat.adjoint @ self._std
Philipp Arras's avatar
Docs    
Philipp Arras committed
86
        pos = {'mean': Flat(position)}
Philipp Frank's avatar
Philipp Frank committed
87
        if is_fieldlike(initial_sig):
Philipp Frank's avatar
Philipp Frank committed
88
            pos['std'] = Flat(initial_sig)
Philipp Frank's avatar
Philipp Frank committed
89
        else:
Philipp Frank's avatar
Philipp Frank committed
90
            pos['std'] = full(Flat.target, initial_sig)
Philipp Frank's avatar
Philipp Frank committed
91
        pos = MultiField.from_dict(pos)
Philipp Arras's avatar
Philipp Arras committed
92
93
        op = hamiltonian(self._generator) + self._entropy
        self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
Philipp Frank's avatar
Philipp Frank committed
94
95
                                    mirror_samples, nanisinf=nanisinf, comm=comm)
        self._samdom = latent.domain
Philipp Frank's avatar
Philipp Frank committed
96
97

    @property
Philipp Frank's avatar
Philipp Frank committed
98
    def mean(self):
Philipp Arras's avatar
Philipp Arras committed
99
        return self._mean.force(self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
100
101
102

    @property
    def std(self):
Philipp Arras's avatar
Philipp Arras committed
103
        return self._std.force(self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
104
105
106

    @property
    def entropy(self):
Philipp Arras's avatar
Philipp Arras committed
107
        return self._entropy.force(self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
108
109
110
111
112

    def draw_sample(self):
        _, op = self._generator.simplify_for_constant_input(
                from_random(self._samdom))
        return op(self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
113
114
115
116

    def minimize(self, minimizer):
        self._KL, _ = minimizer(self._KL)

Philipp Arras's avatar
Docs    
Philipp Arras committed
117

Philipp Frank's avatar
Philipp Frank committed
118
class FullCovarianceVI:
Philipp Arras's avatar
Docs    
Philipp Arras committed
119
    """Collect the operators required for Gaussian full-covariance variational
Jakob Knollmüller's avatar
Jakob Knollmüller committed
120
121
122
123
124
125
126
127
128
129
130

    Gaussian meanfield variational inference approximates some target
    distribution with a Gaussian distribution with a diagonal covariance
    matrix. The parameters of the approximation, in this case the mean and
    a lower triangular matrix corresponding to a Cholesky decomposition of the covariance,
    are obtained by minimizing a stochastic estimate of the  Kullback-Leibler divergence 
    between the target and the approximation.
    In order to obtain gradients w.r.t the parameters, the reparametrization
    trick is employed, which separates the stochastic part of the approximation
    from a deterministic function, the generator. Samples from the approximation
    are drawn by processing samples from a standard Gaussian through this generator.
Jakob Knollmüller's avatar
fix    
Jakob Knollmüller committed
131

Jakob Knollmüller's avatar
Jakob Knollmüller committed
132
133
    Note that the size of the covariance scales quadratically with the number of model
    parameters.
Philipp Arras's avatar
Docs    
Philipp Arras committed
134
135
136

    Parameters
    ----------
Jakob Knollmüller's avatar
Jakob Knollmüller committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    position : Field
        The initial estimate of the approximate mean parameter.
    hamiltonian : Energy
        Hamiltonian of the approximated probability distribution.
    n_samples : int
        Number of samples used to stochastically estimate the KL.
    mirror_samples : bool
        Whether the negative of the drawn samples are also used, as they are
        equally legitimate samples. If true, the number of used samples
        doubles. Mirroring samples stabilizes the KL estimate as extreme
        sample variation is counterbalanced. Since it improves stability in
        many cases, it is recommended to set `mirror_samples` to `True`. 
    initial_sig : positive float
        The initial estimate for the standard deviation. Initially no correlation
        between the parameters is assumed. 
    comm : MPI communicator or None
        If not None, samples will be distributed as evenly as possible
        across this communicator. If `mirror_samples` is set, then a sample and
        its mirror image will always reside on the same task.
Jakob Knollmüller's avatar
fix    
Jakob Knollmüller committed
156
    nanisinf : bool
Jakob Knollmüller's avatar
Jakob Knollmüller committed
157
158
159
160
        If true, nan energies which can happen due to overflows in the forward
        model are interpreted as inf. Thereby, the code does not crash on
        these occasions but rather the minimizer is told that the position it
        has tried is not sensible.
Philipp Arras's avatar
Docs    
Philipp Arras committed
161
    """
Philipp Frank's avatar
Philipp Frank committed
162
    def __init__(self, position, hamiltonian, n_samples, mirror_samples,
Philipp Frank's avatar
Philipp Frank committed
163
                initial_sig=1, comm=None, nanisinf=False):
Philipp Frank's avatar
Philipp Frank committed
164
165
166
        Flat = Multifield2Vector(position.domain)
        flat_domain = Flat.target[0]
        mat_space = DomainTuple.make((flat_domain,flat_domain))
Philipp Frank's avatar
Philipp Frank committed
167
        lat = FieldAdapter(Flat.target,'latent')
Philipp Frank's avatar
cleanup    
Philipp Frank committed
168
169
        LT = LowerTriangularInserter(mat_space)
        tri = FieldAdapter(LT.domain, 'cov')
Philipp Frank's avatar
Philipp Frank committed
170
        mean = FieldAdapter(flat_domain,'mean')
171
        cov = LT @ tri
Philipp Frank's avatar
Philipp Frank committed
172
173
174
175
176
177
        matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co')
        MatMult = MultiLinearEinsum(matmul_setup.target,'ij,j->i',
                                    key_order=('co','latent'))

        self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup)

Philipp Frank's avatar
cleanup    
Philipp Frank committed
178
        diag_cov = (DiagonalSelector(cov.target) @ cov).absolute()
Philipp Frank's avatar
Philipp Frank committed
179
180
181
        self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov
        diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig))
        pos = MultiField.from_dict(
Philipp Frank's avatar
cleanup    
Philipp Frank committed
182
183
                {'mean': Flat(position),
                 'cov': LT.adjoint(makeField(mat_space, diag_tri))})
Philipp Arras's avatar
Philipp Arras committed
184
185
        op = hamiltonian(self._generator) + self._entropy
        self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
Philipp Frank's avatar
Philipp Frank committed
186
187
188
189
190
191
                                    mirror_samples, nanisinf=nanisinf, comm=comm)
        self._mean = Flat.adjoint @ mean
        self._samdom = lat.domain

    @property
    def mean(self):
Philipp Arras's avatar
Philipp Arras committed
192
        return self._mean.force(self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
193
194

    @property
Philipp Frank's avatar
Philipp Frank committed
195
    def entropy(self):
Philipp Arras's avatar
Philipp Arras committed
196
        return self._entropy.force(self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
197
198
199
200
201

    def draw_sample(self):
        _, op = self._generator.simplify_for_constant_input(
                from_random(self._samdom))
        return op(self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
202
203
204

    def minimize(self, minimizer):
        self._KL, _ = minimizer(self._KL)
205
206
207


class GaussianEntropy(EnergyOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
208
209
    """Entropy of a Gaussian distribution given the diagonal of a triangular
    decomposition of the covariance.
Jakob Knollmüller's avatar
Jakob Knollmüller committed
210
211
212

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
213
    domain: Domain FIXME
Philipp Arras's avatar
Philipp Arras committed
214
215
216
        The domain of the diagonal.
    """

217
    def __init__(self, domain):
Philipp Arras's avatar
Philipp Arras committed
218
        self._domain = DomainTuple.make(domain)
219
220
221

    def apply(self, x):
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
222
        res = (x*x).scale(2*np.pi*np.e).log().sum().scale(-0.5)
223
        if not isinstance(x, Linearization):
Jakob Knollmüller's avatar
tests    
Jakob Knollmüller committed
224
            return res
225
226
        if not x.want_metric:
            return res
Philipp Arras's avatar
Philipp Arras committed
227
228
        # FIXME not sure about metric
        return res.add_metric(SandwichOperator.make(res.jac))
229
230


Philipp Frank's avatar
cleanup    
Philipp Frank committed
231
class LowerTriangularInserter(LinearOperator):
Philipp Arras's avatar
Docs    
Philipp Arras committed
232
    """Insert the entries of a lower triangular matrix into a matrix.
Philipp Arras's avatar
Philipp Arras committed
233

Jakob Knollmüller's avatar
Jakob Knollmüller committed
234
235
    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
236
    target: Domain FIXME
Philipp Arras's avatar
Philipp Arras committed
237
238
239
        A two-dimensional domain with NxN entries.
    """

Philipp Frank's avatar
cleanup    
Philipp Frank committed
240
241
242
243
244
245
    def __init__(self, target):
        myassert(len(target.shape) == 2)
        myassert(target.shape[0] == target.shape[1])
        self._target = makeDomain(target)
        ndof = (target.shape[0]*(target.shape[0]+1))//2
        self._domain = makeDomain(UnstructuredDomain(ndof))
Philipp Arras's avatar
Philipp Arras committed
246
        self._indices = np.tril_indices(target.shape[0])
247
248
249
        self._capability = self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
Philipp Arras's avatar
Philipp Arras committed
250
251
        self._check_input(x, mode)
        x = x.val
252
        if mode == self.TIMES:
Philipp Arras's avatar
Philipp Arras committed
253
254
255
256
257
258
            res = np.zeros(self._target.shape)
            res[self._indices] = x
        else:
            res = x[self._indices].reshape(self._domain.shape)
        return makeField(self._tgt(mode), res)

259
260

class DiagonalSelector(LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
261
    """Extract the diagonal of a two-dimensional field.
Jakob Knollmüller's avatar
Jakob Knollmüller committed
262
263
264

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
265
    domain: Domain FIXME
Philipp Frank's avatar
cleanup    
Philipp Frank committed
266
        The two-dimensional domain of the input field. Must be of shape NxN.
Philipp Arras's avatar
Philipp Arras committed
267
268
    """

Philipp Frank's avatar
cleanup    
Philipp Frank committed
269
270
271
272
273
    def __init__(self, domain):
        myassert(len(domain.shape) == 2)
        myassert(domain.shape[0] == domain.shape[1])
        self._domain = makeDomain(domain)
        self._target = makeDomain(UnstructuredDomain(domain.shape[0]))
274
275
        self._capability = self.TIMES | self.ADJOINT_TIMES

Philipp Arras's avatar
Philipp Arras committed
276
277
    def apply(self, x, mode):
        self._check_input(x, mode)
Philipp Frank's avatar
cleanup    
Philipp Frank committed
278
        return makeField(self._tgt(mode), np.diag(x.val))