variational_models.py 8.75 KB
 Philipp Arras committed Jun 01, 2021 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ``````# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2021 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. `````` Jakob Knollmüller committed May 30, 2021 18 ``````import numpy as np `````` Philipp Arras committed Jun 01, 2021 19 `````` `````` Jakob Knollmüller committed May 30, 2021 20 21 ``````from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain `````` Philipp Arras committed Jun 01, 2021 22 23 24 25 26 27 28 29 30 31 32 33 ``````from ..field import Field from ..linearization import Linearization from ..multi_domain import MultiDomain from ..multi_field import MultiField from ..operators.einsum import MultiLinearEinsum from ..operators.energy_operators import EnergyOperator from ..operators.linear_operator import LinearOperator from ..operators.multifield2vector import Multifield2Vector from ..operators.sandwich_operator import SandwichOperator from ..operators.simple_linear_operators import FieldAdapter, PartialExtractor from ..sugar import domain_union, from_random, full, makeField `````` Jakob Knollmüller committed May 30, 2021 34 35 `````` class MeanfieldModel(): `````` Philipp Arras committed Jun 01, 2021 36 37 `````` """Collect the operators required for Gaussian mean-field variational inference. `````` Jakob Knollmüller committed Jun 01, 2021 38 39 40 41 `````` Parameters ---------- domain: MultiDomain `````` Philipp Arras committed Jun 01, 2021 42 43 44 `````` The domain of the model parameters. """ `````` Jakob Knollmüller committed May 30, 2021 45 `````` def __init__(self, domain): `````` Philipp Arras committed Jun 01, 2021 46 47 `````` self.domain = MultiDomain.make(domain) self.Flat = Multifield2Vector(self.domain) `````` Jakob Knollmüller committed May 30, 2021 48 49 50 51 52 53 54 `````` self.std = FieldAdapter(self.Flat.target,'var').absolute() self.latent = FieldAdapter(self.Flat.target,'latent') self.mean = FieldAdapter(self.Flat.target,'mean') self.generator = self.Flat.adjoint(self.mean + self.std * self.latent) self.entropy = GaussianEntropy(self.std.target) @ self.std `````` Jakob Knollmüller committed Jun 01, 2021 55 `````` def get_initial_pos(self, initial_mean=None, initial_sig = 1): `````` Philipp Arras committed Jun 01, 2021 56 57 `````` """Provide an initial position for a given mean parameter vector and an initial standard deviation. `````` Jakob Knollmüller committed Jun 01, 2021 58 59 60 61 `````` Parameters ---------- initial_mean: MultiField `````` Philipp Arras committed Jun 01, 2021 62 63 64 `````` The initial mean of the variational approximation. If not None, a Gaussian sample with mean zero and standard deviation of 0.1 is used. Default: None `````` Jakob Knollmüller committed Jun 01, 2021 65 `````` initial_sig: positive float `````` Philipp Arras committed Jun 01, 2021 66 67 `````` The initial standard deviation shared by all parameters. Default: 1 """ `````` Jakob Knollmüller committed Jun 01, 2021 68 `````` `````` Jakob Knollmüller committed May 30, 2021 69 70 71 72 73 74 75 76 77 78 `````` initial_pos = from_random(self.generator.domain).to_dict() initial_pos['latent'] = full(self.generator.domain['latent'], 0.) initial_pos['var'] = full(self.generator.domain['var'], initial_sig) if initial_mean is None: initial_mean = 0.1*from_random(self.generator.target) initial_pos['mean'] = self.Flat(initial_mean) return MultiField.from_dict(initial_pos) `````` Philipp Arras committed Jun 01, 2021 79 `````` `````` Jakob Knollmüller committed May 30, 2021 80 ``````class FullCovarianceModel(): `````` Philipp Arras committed Jun 01, 2021 81 82 `````` """Collect the operators required for Gaussian full-covariance variational inference. `````` Jakob Knollmüller committed Jun 01, 2021 83 84 85 86 `````` Parameters ---------- domain: MultiDomain `````` Philipp Arras committed Jun 01, 2021 87 88 89 `````` The domain of the model parameters. """ `````` Jakob Knollmüller committed May 30, 2021 90 `````` def __init__(self, domain): `````` Philipp Arras committed Jun 01, 2021 91 92 `````` self.domain = MultiDomain.make(domain) self.Flat = Multifield2Vector(self.domain) `````` Jakob Knollmüller committed May 30, 2021 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 `````` one_space = UnstructuredDomain(1) self.flat_domain = self.Flat.target[0] N_tri = self.flat_domain.shape[0]*(self.flat_domain.shape[0]+1)//2 triangular_space = DomainTuple.make(UnstructuredDomain(N_tri)) tri = FieldAdapter(triangular_space, 'cov') mat_space = DomainTuple.make((self.flat_domain,self.flat_domain)) lat_mat_space = DomainTuple.make((one_space,self.flat_domain)) lat = FieldAdapter(lat_mat_space,'latent') LT = LowerTriangularProjector(triangular_space,mat_space) mean = FieldAdapter(self.flat_domain,'mean') cov = LT @ tri co = FieldAdapter(cov.target, 'co') matmul_setup_dom = domain_union((co.domain,lat.domain)) co_part = PartialExtractor(matmul_setup_dom, co.domain) lat_part = PartialExtractor(matmul_setup_dom, lat.domain) matmul_setup = lat_part.adjoint @ lat.adjoint @ lat + co_part.adjoint @ co.adjoint @ cov MatMult = MultiLinearEinsum(matmul_setup.target,'ij,ki->jk', key_order=('co','latent')) Resp = Respacer(MatMult.target, mean.target) self.generator = self.Flat.adjoint @ (mean + Resp @ MatMult @ matmul_setup) `````` Philipp Arras committed Jun 01, 2021 114 `````` `````` Jakob Knollmüller committed May 30, 2021 115 116 117 118 `````` Diag = DiagonalSelector(cov.target, self.Flat.target) diag_cov = Diag(cov).absolute() self.entropy = GaussianEntropy(diag_cov.target) @ diag_cov `````` Philipp Arras committed Jun 01, 2021 119 `````` def get_initial_pos(self, initial_mean=None, initial_sig=1): `````` Philipp Arras committed Jun 01, 2021 120 121 `````` """Provide an initial position for a given mean parameter vector and a diagonal covariance with an initial standard deviation. `````` Jakob Knollmüller committed Jun 01, 2021 122 123 124 125 `````` Parameters ---------- initial_mean: MultiField `````` Philipp Arras committed Jun 01, 2021 126 127 128 `````` The initial mean of the variational approximation. If not None, a Gaussian sample with mean zero and standard deviation of 0.1 is used. Default: None `````` Jakob Knollmüller committed Jun 01, 2021 129 `````` initial_sig: positive float `````` Philipp Arras committed Jun 01, 2021 130 131 `````` The initial standard deviation shared by all parameters. Default: 1 """ `````` Jakob Knollmüller committed May 30, 2021 132 133 `````` initial_pos = from_random(self.generator.domain).to_dict() initial_pos['latent'] = full(self.generator.domain['latent'], 0.) `````` Philipp Arras committed Jun 01, 2021 134 `````` diag_tri = np.diag(np.full(self.flat_domain.shape[0], initial_sig))[np.tril_indices(self.flat_domain.shape[0])] `````` Jakob Knollmüller committed May 30, 2021 135 136 137 138 139 140 141 142 `````` initial_pos['cov'] = makeField(self.generator.domain['cov'], diag_tri) if initial_mean is None: initial_mean = 0.1*from_random(self.generator.target) initial_pos['mean'] = self.Flat(initial_mean) return MultiField.from_dict(initial_pos) class GaussianEntropy(EnergyOperator): `````` Philipp Arras committed Jun 01, 2021 143 144 `````` """Calculate the entropy of a Gaussian distribution given the diagonal of a triangular decomposition of the covariance. `````` Jakob Knollmüller committed Jun 01, 2021 145 146 147 148 `````` Parameters ---------- domain: Domain `````` Philipp Arras committed Jun 01, 2021 149 150 151 `````` The domain of the diagonal. """ `````` Jakob Knollmüller committed May 30, 2021 152 153 154 155 156 `````` def __init__(self, domain): self._domain = domain def apply(self, x): self._check_input(x) `````` Philipp Arras committed Jun 01, 2021 157 `````` res = -0.5*(2*np.pi*np.e*x**2).log().sum() `````` Jakob Knollmüller committed May 30, 2021 158 159 160 161 `````` if not isinstance(x, Linearization): return Field.scalar(res) if not x.want_metric: return res `````` Philipp Arras committed Jun 01, 2021 162 163 `````` # FIXME not sure about metric return res.add_metric(SandwichOperator.make(res.jac)) `````` Jakob Knollmüller committed May 30, 2021 164 165 166 `````` class LowerTriangularProjector(LinearOperator): `````` Philipp Arras committed Jun 01, 2021 167 168 `````` """Project the DOFs of a triangular matrix into the matrix form. `````` Jakob Knollmüller committed Jun 01, 2021 169 170 171 `````` Parameters ---------- domain: Domain `````` Philipp Arras committed Jun 01, 2021 172 173 `````` A one-dimensional domain containing N(N+1)/2 DOFs of a triangular matrix. `````` Jakob Knollmüller committed Jun 01, 2021 174 `````` target: Domain `````` Philipp Arras committed Jun 01, 2021 175 176 177 `````` A two-dimensional domain with NxN entries. """ `````` Jakob Knollmüller committed May 30, 2021 178 `````` def __init__(self, domain, target): `````` Philipp Arras committed Jun 01, 2021 179 180 181 `````` self._domain = DomainTuple.make(domain) self._target = DomainTuple.make(target) self._indices = np.tril_indices(target.shape[0]) `````` Jakob Knollmüller committed May 30, 2021 182 183 184 `````` self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): `````` Philipp Arras committed Jun 01, 2021 185 186 `````` self._check_input(x, mode) x = x.val `````` Jakob Knollmüller committed May 30, 2021 187 `````` if mode == self.TIMES: `````` Philipp Arras committed Jun 01, 2021 188 189 190 191 192 193 `````` res = np.zeros(self._target.shape) res[self._indices] = x else: res = x[self._indices].reshape(self._domain.shape) return makeField(self._tgt(mode), res) `````` Jakob Knollmüller committed May 30, 2021 194 195 `````` class DiagonalSelector(LinearOperator): `````` Philipp Arras committed Jun 01, 2021 196 `````` """Extract the diagonal of a two-dimensional field. `````` Jakob Knollmüller committed Jun 01, 2021 197 198 199 200 `````` Parameters ---------- domain: Domain `````` Philipp Arras committed Jun 01, 2021 201 `````` The two-dimensional domain of the input field `````` Jakob Knollmüller committed Jun 01, 2021 202 `````` target: Domain `````` Philipp Arras committed Jun 01, 2021 203 204 205 206 `````` The one-dimensional domain on which the diagonal of the input field is defined. """ `````` Jakob Knollmüller committed May 30, 2021 207 `````` def __init__(self, domain, target): `````` Philipp Arras committed Jun 01, 2021 208 209 `````` self._domain = DomainTuple.make(domain) self._target = DomainTuple.make(target) `````` Jakob Knollmüller committed May 30, 2021 210 211 212 `````` self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): `````` Philipp Arras committed Jun 01, 2021 213 214 215 216 217 `````` self._check_input(x, mode) x = np.diag(x.val) if mode == self.ADJOINT_TIMES: x = x.reshape(self._domain.shape) return makeField(self._tgt(mode), x) `````` Jakob Knollmüller committed May 30, 2021 218 219 220 `````` class Respacer(LinearOperator): `````` Philipp Arras committed Jun 01, 2021 221 222 `````` """Re-map a field from one domain to another one with the same amounts of DOFs. Wrapps the numpy.reshape method. `````` Jakob Knollmüller committed Jun 01, 2021 223 224 225 226 `````` Parameters ---------- domain: Domain `````` Philipp Arras committed Jun 01, 2021 227 `````` The domain of the input field. `````` Jakob Knollmüller committed Jun 01, 2021 228 `````` target: Domain `````` Philipp Arras committed Jun 01, 2021 229 230 `````` The domain of the output field. """ `````` Jakob Knollmüller committed Jun 01, 2021 231 `````` `````` Jakob Knollmüller committed May 30, 2021 232 `````` def __init__(self, domain, target): `````` Philipp Arras committed Jun 01, 2021 233 234 `````` self._domain = DomainTuple.make(domain) self._target = DomainTuple.make(target) `````` Jakob Knollmüller committed May 30, 2021 235 236 `````` self._capability = self.TIMES | self.ADJOINT_TIMES `````` Philipp Arras committed Jun 01, 2021 237 238 239 `````` def apply(self, x, mode): self._check_input(x, mode) return makeField(self._tgt(mode), x.val.reshape(self._tgt(mode).shape))``````