variational_models.py 11.5 KB
 Philipp Arras committed Jun 01, 2021 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ``````# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2021 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. `````` Jakob Knollmüller committed May 30, 2021 18 ``````import numpy as np `````` Philipp Arras committed Jun 01, 2021 19 `````` `````` Jakob Knollmüller committed May 30, 2021 20 21 ``````from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain `````` Philipp Arras committed Jun 01, 2021 22 23 ``````from ..field import Field from ..linearization import Linearization `````` Philipp Arras committed Jun 08, 2021 24 ``````from ..minimization.energy_adapter import StochasticEnergyAdapter `````` Philipp Arras committed Jun 01, 2021 25 26 27 28 29 30 ``````from ..multi_field import MultiField from ..operators.einsum import MultiLinearEinsum from ..operators.energy_operators import EnergyOperator from ..operators.linear_operator import LinearOperator from ..operators.multifield2vector import Multifield2Vector from ..operators.sandwich_operator import SandwichOperator `````` Philipp Frank committed Jun 02, 2021 31 ``````from ..operators.simple_linear_operators import FieldAdapter `````` Philipp Arras committed Jun 08, 2021 32 ``````from ..sugar import from_random, full, is_fieldlike, makeDomain, makeField `````` Philipp Frank committed Jun 02, 2021 33 34 35 ``````from ..utilities import myassert `````` Philipp Frank committed Jun 02, 2021 36 ``````class MeanFieldVI: `````` Philipp Arras committed Jun 03, 2021 37 38 39 `````` """Collect the operators required for Gaussian meanfield variational inference. `````` Jakob Knollmüller committed Jun 10, 2021 40 41 42 43 44 45 46 47 48 49 `````` Gaussian meanfield variational inference approximates some target distribution with a Gaussian distribution with a diagonal covariance matrix. The parameters of the approximation, in this case the mean and standard deviation, are obtained by minimizing a stochastic estimate of the Kullback-Leibler divergence between the target and the approximation. In order to obtain gradients w.r.t the parameters, the reparametrization trick is employed, which separates the stochastic part of the approximation from a deterministic function, the generator. Samples from the approximation are drawn by processing samples from a standard Gaussian through this generator. `````` Philipp Arras committed Jun 03, 2021 50 51 `````` Parameters ---------- `````` Jakob Knollmüller committed Jun 10, 2021 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 `````` position : Field The initial estimate of the approximate mean parameter. hamiltonian : Energy Hamiltonian of the approximated probability distribution. n_samples : int Number of samples used to stochastically estimate the KL. mirror_samples : bool Whether the negative of the drawn samples are also used, as they are equally legitimate samples. If true, the number of used samples doubles. Mirroring samples stabilizes the KL estimate as extreme sample variation is counterbalanced. Since it improves stability in many cases, it is recommended to set `mirror_samples` to `True`. initial_sig : positive Field or positive float The initial estimate of the standard deviation. comm : MPI communicator or None If not None, samples will be distributed as evenly as possible across this communicator. If `mirror_samples` is set, then a sample and its mirror image will always reside on the same task. `````` Jakob Knollmüller committed Jun 10, 2021 70 `````` nanisinf : bool `````` Jakob Knollmüller committed Jun 10, 2021 71 72 73 74 `````` If true, nan energies which can happen due to overflows in the forward model are interpreted as inf. Thereby, the code does not crash on these occasions but rather the minimizer is told that the position it has tried is not sensible. `````` Philipp Arras committed Jun 03, 2021 75 76 `````` """ def __init__(self, position, hamiltonian, n_samples, mirror_samples, `````` Philipp Frank committed Jun 02, 2021 77 `````` initial_sig=1, comm=None, nanisinf=False): `````` Philipp Arras committed Jun 03, 2021 78 `````` Flat = Multifield2Vector(position.domain) `````` Philipp Frank committed Jun 02, 2021 79 `````` self._std = FieldAdapter(Flat.target, 'std').absolute() `````` Philipp Frank committed Jun 02, 2021 80 `````` latent = FieldAdapter(Flat.target,'latent') `````` Philipp Frank committed Jun 02, 2021 81 82 83 84 85 `````` self._mean = FieldAdapter(Flat.target, 'mean') self._generator = Flat.adjoint(self._mean + self._std * latent) self._entropy = GaussianEntropy(self._std.target) @ self._std self._mean = Flat.adjoint @ self._mean self._std = Flat.adjoint @ self._std `````` Philipp Arras committed Jun 03, 2021 86 `````` pos = {'mean': Flat(position)} `````` Philipp Frank committed Jun 02, 2021 87 `````` if is_fieldlike(initial_sig): `````` Philipp Frank committed Jun 02, 2021 88 `````` pos['std'] = Flat(initial_sig) `````` Philipp Frank committed Jun 02, 2021 89 `````` else: `````` Philipp Frank committed Jun 02, 2021 90 `````` pos['std'] = full(Flat.target, initial_sig) `````` Philipp Frank committed Jun 02, 2021 91 `````` pos = MultiField.from_dict(pos) `````` Philipp Arras committed Jun 09, 2021 92 93 `````` op = hamiltonian(self._generator) + self._entropy self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples, `````` Philipp Frank committed Jun 02, 2021 94 95 `````` mirror_samples, nanisinf=nanisinf, comm=comm) self._samdom = latent.domain `````` Philipp Frank committed Jun 02, 2021 96 97 `````` @property `````` Philipp Frank committed Jun 02, 2021 98 `````` def mean(self): `````` Philipp Arras committed Jun 03, 2021 99 `````` return self._mean.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 100 101 102 `````` @property def std(self): `````` Philipp Arras committed Jun 03, 2021 103 `````` return self._std.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 104 105 106 `````` @property def entropy(self): `````` Philipp Arras committed Jun 03, 2021 107 `````` return self._entropy.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 108 109 110 111 112 `````` def draw_sample(self): _, op = self._generator.simplify_for_constant_input( from_random(self._samdom)) return op(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 113 114 115 116 `````` def minimize(self, minimizer): self._KL, _ = minimizer(self._KL) `````` Philipp Arras committed Jun 03, 2021 117 `````` `````` Philipp Frank committed Jun 02, 2021 118 ``````class FullCovarianceVI: `````` Philipp Arras committed Jun 03, 2021 119 `````` """Collect the operators required for Gaussian full-covariance variational `````` Jakob Knollmüller committed Jun 10, 2021 120 121 122 123 124 125 126 127 128 129 130 `````` Gaussian meanfield variational inference approximates some target distribution with a Gaussian distribution with a diagonal covariance matrix. The parameters of the approximation, in this case the mean and a lower triangular matrix corresponding to a Cholesky decomposition of the covariance, are obtained by minimizing a stochastic estimate of the Kullback-Leibler divergence between the target and the approximation. In order to obtain gradients w.r.t the parameters, the reparametrization trick is employed, which separates the stochastic part of the approximation from a deterministic function, the generator. Samples from the approximation are drawn by processing samples from a standard Gaussian through this generator. `````` Jakob Knollmüller committed Jun 10, 2021 131 `````` `````` Jakob Knollmüller committed Jun 10, 2021 132 133 `````` Note that the size of the covariance scales quadratically with the number of model parameters. `````` Philipp Arras committed Jun 03, 2021 134 135 136 `````` Parameters ---------- `````` Jakob Knollmüller committed Jun 10, 2021 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 `````` position : Field The initial estimate of the approximate mean parameter. hamiltonian : Energy Hamiltonian of the approximated probability distribution. n_samples : int Number of samples used to stochastically estimate the KL. mirror_samples : bool Whether the negative of the drawn samples are also used, as they are equally legitimate samples. If true, the number of used samples doubles. Mirroring samples stabilizes the KL estimate as extreme sample variation is counterbalanced. Since it improves stability in many cases, it is recommended to set `mirror_samples` to `True`. initial_sig : positive float The initial estimate for the standard deviation. Initially no correlation between the parameters is assumed. comm : MPI communicator or None If not None, samples will be distributed as evenly as possible across this communicator. If `mirror_samples` is set, then a sample and its mirror image will always reside on the same task. `````` Jakob Knollmüller committed Jun 10, 2021 156 `````` nanisinf : bool `````` Jakob Knollmüller committed Jun 10, 2021 157 158 159 160 `````` If true, nan energies which can happen due to overflows in the forward model are interpreted as inf. Thereby, the code does not crash on these occasions but rather the minimizer is told that the position it has tried is not sensible. `````` Philipp Arras committed Jun 03, 2021 161 `````` """ `````` Philipp Frank committed Jun 02, 2021 162 `````` def __init__(self, position, hamiltonian, n_samples, mirror_samples, `````` Philipp Frank committed Jun 02, 2021 163 `````` initial_sig=1, comm=None, nanisinf=False): `````` Philipp Frank committed Jun 02, 2021 164 165 166 `````` Flat = Multifield2Vector(position.domain) flat_domain = Flat.target[0] mat_space = DomainTuple.make((flat_domain,flat_domain)) `````` Philipp Frank committed Jun 02, 2021 167 `````` lat = FieldAdapter(Flat.target,'latent') `````` Philipp Frank committed Jun 02, 2021 168 169 `````` LT = LowerTriangularInserter(mat_space) tri = FieldAdapter(LT.domain, 'cov') `````` Philipp Frank committed Jun 02, 2021 170 `````` mean = FieldAdapter(flat_domain,'mean') `````` Jakob Knollmüller committed May 30, 2021 171 `````` cov = LT @ tri `````` Philipp Frank committed Jun 02, 2021 172 173 174 175 176 177 `````` matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co') MatMult = MultiLinearEinsum(matmul_setup.target,'ij,j->i', key_order=('co','latent')) self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup) `````` Philipp Frank committed Jun 02, 2021 178 `````` diag_cov = (DiagonalSelector(cov.target) @ cov).absolute() `````` Philipp Frank committed Jun 02, 2021 179 180 181 `````` self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig)) pos = MultiField.from_dict( `````` Philipp Frank committed Jun 02, 2021 182 183 `````` {'mean': Flat(position), 'cov': LT.adjoint(makeField(mat_space, diag_tri))}) `````` Philipp Arras committed Jun 09, 2021 184 185 `````` op = hamiltonian(self._generator) + self._entropy self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples, `````` Philipp Frank committed Jun 02, 2021 186 187 188 189 190 191 `````` mirror_samples, nanisinf=nanisinf, comm=comm) self._mean = Flat.adjoint @ mean self._samdom = lat.domain @property def mean(self): `````` Philipp Arras committed Jun 03, 2021 192 `````` return self._mean.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 193 194 `````` @property `````` Philipp Frank committed Jun 02, 2021 195 `````` def entropy(self): `````` Philipp Arras committed Jun 03, 2021 196 `````` return self._entropy.force(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 197 198 199 200 201 `````` def draw_sample(self): _, op = self._generator.simplify_for_constant_input( from_random(self._samdom)) return op(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 202 203 204 `````` def minimize(self, minimizer): self._KL, _ = minimizer(self._KL) `````` Jakob Knollmüller committed May 30, 2021 205 206 207 `````` class GaussianEntropy(EnergyOperator): `````` Philipp Arras committed Jun 03, 2021 208 209 `````` """Entropy of a Gaussian distribution given the diagonal of a triangular decomposition of the covariance. `````` Jakob Knollmüller committed Jun 01, 2021 210 211 212 `````` Parameters ---------- `````` Philipp Arras committed Jun 08, 2021 213 `````` domain: Domain FIXME `````` Philipp Arras committed Jun 01, 2021 214 215 216 `````` The domain of the diagonal. """ `````` Jakob Knollmüller committed May 30, 2021 217 `````` def __init__(self, domain): `````` Philipp Arras committed Jun 08, 2021 218 `````` self._domain = DomainTuple.make(domain) `````` Jakob Knollmüller committed May 30, 2021 219 220 221 `````` def apply(self, x): self._check_input(x) `````` Philipp Arras committed Jun 08, 2021 222 `````` res = (x*x).scale(2*np.pi*np.e).log().sum().scale(-0.5) `````` Jakob Knollmüller committed May 30, 2021 223 `````` if not isinstance(x, Linearization): `````` Jakob Knollmüller committed Jun 07, 2021 224 `````` return res `````` Jakob Knollmüller committed May 30, 2021 225 226 `````` if not x.want_metric: return res `````` Philipp Arras committed Jun 01, 2021 227 228 `````` # FIXME not sure about metric return res.add_metric(SandwichOperator.make(res.jac)) `````` Jakob Knollmüller committed May 30, 2021 229 230 `````` `````` Philipp Frank committed Jun 02, 2021 231 ``````class LowerTriangularInserter(LinearOperator): `````` Philipp Arras committed Jun 03, 2021 232 `````` """Insert the entries of a lower triangular matrix into a matrix. `````` Philipp Arras committed Jun 01, 2021 233 `````` `````` Jakob Knollmüller committed Jun 01, 2021 234 235 `````` Parameters ---------- `````` Philipp Arras committed Jun 08, 2021 236 `````` target: Domain FIXME `````` Philipp Arras committed Jun 01, 2021 237 238 239 `````` A two-dimensional domain with NxN entries. """ `````` Philipp Frank committed Jun 02, 2021 240 241 242 243 244 245 `````` def __init__(self, target): myassert(len(target.shape) == 2) myassert(target.shape[0] == target.shape[1]) self._target = makeDomain(target) ndof = (target.shape[0]*(target.shape[0]+1))//2 self._domain = makeDomain(UnstructuredDomain(ndof)) `````` Philipp Arras committed Jun 01, 2021 246 `````` self._indices = np.tril_indices(target.shape[0]) `````` Jakob Knollmüller committed May 30, 2021 247 248 249 `````` self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): `````` Philipp Arras committed Jun 01, 2021 250 251 `````` self._check_input(x, mode) x = x.val `````` Jakob Knollmüller committed May 30, 2021 252 `````` if mode == self.TIMES: `````` Philipp Arras committed Jun 01, 2021 253 254 255 256 257 258 `````` res = np.zeros(self._target.shape) res[self._indices] = x else: res = x[self._indices].reshape(self._domain.shape) return makeField(self._tgt(mode), res) `````` Jakob Knollmüller committed May 30, 2021 259 260 `````` class DiagonalSelector(LinearOperator): `````` Philipp Arras committed Jun 01, 2021 261 `````` """Extract the diagonal of a two-dimensional field. `````` Jakob Knollmüller committed Jun 01, 2021 262 263 264 `````` Parameters ---------- `````` Philipp Arras committed Jun 08, 2021 265 `````` domain: Domain FIXME `````` Philipp Frank committed Jun 02, 2021 266 `````` The two-dimensional domain of the input field. Must be of shape NxN. `````` Philipp Arras committed Jun 01, 2021 267 268 `````` """ `````` Philipp Frank committed Jun 02, 2021 269 270 271 272 273 `````` def __init__(self, domain): myassert(len(domain.shape) == 2) myassert(domain.shape[0] == domain.shape[1]) self._domain = makeDomain(domain) self._target = makeDomain(UnstructuredDomain(domain.shape[0])) `````` Jakob Knollmüller committed May 30, 2021 274 275 `````` self._capability = self.TIMES | self.ADJOINT_TIMES `````` Philipp Arras committed Jun 01, 2021 276 277 `````` def apply(self, x, mode): self._check_input(x, mode) `````` Philipp Frank committed Jun 02, 2021 278 `` return makeField(self._tgt(mode), np.diag(x.val))``