variational_models.py 7.38 KB
 Philipp Arras committed Jun 01, 2021 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ``````# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2021 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. `````` Jakob Knollmüller committed May 30, 2021 18 ``````import numpy as np `````` Philipp Arras committed Jun 01, 2021 19 `````` `````` Jakob Knollmüller committed May 30, 2021 20 21 ``````from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain `````` Philipp Arras committed Jun 01, 2021 22 23 24 25 26 27 28 29 ``````from ..field import Field from ..linearization import Linearization from ..multi_field import MultiField from ..operators.einsum import MultiLinearEinsum from ..operators.energy_operators import EnergyOperator from ..operators.linear_operator import LinearOperator from ..operators.multifield2vector import Multifield2Vector from ..operators.sandwich_operator import SandwichOperator `````` Philipp Frank committed Jun 02, 2021 30 ``````from ..operators.simple_linear_operators import FieldAdapter `````` Philipp Frank committed Jun 02, 2021 31 ``````from ..sugar import full, makeField, makeDomain, from_random, is_fieldlike `````` Philipp Frank committed Jun 02, 2021 32 ``````from ..minimization.energy_adapter import StochasticEnergyAdapter `````` Philipp Frank committed Jun 02, 2021 33 34 35 36 37 ``````from ..utilities import myassert def _eval(op, position): return op(position.extract(op.domain)) `````` Philipp Arras committed Jun 01, 2021 38 `````` `````` Jakob Knollmüller committed May 30, 2021 39 `````` `````` Philipp Frank committed Jun 02, 2021 40 41 42 ``````class MeanFieldVI: def __init__(self, initial_position, hamiltonian, n_samples, mirror_samples, initial_sig=1, comm=None, nanisinf=False): `````` Philipp Frank committed Jun 02, 2021 43 44 `````` """Collect the operators required for Gaussian mean-field variational inference. `````` Philipp Arras committed Jun 01, 2021 45 `````` """ `````` Philipp Frank committed Jun 02, 2021 46 47 `````` Flat = Multifield2Vector(initial_position.domain) self._std = FieldAdapter(Flat.target, 'std').absolute() `````` Philipp Frank committed Jun 02, 2021 48 `````` latent = FieldAdapter(Flat.target,'latent') `````` Philipp Frank committed Jun 02, 2021 49 50 51 52 53 54 `````` self._mean = FieldAdapter(Flat.target, 'mean') self._generator = Flat.adjoint(self._mean + self._std * latent) self._entropy = GaussianEntropy(self._std.target) @ self._std self._mean = Flat.adjoint @ self._mean self._std = Flat.adjoint @ self._std pos = {'mean': Flat(initial_position)} `````` Philipp Frank committed Jun 02, 2021 55 `````` if is_fieldlike(initial_sig): `````` Philipp Frank committed Jun 02, 2021 56 `````` pos['std'] = Flat(initial_sig) `````` Philipp Frank committed Jun 02, 2021 57 `````` else: `````` Philipp Frank committed Jun 02, 2021 58 `````` pos['std'] = full(Flat.target, initial_sig) `````` Philipp Frank committed Jun 02, 2021 59 `````` pos = MultiField.from_dict(pos) `````` Philipp Frank committed Jun 02, 2021 60 61 62 63 `````` op = hamiltonian(self._generator) + self._entropy self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples, mirror_samples, nanisinf=nanisinf, comm=comm) self._samdom = latent.domain `````` Philipp Frank committed Jun 02, 2021 64 65 `````` @property `````` Philipp Frank committed Jun 02, 2021 66 `````` def mean(self): `````` Philipp Frank committed Jun 02, 2021 67 `````` return _eval(self._mean,self._KL.position) `````` Philipp Frank committed Jun 02, 2021 68 69 70 `````` @property def std(self): `````` Philipp Frank committed Jun 02, 2021 71 `````` return _eval(self._std,self._KL.position) `````` Philipp Frank committed Jun 02, 2021 72 73 74 `````` @property def entropy(self): `````` Philipp Frank committed Jun 02, 2021 75 `````` return _eval(self._entropy,self._KL.position) `````` Philipp Frank committed Jun 02, 2021 76 77 78 79 80 `````` def draw_sample(self): _, op = self._generator.simplify_for_constant_input( from_random(self._samdom)) return op(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 81 82 83 84 `````` def minimize(self, minimizer): self._KL, _ = minimizer(self._KL) `````` Philipp Frank committed Jun 02, 2021 85 ``````class FullCovarianceVI: `````` Philipp Frank committed Jun 02, 2021 86 `````` def __init__(self, position, hamiltonian, n_samples, mirror_samples, `````` Philipp Frank committed Jun 02, 2021 87 `````` initial_sig=1, comm=None, nanisinf=False): `````` Philipp Frank committed Jun 02, 2021 88 89 90 91 92 93 `````` """Collect the operators required for Gaussian full-covariance variational inference. """ Flat = Multifield2Vector(position.domain) flat_domain = Flat.target[0] mat_space = DomainTuple.make((flat_domain,flat_domain)) `````` Philipp Frank committed Jun 02, 2021 94 `````` lat = FieldAdapter(Flat.target,'latent') `````` Philipp Frank committed Jun 02, 2021 95 96 `````` LT = LowerTriangularInserter(mat_space) tri = FieldAdapter(LT.domain, 'cov') `````` Philipp Frank committed Jun 02, 2021 97 `````` mean = FieldAdapter(flat_domain,'mean') `````` Jakob Knollmüller committed May 30, 2021 98 `````` cov = LT @ tri `````` Philipp Frank committed Jun 02, 2021 99 100 101 102 103 104 `````` matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co') MatMult = MultiLinearEinsum(matmul_setup.target,'ij,j->i', key_order=('co','latent')) self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup) `````` Philipp Frank committed Jun 02, 2021 105 `````` diag_cov = (DiagonalSelector(cov.target) @ cov).absolute() `````` Philipp Frank committed Jun 02, 2021 106 107 108 `````` self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig)) pos = MultiField.from_dict( `````` Philipp Frank committed Jun 02, 2021 109 110 `````` {'mean': Flat(position), 'cov': LT.adjoint(makeField(mat_space, diag_tri))}) `````` Philipp Frank committed Jun 02, 2021 111 112 113 114 115 116 117 118 `````` op = hamiltonian(self._generator) + self._entropy self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples, mirror_samples, nanisinf=nanisinf, comm=comm) self._mean = Flat.adjoint @ mean self._samdom = lat.domain @property def mean(self): `````` Philipp Frank committed Jun 02, 2021 119 `````` return _eval(self._mean,self._KL.position) `````` Philipp Frank committed Jun 02, 2021 120 121 `````` @property `````` Philipp Frank committed Jun 02, 2021 122 `````` def entropy(self): `````` Philipp Frank committed Jun 02, 2021 123 `````` return _eval(self._entropy,self._KL.position) `````` Philipp Frank committed Jun 02, 2021 124 125 126 127 128 `````` def draw_sample(self): _, op = self._generator.simplify_for_constant_input( from_random(self._samdom)) return op(self._KL.position) `````` Philipp Frank committed Jun 02, 2021 129 130 131 `````` def minimize(self, minimizer): self._KL, _ = minimizer(self._KL) `````` Jakob Knollmüller committed May 30, 2021 132 133 134 `````` class GaussianEntropy(EnergyOperator): `````` Philipp Arras committed Jun 01, 2021 135 136 `````` """Calculate the entropy of a Gaussian distribution given the diagonal of a triangular decomposition of the covariance. `````` Jakob Knollmüller committed Jun 01, 2021 137 138 139 140 `````` Parameters ---------- domain: Domain `````` Philipp Arras committed Jun 01, 2021 141 142 143 `````` The domain of the diagonal. """ `````` Jakob Knollmüller committed May 30, 2021 144 145 146 147 148 `````` def __init__(self, domain): self._domain = domain def apply(self, x): self._check_input(x) `````` Philipp Arras committed Jun 01, 2021 149 `````` res = -0.5*(2*np.pi*np.e*x**2).log().sum() `````` Jakob Knollmüller committed May 30, 2021 150 151 152 153 `````` if not isinstance(x, Linearization): return Field.scalar(res) if not x.want_metric: return res `````` Philipp Arras committed Jun 01, 2021 154 155 `````` # FIXME not sure about metric return res.add_metric(SandwichOperator.make(res.jac)) `````` Jakob Knollmüller committed May 30, 2021 156 157 `````` `````` Philipp Frank committed Jun 02, 2021 158 159 ``````class LowerTriangularInserter(LinearOperator): """Inserts the DOFs of a lower triangular matrix into a matrix. `````` Philipp Arras committed Jun 01, 2021 160 `````` `````` Jakob Knollmüller committed Jun 01, 2021 161 162 163 `````` Parameters ---------- target: Domain `````` Philipp Arras committed Jun 01, 2021 164 165 166 `````` A two-dimensional domain with NxN entries. """ `````` Philipp Frank committed Jun 02, 2021 167 168 169 170 171 172 `````` def __init__(self, target): myassert(len(target.shape) == 2) myassert(target.shape[0] == target.shape[1]) self._target = makeDomain(target) ndof = (target.shape[0]*(target.shape[0]+1))//2 self._domain = makeDomain(UnstructuredDomain(ndof)) `````` Philipp Arras committed Jun 01, 2021 173 `````` self._indices = np.tril_indices(target.shape[0]) `````` Jakob Knollmüller committed May 30, 2021 174 175 176 `````` self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): `````` Philipp Arras committed Jun 01, 2021 177 178 `````` self._check_input(x, mode) x = x.val `````` Jakob Knollmüller committed May 30, 2021 179 `````` if mode == self.TIMES: `````` Philipp Arras committed Jun 01, 2021 180 181 182 183 184 185 `````` res = np.zeros(self._target.shape) res[self._indices] = x else: res = x[self._indices].reshape(self._domain.shape) return makeField(self._tgt(mode), res) `````` Jakob Knollmüller committed May 30, 2021 186 187 `````` class DiagonalSelector(LinearOperator): `````` Philipp Arras committed Jun 01, 2021 188 `````` """Extract the diagonal of a two-dimensional field. `````` Jakob Knollmüller committed Jun 01, 2021 189 190 191 192 `````` Parameters ---------- domain: Domain `````` Philipp Frank committed Jun 02, 2021 193 `````` The two-dimensional domain of the input field. Must be of shape NxN. `````` Philipp Arras committed Jun 01, 2021 194 195 `````` """ `````` Philipp Frank committed Jun 02, 2021 196 197 198 199 200 `````` def __init__(self, domain): myassert(len(domain.shape) == 2) myassert(domain.shape[0] == domain.shape[1]) self._domain = makeDomain(domain) self._target = makeDomain(UnstructuredDomain(domain.shape[0])) `````` Jakob Knollmüller committed May 30, 2021 201 202 `````` self._capability = self.TIMES | self.ADJOINT_TIMES `````` Philipp Arras committed Jun 01, 2021 203 204 `````` def apply(self, x, mode): self._check_input(x, mode) `````` Philipp Frank committed Jun 02, 2021 205 `` return makeField(self._tgt(mode), np.diag(x.val))``