variational_models.py 7.71 KB
 Philipp Arras committed Jun 01, 2021 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ``````# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2021 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. `````` Jakob Knollmüller committed May 30, 2021 18 ``````import numpy as np `````` Philipp Arras committed Jun 01, 2021 19 `````` `````` Jakob Knollmüller committed May 30, 2021 20 21 ``````from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain `````` Philipp Arras committed Jun 01, 2021 22 23 24 25 26 27 28 29 ``````from ..field import Field from ..linearization import Linearization from ..multi_field import MultiField from ..operators.einsum import MultiLinearEinsum from ..operators.energy_operators import EnergyOperator from ..operators.linear_operator import LinearOperator from ..operators.multifield2vector import Multifield2Vector from ..operators.sandwich_operator import SandwichOperator `````` Philipp Frank committed Jun 02, 2021 30 ``````from ..operators.simple_linear_operators import FieldAdapter `````` Philipp Frank committed Jun 02, 2021 31 ``````from ..sugar import full, makeField, makeDomain, from_random, is_fieldlike `````` Philipp Frank committed Jun 02, 2021 32 ``````from ..minimization.energy_adapter import StochasticEnergyAdapter `````` Philipp Frank committed Jun 02, 2021 33 34 35 ``````from ..utilities import myassert `````` Philipp Frank committed Jun 02, 2021 36 ``````class MeanFieldVI: `````` Philipp Arras committed Jun 03, 2021 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 `````` """Collect the operators required for Gaussian meanfield variational inference. Parameters ---------- position : FIXME hamiltonian : FIXME n_samples : FIXME mirror_samples : FIXME initial_sig : FIXME comm : FIXME nanisinf : FIXME """ def __init__(self, position, hamiltonian, n_samples, mirror_samples, `````` Philipp Frank committed Jun 02, 2021 58 `````` initial_sig=1, comm=None, nanisinf=False): `````` Philipp Arras committed Jun 03, 2021 59 `````` Flat = Multifield2Vector(position.domain) `````` Philipp Frank committed Jun 02, 2021 60 `````` self._std = FieldAdapter(Flat.target, 'std').absolute() `````` Philipp Frank committed Jun 02, 2021 61 `````` latent = FieldAdapter(Flat.target,'latent') `````` Philipp Frank committed Jun 02, 2021 62 63 64 65 66 `````` self._mean = FieldAdapter(Flat.target, 'mean') self._generator = Flat.adjoint(self._mean + self._std * latent) self._entropy = GaussianEntropy(self._std.target) @ self._std self._mean = Flat.adjoint @ self._mean self._std = Flat.adjoint @ self._std `````` Philipp Arras committed Jun 03, 2021 67 `````` pos = {'mean': Flat(position)} `````` Philipp Frank committed Jun 02, 2021 68 `````` if is_fieldlike(initial_sig): `````` Philipp Frank committed Jun 02, 2021 69 `````` pos['std'] = Flat(initial_sig) `````` Philipp Frank committed Jun 02, 2021 70 `````` else: `````` Philipp Frank committed Jun 02, 2021 71 `````` pos['std'] = full(Flat.target, initial_sig) `````` Philipp Frank committed Jun 02, 2021 72 `````` pos = MultiField.from_dict(pos) `````` Philipp Frank committed Jun 02, 2021 73 74 75 76 `````` op = hamiltonian(self._generator) + self._entropy self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples, mirror_samples, nanisinf=nanisinf, comm=comm) self._samdom = latent.domain `````` Philipp Frank committed Jun 02, 2021 77 78 `````` @property `````` Philipp Frank committed Jun 02, 2021 79 `````` def mean(self): `````` Philipp Arras committed Jun 03, 2021 80 `````` return self._mean.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 81 82 83 `````` @property def std(self): `````` Philipp Arras committed Jun 03, 2021 84 `````` return self._std.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 85 86 87 `````` @property def entropy(self): `````` Philipp Arras committed Jun 03, 2021 88 `````` return self._entropy.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 89 90 91 92 93 `````` def draw_sample(self): _, op = self._generator.simplify_for_constant_input( from_random(self._samdom)) return op(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 94 95 96 97 `````` def minimize(self, minimizer): self._KL, _ = minimizer(self._KL) `````` Philipp Arras committed Jun 03, 2021 98 `````` `````` Philipp Frank committed Jun 02, 2021 99 ``````class FullCovarianceVI: `````` Philipp Arras committed Jun 03, 2021 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 `````` """Collect the operators required for Gaussian full-covariance variational inference. Parameters ---------- position : FIXME hamiltonian : FIXME n_samples : FIXME mirror_samples : FIXME initial_sig : FIXME comm : FIXME nanisinf : FIXME """ `````` Philipp Frank committed Jun 02, 2021 120 `````` def __init__(self, position, hamiltonian, n_samples, mirror_samples, `````` Philipp Frank committed Jun 02, 2021 121 `````` initial_sig=1, comm=None, nanisinf=False): `````` Philipp Frank committed Jun 02, 2021 122 123 124 `````` Flat = Multifield2Vector(position.domain) flat_domain = Flat.target[0] mat_space = DomainTuple.make((flat_domain,flat_domain)) `````` Philipp Frank committed Jun 02, 2021 125 `````` lat = FieldAdapter(Flat.target,'latent') `````` Philipp Frank committed Jun 02, 2021 126 127 `````` LT = LowerTriangularInserter(mat_space) tri = FieldAdapter(LT.domain, 'cov') `````` Philipp Frank committed Jun 02, 2021 128 `````` mean = FieldAdapter(flat_domain,'mean') `````` Jakob Knollmüller committed May 30, 2021 129 `````` cov = LT @ tri `````` Philipp Frank committed Jun 02, 2021 130 131 132 133 134 135 `````` matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co') MatMult = MultiLinearEinsum(matmul_setup.target,'ij,j->i', key_order=('co','latent')) self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup) `````` Philipp Frank committed Jun 02, 2021 136 `````` diag_cov = (DiagonalSelector(cov.target) @ cov).absolute() `````` Philipp Frank committed Jun 02, 2021 137 138 139 `````` self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig)) pos = MultiField.from_dict( `````` Philipp Frank committed Jun 02, 2021 140 141 `````` {'mean': Flat(position), 'cov': LT.adjoint(makeField(mat_space, diag_tri))}) `````` Philipp Frank committed Jun 02, 2021 142 143 144 145 146 147 148 149 `````` op = hamiltonian(self._generator) + self._entropy self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples, mirror_samples, nanisinf=nanisinf, comm=comm) self._mean = Flat.adjoint @ mean self._samdom = lat.domain @property def mean(self): `````` Philipp Arras committed Jun 03, 2021 150 `````` return self._mean.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 151 152 `````` @property `````` Philipp Frank committed Jun 02, 2021 153 `````` def entropy(self): `````` Philipp Arras committed Jun 03, 2021 154 `````` return self._entropy.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 155 156 157 158 159 `````` def draw_sample(self): _, op = self._generator.simplify_for_constant_input( from_random(self._samdom)) return op(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 160 161 162 `````` def minimize(self, minimizer): self._KL, _ = minimizer(self._KL) `````` Jakob Knollmüller committed May 30, 2021 163 164 165 `````` class GaussianEntropy(EnergyOperator): `````` Philipp Arras committed Jun 03, 2021 166 167 `````` """Entropy of a Gaussian distribution given the diagonal of a triangular decomposition of the covariance. `````` Jakob Knollmüller committed Jun 01, 2021 168 169 170 171 `````` Parameters ---------- domain: Domain `````` Philipp Arras committed Jun 01, 2021 172 173 174 `````` The domain of the diagonal. """ `````` Jakob Knollmüller committed May 30, 2021 175 176 177 178 179 `````` def __init__(self, domain): self._domain = domain def apply(self, x): self._check_input(x) `````` Philipp Arras committed Jun 01, 2021 180 `````` res = -0.5*(2*np.pi*np.e*x**2).log().sum() `````` Jakob Knollmüller committed May 30, 2021 181 `````` if not isinstance(x, Linearization): `````` Jakob Knollmüller committed Jun 07, 2021 182 `````` return res `````` Jakob Knollmüller committed May 30, 2021 183 184 `````` if not x.want_metric: return res `````` Philipp Arras committed Jun 01, 2021 185 186 `````` # FIXME not sure about metric return res.add_metric(SandwichOperator.make(res.jac)) `````` Jakob Knollmüller committed May 30, 2021 187 188 `````` `````` Philipp Frank committed Jun 02, 2021 189 ``````class LowerTriangularInserter(LinearOperator): `````` Philipp Arras committed Jun 03, 2021 190 `````` """Insert the entries of a lower triangular matrix into a matrix. `````` Philipp Arras committed Jun 01, 2021 191 `````` `````` Jakob Knollmüller committed Jun 01, 2021 192 193 194 `````` Parameters ---------- target: Domain `````` Philipp Arras committed Jun 01, 2021 195 196 197 `````` A two-dimensional domain with NxN entries. """ `````` Philipp Frank committed Jun 02, 2021 198 199 200 201 202 203 `````` def __init__(self, target): myassert(len(target.shape) == 2) myassert(target.shape[0] == target.shape[1]) self._target = makeDomain(target) ndof = (target.shape[0]*(target.shape[0]+1))//2 self._domain = makeDomain(UnstructuredDomain(ndof)) `````` Philipp Arras committed Jun 01, 2021 204 `````` self._indices = np.tril_indices(target.shape[0]) `````` Jakob Knollmüller committed May 30, 2021 205 206 207 `````` self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): `````` Philipp Arras committed Jun 01, 2021 208 209 `````` self._check_input(x, mode) x = x.val `````` Jakob Knollmüller committed May 30, 2021 210 `````` if mode == self.TIMES: `````` Philipp Arras committed Jun 01, 2021 211 212 213 214 215 216 `````` res = np.zeros(self._target.shape) res[self._indices] = x else: res = x[self._indices].reshape(self._domain.shape) return makeField(self._tgt(mode), res) `````` Jakob Knollmüller committed May 30, 2021 217 218 `````` class DiagonalSelector(LinearOperator): `````` Philipp Arras committed Jun 01, 2021 219 `````` """Extract the diagonal of a two-dimensional field. `````` Jakob Knollmüller committed Jun 01, 2021 220 221 222 223 `````` Parameters ---------- domain: Domain `````` Philipp Frank committed Jun 02, 2021 224 `````` The two-dimensional domain of the input field. Must be of shape NxN. `````` Philipp Arras committed Jun 01, 2021 225 226 `````` """ `````` Philipp Frank committed Jun 02, 2021 227 228 229 230 231 `````` def __init__(self, domain): myassert(len(domain.shape) == 2) myassert(domain.shape[0] == domain.shape[1]) self._domain = makeDomain(domain) self._target = makeDomain(UnstructuredDomain(domain.shape[0])) `````` Jakob Knollmüller committed May 30, 2021 232 233 `````` self._capability = self.TIMES | self.ADJOINT_TIMES `````` Philipp Arras committed Jun 01, 2021 234 235 `````` def apply(self, x, mode): self._check_input(x, mode) `````` Philipp Frank committed Jun 02, 2021 236 `` return makeField(self._tgt(mode), np.diag(x.val))``