variational_models.py 7.38 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

18
import numpy as np
Philipp Arras's avatar
Philipp Arras committed
19

20
21
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
Philipp Arras's avatar
Philipp Arras committed
22
23
24
25
26
27
28
29
from ..field import Field
from ..linearization import Linearization
from ..multi_field import MultiField
from ..operators.einsum import MultiLinearEinsum
from ..operators.energy_operators import EnergyOperator
from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator
Philipp Frank's avatar
Philipp Frank committed
30
from ..operators.simple_linear_operators import FieldAdapter
Philipp Frank's avatar
cleanup    
Philipp Frank committed
31
from ..sugar import full, makeField, makeDomain, from_random, is_fieldlike
Philipp Frank's avatar
Philipp Frank committed
32
from ..minimization.energy_adapter import StochasticEnergyAdapter
Philipp Frank's avatar
cleanup    
Philipp Frank committed
33
34
35
36
37
from ..utilities import myassert


def _eval(op, position):
    return op(position.extract(op.domain))
Philipp Arras's avatar
Philipp Arras committed
38

39

Philipp Frank's avatar
Philipp Frank committed
40
41
42
class MeanFieldVI:
    def __init__(self, initial_position, hamiltonian, n_samples, mirror_samples,
                 initial_sig=1, comm=None, nanisinf=False):
Philipp Frank's avatar
Philipp Frank committed
43
44
        """Collect the operators required for Gaussian mean-field variational
        inference.
Philipp Arras's avatar
Philipp Arras committed
45
        """
Philipp Frank's avatar
Philipp Frank committed
46
47
        Flat = Multifield2Vector(initial_position.domain)
        self._std = FieldAdapter(Flat.target, 'std').absolute()
Philipp Frank's avatar
Philipp Frank committed
48
        latent = FieldAdapter(Flat.target,'latent')
Philipp Frank's avatar
Philipp Frank committed
49
50
51
52
53
54
        self._mean = FieldAdapter(Flat.target, 'mean')
        self._generator = Flat.adjoint(self._mean + self._std * latent)
        self._entropy = GaussianEntropy(self._std.target) @ self._std
        self._mean = Flat.adjoint @ self._mean
        self._std = Flat.adjoint @ self._std
        pos = {'mean': Flat(initial_position)}
Philipp Frank's avatar
Philipp Frank committed
55
        if is_fieldlike(initial_sig):
Philipp Frank's avatar
Philipp Frank committed
56
            pos['std'] = Flat(initial_sig)
Philipp Frank's avatar
Philipp Frank committed
57
        else:
Philipp Frank's avatar
Philipp Frank committed
58
            pos['std'] = full(Flat.target, initial_sig)
Philipp Frank's avatar
Philipp Frank committed
59
        pos = MultiField.from_dict(pos)
Philipp Frank's avatar
Philipp Frank committed
60
61
62
63
        op = hamiltonian(self._generator) + self._entropy
        self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
                                    mirror_samples, nanisinf=nanisinf, comm=comm)
        self._samdom = latent.domain
Philipp Frank's avatar
Philipp Frank committed
64
65

    @property
Philipp Frank's avatar
Philipp Frank committed
66
    def mean(self):
Philipp Frank's avatar
cleanup    
Philipp Frank committed
67
        return _eval(self._mean,self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
68
69
70

    @property
    def std(self):
Philipp Frank's avatar
cleanup    
Philipp Frank committed
71
        return _eval(self._std,self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
72
73
74

    @property
    def entropy(self):
Philipp Frank's avatar
cleanup    
Philipp Frank committed
75
        return _eval(self._entropy,self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
76
77
78
79
80

    def draw_sample(self):
        _, op = self._generator.simplify_for_constant_input(
                from_random(self._samdom))
        return op(self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
81
82
83
84

    def minimize(self, minimizer):
        self._KL, _ = minimizer(self._KL)

Philipp Frank's avatar
Philipp Frank committed
85
class FullCovarianceVI:
Philipp Frank's avatar
Philipp Frank committed
86
    def __init__(self, position, hamiltonian, n_samples, mirror_samples,
Philipp Frank's avatar
Philipp Frank committed
87
                initial_sig=1, comm=None, nanisinf=False):
Philipp Frank's avatar
Philipp Frank committed
88
89
90
91
92
93
        """Collect the operators required for Gaussian full-covariance variational
        inference.
        """
        Flat = Multifield2Vector(position.domain)
        flat_domain = Flat.target[0]
        mat_space = DomainTuple.make((flat_domain,flat_domain))
Philipp Frank's avatar
Philipp Frank committed
94
        lat = FieldAdapter(Flat.target,'latent')
Philipp Frank's avatar
cleanup    
Philipp Frank committed
95
96
        LT = LowerTriangularInserter(mat_space)
        tri = FieldAdapter(LT.domain, 'cov')
Philipp Frank's avatar
Philipp Frank committed
97
        mean = FieldAdapter(flat_domain,'mean')
98
        cov = LT @ tri
Philipp Frank's avatar
Philipp Frank committed
99
100
101
102
103
104
        matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co')
        MatMult = MultiLinearEinsum(matmul_setup.target,'ij,j->i',
                                    key_order=('co','latent'))

        self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup)

Philipp Frank's avatar
cleanup    
Philipp Frank committed
105
        diag_cov = (DiagonalSelector(cov.target) @ cov).absolute()
Philipp Frank's avatar
Philipp Frank committed
106
107
108
        self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov
        diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig))
        pos = MultiField.from_dict(
Philipp Frank's avatar
cleanup    
Philipp Frank committed
109
110
                {'mean': Flat(position),
                 'cov': LT.adjoint(makeField(mat_space, diag_tri))})
Philipp Frank's avatar
Philipp Frank committed
111
112
113
114
115
116
117
118
        op = hamiltonian(self._generator) + self._entropy
        self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
                                    mirror_samples, nanisinf=nanisinf, comm=comm)
        self._mean = Flat.adjoint @ mean
        self._samdom = lat.domain

    @property
    def mean(self):
Philipp Frank's avatar
cleanup    
Philipp Frank committed
119
        return _eval(self._mean,self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
120
121

    @property
Philipp Frank's avatar
Philipp Frank committed
122
    def entropy(self):
Philipp Frank's avatar
cleanup    
Philipp Frank committed
123
        return _eval(self._entropy,self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
124
125
126
127
128

    def draw_sample(self):
        _, op = self._generator.simplify_for_constant_input(
                from_random(self._samdom))
        return op(self._KL.position)
Philipp Frank's avatar
Philipp Frank committed
129
130
131

    def minimize(self, minimizer):
        self._KL, _ = minimizer(self._KL)
132
133
134


class GaussianEntropy(EnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
135
136
    """Calculate the entropy of a Gaussian distribution given the diagonal of a
    triangular decomposition of the covariance.
Jakob Knollmüller's avatar
Jakob Knollmüller committed
137
138
139
140

    Parameters
    ----------
    domain: Domain
Philipp Arras's avatar
Philipp Arras committed
141
142
143
        The domain of the diagonal.
    """

144
145
146
147
148
    def __init__(self, domain):
        self._domain = domain

    def apply(self, x):
        self._check_input(x)
Philipp Arras's avatar
Philipp Arras committed
149
        res = -0.5*(2*np.pi*np.e*x**2).log().sum()
150
151
152
153
        if not isinstance(x, Linearization):
            return Field.scalar(res)
        if not x.want_metric:
            return res
Philipp Arras's avatar
Philipp Arras committed
154
155
        # FIXME not sure about metric
        return res.add_metric(SandwichOperator.make(res.jac))
156
157


Philipp Frank's avatar
cleanup    
Philipp Frank committed
158
159
class LowerTriangularInserter(LinearOperator):
    """Inserts the DOFs of a lower triangular matrix into a matrix.
Philipp Arras's avatar
Philipp Arras committed
160

Jakob Knollmüller's avatar
Jakob Knollmüller committed
161
162
163
    Parameters
    ----------
    target: Domain
Philipp Arras's avatar
Philipp Arras committed
164
165
166
        A two-dimensional domain with NxN entries.
    """

Philipp Frank's avatar
cleanup    
Philipp Frank committed
167
168
169
170
171
172
    def __init__(self, target):
        myassert(len(target.shape) == 2)
        myassert(target.shape[0] == target.shape[1])
        self._target = makeDomain(target)
        ndof = (target.shape[0]*(target.shape[0]+1))//2
        self._domain = makeDomain(UnstructuredDomain(ndof))
Philipp Arras's avatar
Philipp Arras committed
173
        self._indices = np.tril_indices(target.shape[0])
174
175
176
        self._capability = self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
Philipp Arras's avatar
Philipp Arras committed
177
178
        self._check_input(x, mode)
        x = x.val
179
        if mode == self.TIMES:
Philipp Arras's avatar
Philipp Arras committed
180
181
182
183
184
185
            res = np.zeros(self._target.shape)
            res[self._indices] = x
        else:
            res = x[self._indices].reshape(self._domain.shape)
        return makeField(self._tgt(mode), res)

186
187

class DiagonalSelector(LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
188
    """Extract the diagonal of a two-dimensional field.
Jakob Knollmüller's avatar
Jakob Knollmüller committed
189
190
191
192

    Parameters
    ----------
    domain: Domain
Philipp Frank's avatar
cleanup    
Philipp Frank committed
193
        The two-dimensional domain of the input field. Must be of shape NxN.
Philipp Arras's avatar
Philipp Arras committed
194
195
    """

Philipp Frank's avatar
cleanup    
Philipp Frank committed
196
197
198
199
200
    def __init__(self, domain):
        myassert(len(domain.shape) == 2)
        myassert(domain.shape[0] == domain.shape[1])
        self._domain = makeDomain(domain)
        self._target = makeDomain(UnstructuredDomain(domain.shape[0]))
201
202
        self._capability = self.TIMES | self.ADJOINT_TIMES

Philipp Arras's avatar
Philipp Arras committed
203
204
    def apply(self, x, mode):
        self._check_input(x, mode)
Philipp Frank's avatar
cleanup    
Philipp Frank committed
205
        return makeField(self._tgt(mode), np.diag(x.val))