variational_models.py 8.69 KB
 Philipp Arras committed Jun 01, 2021 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ``````# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2021 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. `````` Jakob Knollmüller committed May 30, 2021 18 ``````import numpy as np `````` Philipp Arras committed Jun 01, 2021 19 `````` `````` Jakob Knollmüller committed May 30, 2021 20 21 ``````from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain `````` Philipp Arras committed Jun 01, 2021 22 23 24 25 26 27 28 29 30 31 32 33 ``````from ..field import Field from ..linearization import Linearization from ..multi_domain import MultiDomain from ..multi_field import MultiField from ..operators.einsum import MultiLinearEinsum from ..operators.energy_operators import EnergyOperator from ..operators.linear_operator import LinearOperator from ..operators.multifield2vector import Multifield2Vector from ..operators.sandwich_operator import SandwichOperator from ..operators.simple_linear_operators import FieldAdapter, PartialExtractor from ..sugar import domain_union, from_random, full, makeField `````` Jakob Knollmüller committed May 30, 2021 34 35 `````` class MeanfieldModel(): `````` Jakob Knollmüller committed Jun 01, 2021 36 37 38 39 40 41 42 43 `````` ''' Collects the operators required for Gaussian mean-field variational inference. Parameters ---------- domain: MultiDomain The domain of the model parameters. ''' `````` Jakob Knollmüller committed May 30, 2021 44 `````` def __init__(self, domain): `````` Philipp Arras committed Jun 01, 2021 45 46 `````` self.domain = MultiDomain.make(domain) self.Flat = Multifield2Vector(self.domain) `````` Jakob Knollmüller committed May 30, 2021 47 48 49 50 51 52 53 `````` self.std = FieldAdapter(self.Flat.target,'var').absolute() self.latent = FieldAdapter(self.Flat.target,'latent') self.mean = FieldAdapter(self.Flat.target,'mean') self.generator = self.Flat.adjoint(self.mean + self.std * self.latent) self.entropy = GaussianEntropy(self.std.target) @ self.std `````` Jakob Knollmüller committed Jun 01, 2021 54 55 56 57 58 59 60 61 62 63 64 65 66 `````` def get_initial_pos(self, initial_mean=None, initial_sig = 1): ''' Provides an initial position for a given mean parameter vector and an initial standard deviation. Parameters ---------- initial_mean: MultiField The initial mean of the variational approximation. If not None, a Gaussian sample with mean zero and standard deviation of 0.1 is used. Default: None initial_sig: positive float The initial standard deviation shared by all parameters. Default: 1 ''' `````` Jakob Knollmüller committed May 30, 2021 67 68 69 70 71 72 73 74 75 76 `````` initial_pos = from_random(self.generator.domain).to_dict() initial_pos['latent'] = full(self.generator.domain['latent'], 0.) initial_pos['var'] = full(self.generator.domain['var'], initial_sig) if initial_mean is None: initial_mean = 0.1*from_random(self.generator.target) initial_pos['mean'] = self.Flat(initial_mean) return MultiField.from_dict(initial_pos) `````` Philipp Arras committed Jun 01, 2021 77 `````` `````` Jakob Knollmüller committed May 30, 2021 78 ``````class FullCovarianceModel(): `````` Jakob Knollmüller committed Jun 01, 2021 79 80 81 82 83 84 85 86 `````` ''' Collects the operators required for Gaussian full-covariance variational inference. Parameters ---------- domain: MultiDomain The domain of the model parameters. ''' `````` Jakob Knollmüller committed May 30, 2021 87 `````` def __init__(self, domain): `````` Philipp Arras committed Jun 01, 2021 88 89 `````` self.domain = MultiDomain.make(domain) self.Flat = Multifield2Vector(self.domain) `````` Jakob Knollmüller committed May 30, 2021 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 `````` one_space = UnstructuredDomain(1) self.flat_domain = self.Flat.target[0] N_tri = self.flat_domain.shape[0]*(self.flat_domain.shape[0]+1)//2 triangular_space = DomainTuple.make(UnstructuredDomain(N_tri)) tri = FieldAdapter(triangular_space, 'cov') mat_space = DomainTuple.make((self.flat_domain,self.flat_domain)) lat_mat_space = DomainTuple.make((one_space,self.flat_domain)) lat = FieldAdapter(lat_mat_space,'latent') LT = LowerTriangularProjector(triangular_space,mat_space) mean = FieldAdapter(self.flat_domain,'mean') cov = LT @ tri co = FieldAdapter(cov.target, 'co') matmul_setup_dom = domain_union((co.domain,lat.domain)) co_part = PartialExtractor(matmul_setup_dom, co.domain) lat_part = PartialExtractor(matmul_setup_dom, lat.domain) matmul_setup = lat_part.adjoint @ lat.adjoint @ lat + co_part.adjoint @ co.adjoint @ cov MatMult = MultiLinearEinsum(matmul_setup.target,'ij,ki->jk', key_order=('co','latent')) Resp = Respacer(MatMult.target, mean.target) self.generator = self.Flat.adjoint @ (mean + Resp @ MatMult @ matmul_setup) Diag = DiagonalSelector(cov.target, self.Flat.target) diag_cov = Diag(cov).absolute() self.entropy = GaussianEntropy(diag_cov.target) @ diag_cov `````` Philipp Arras committed Jun 01, 2021 116 `````` def get_initial_pos(self, initial_mean=None, initial_sig=1): `````` Jakob Knollmüller committed Jun 01, 2021 117 118 119 120 121 122 123 124 125 126 127 `````` ''' Provides an initial position for a given mean parameter vector and a diagonal covariance with an initial standard deviation. Parameters ---------- initial_mean: MultiField The initial mean of the variational approximation. If not None, a Gaussian sample with mean zero and standard deviation of 0.1 is used. Default: None initial_sig: positive float The initial standard deviation shared by all parameters. Default: 1 ''' `````` Jakob Knollmüller committed May 30, 2021 128 129 `````` initial_pos = from_random(self.generator.domain).to_dict() initial_pos['latent'] = full(self.generator.domain['latent'], 0.) `````` Philipp Arras committed Jun 01, 2021 130 `````` diag_tri = np.diag(np.full(self.flat_domain.shape[0], initial_sig))[np.tril_indices(self.flat_domain.shape[0])] `````` Jakob Knollmüller committed May 30, 2021 131 132 133 134 135 136 137 138 `````` initial_pos['cov'] = makeField(self.generator.domain['cov'], diag_tri) if initial_mean is None: initial_mean = 0.1*from_random(self.generator.target) initial_pos['mean'] = self.Flat(initial_mean) return MultiField.from_dict(initial_pos) class GaussianEntropy(EnergyOperator): `````` Jakob Knollmüller committed Jun 01, 2021 139 140 141 142 143 144 145 146 `````` ''' Calculates the entropy of a Gaussian distribution given the diagonal of a triangular decomposition of the covariance. Parameters ---------- domain: Domain The domain of the diagonal. ''' `````` Jakob Knollmüller committed May 30, 2021 147 148 149 150 151 `````` def __init__(self, domain): self._domain = domain def apply(self, x): self._check_input(x) `````` Philipp Arras committed Jun 01, 2021 152 `````` res = -0.5*(2*np.pi*np.e*x**2).log().sum() `````` Jakob Knollmüller committed May 30, 2021 153 154 155 156 `````` if not isinstance(x, Linearization): return Field.scalar(res) if not x.want_metric: return res `````` Philipp Arras committed Jun 01, 2021 157 158 `````` # FIXME not sure about metric return res.add_metric(SandwichOperator.make(res.jac)) `````` Jakob Knollmüller committed May 30, 2021 159 160 161 `````` class LowerTriangularProjector(LinearOperator): `````` Jakob Knollmüller committed Jun 01, 2021 162 163 164 165 166 167 168 169 170 171 `````` ''' Projects the DOFs of a triangular matrix into the matrix form. Parameters ---------- domain: Domain A one-dimensional domain containing N(N+1)/2 DOFs of a triangular matrix. target: Domain A two-dimensional domain with NxN entries. ''' `````` Jakob Knollmüller committed May 30, 2021 172 `````` def __init__(self, domain, target): `````` Philipp Arras committed Jun 01, 2021 173 174 175 `````` self._domain = DomainTuple.make(domain) self._target = DomainTuple.make(target) self._indices = np.tril_indices(target.shape[0]) `````` Jakob Knollmüller committed May 30, 2021 176 177 178 `````` self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): `````` Philipp Arras committed Jun 01, 2021 179 180 `````` self._check_input(x, mode) x = x.val `````` Jakob Knollmüller committed May 30, 2021 181 `````` if mode == self.TIMES: `````` Philipp Arras committed Jun 01, 2021 182 183 184 185 186 187 `````` res = np.zeros(self._target.shape) res[self._indices] = x else: res = x[self._indices].reshape(self._domain.shape) return makeField(self._tgt(mode), res) `````` Jakob Knollmüller committed May 30, 2021 188 189 `````` class DiagonalSelector(LinearOperator): `````` Jakob Knollmüller committed Jun 01, 2021 190 191 192 193 194 195 196 197 198 199 `````` ''' Extracts the diagonal of a two-dimensional field. Parameters ---------- domain: Domain The two-dimensional domain of the input field target: Domain A one-dimensional domain in which the diagonal of the input field lives. ''' `````` Jakob Knollmüller committed May 30, 2021 200 `````` def __init__(self, domain, target): `````` Philipp Arras committed Jun 01, 2021 201 202 `````` self._domain = DomainTuple.make(domain) self._target = DomainTuple.make(target) `````` Jakob Knollmüller committed May 30, 2021 203 204 205 `````` self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): `````` Philipp Arras committed Jun 01, 2021 206 207 208 209 210 `````` self._check_input(x, mode) x = np.diag(x.val) if mode == self.ADJOINT_TIMES: x = x.reshape(self._domain.shape) return makeField(self._tgt(mode), x) `````` Jakob Knollmüller committed May 30, 2021 211 212 213 `````` class Respacer(LinearOperator): `````` Jakob Knollmüller committed Jun 01, 2021 214 215 216 217 218 219 220 221 222 223 224 `````` ''' Re-maps a field from one domain to another one with the same amounts of DOFs. Wrapps the numpy.reshape method. Parameters ---------- domain: Domain The domain of the input field. target: Domain The domain of the output field. ''' `````` Jakob Knollmüller committed May 30, 2021 225 `````` def __init__(self, domain, target): `````` Philipp Arras committed Jun 01, 2021 226 227 `````` self._domain = DomainTuple.make(domain) self._target = DomainTuple.make(target) `````` Jakob Knollmüller committed May 30, 2021 228 229 `````` self._capability = self.TIMES | self.ADJOINT_TIMES `````` Philipp Arras committed Jun 01, 2021 230 231 232 `````` def apply(self, x, mode): self._check_input(x, mode) return makeField(self._tgt(mode), x.val.reshape(self._tgt(mode).shape))``````