-
Philipp Arras authoredPhilipp Arras authored
variational_models.py 7.76 KiB
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..linearization import Linearization
from ..minimization.energy_adapter import StochasticEnergyAdapter
from ..multi_field import MultiField
from ..operators.einsum import MultiLinearEinsum
from ..operators.energy_operators import EnergyOperator
from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import from_random, full, is_fieldlike, makeDomain, makeField
from ..utilities import myassert
class MeanFieldVI:
"""Collect the operators required for Gaussian meanfield variational
inference.
Parameters
----------
position :
FIXME
hamiltonian :
FIXME
n_samples :
FIXME
mirror_samples :
FIXME
initial_sig :
FIXME
comm :
FIXME
nanisinf :
FIXME
"""
def __init__(self, position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False):
Flat = Multifield2Vector(position.domain)
self._std = FieldAdapter(Flat.target, 'std').absolute()
latent = FieldAdapter(Flat.target,'latent')
self._mean = FieldAdapter(Flat.target, 'mean')
self._generator = Flat.adjoint(self._mean + self._std * latent)
self._entropy = GaussianEntropy(self._std.target) @ self._std
self._mean = Flat.adjoint @ self._mean
self._std = Flat.adjoint @ self._std
pos = {'mean': Flat(position)}
if is_fieldlike(initial_sig):
pos['std'] = Flat(initial_sig)
else:
pos['std'] = full(Flat.target, initial_sig)
pos = MultiField.from_dict(pos)
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._samdom = latent.domain
@property
def mean(self):
return self._mean.force(self._KL.position)
@property
def std(self):
return self._std.force(self._KL.position)
@property
def entropy(self):
return self._entropy.force(self._KL.position)
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
return op(self._KL.position)
def minimize(self, minimizer):
self._KL, _ = minimizer(self._KL)
class FullCovarianceVI:
"""Collect the operators required for Gaussian full-covariance variational
inference.
Parameters
----------
position :
FIXME
hamiltonian :
FIXME
n_samples :
FIXME
mirror_samples :
FIXME
initial_sig :
FIXME
comm :
FIXME
nanisinf :
FIXME
"""
def __init__(self, position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False):
Flat = Multifield2Vector(position.domain)
flat_domain = Flat.target[0]
mat_space = DomainTuple.make((flat_domain,flat_domain))
lat = FieldAdapter(Flat.target,'latent')
LT = LowerTriangularInserter(mat_space)
tri = FieldAdapter(LT.domain, 'cov')
mean = FieldAdapter(flat_domain,'mean')
cov = LT @ tri
matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co')
MatMult = MultiLinearEinsum(matmul_setup.target,'ij,j->i',
key_order=('co','latent'))
self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup)
diag_cov = (DiagonalSelector(cov.target) @ cov).absolute()
self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov
diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig))
pos = MultiField.from_dict(
{'mean': Flat(position),
'cov': LT.adjoint(makeField(mat_space, diag_tri))})
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._mean = Flat.adjoint @ mean
self._samdom = lat.domain
@property
def mean(self):
return self._mean.force(self._KL.position)
@property
def entropy(self):
return self._entropy.force(self._KL.position)
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
return op(self._KL.position)
def minimize(self, minimizer):
self._KL, _ = minimizer(self._KL)
class GaussianEntropy(EnergyOperator):
"""Entropy of a Gaussian distribution given the diagonal of a triangular
decomposition of the covariance.
Parameters
----------
domain: Domain FIXME
The domain of the diagonal.
"""
def __init__(self, domain):
self._domain = DomainTuple.make(domain)
def apply(self, x):
self._check_input(x)
res = (x*x).scale(2*np.pi*np.e).log().sum().scale(-0.5)
if not isinstance(x, Linearization):
return res
if not x.want_metric:
return res
# FIXME not sure about metric
return res.add_metric(SandwichOperator.make(res.jac))
class LowerTriangularInserter(LinearOperator):
"""Insert the entries of a lower triangular matrix into a matrix.
Parameters
----------
target: Domain FIXME
A two-dimensional domain with NxN entries.
"""
def __init__(self, target):
myassert(len(target.shape) == 2)
myassert(target.shape[0] == target.shape[1])
self._target = makeDomain(target)
ndof = (target.shape[0]*(target.shape[0]+1))//2
self._domain = makeDomain(UnstructuredDomain(ndof))
self._indices = np.tril_indices(target.shape[0])
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
if mode == self.TIMES:
res = np.zeros(self._target.shape)
res[self._indices] = x
else:
res = x[self._indices].reshape(self._domain.shape)
return makeField(self._tgt(mode), res)
class DiagonalSelector(LinearOperator):
"""Extract the diagonal of a two-dimensional field.
Parameters
----------
domain: Domain FIXME
The two-dimensional domain of the input field. Must be of shape NxN.
"""
def __init__(self, domain):
myassert(len(domain.shape) == 2)
myassert(domain.shape[0] == domain.shape[1])
self._domain = makeDomain(domain)
self._target = makeDomain(UnstructuredDomain(domain.shape[0]))
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
return makeField(self._tgt(mode), np.diag(x.val))