Commit ada00fdb authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'more_samplers' into 'NIFTy_7'

Parametric MGVI

See merge request !604
parents c76f44f3 2cbf787a
......@@ -148,6 +148,11 @@ run_visual_vi:
script:
- python3 demos/variational_inference_visualized.py
run_meanfield:
stage: demo_runs
script:
- python3 demos/parametric_variational_inference.py
run_nonlinearity_guide:
stage: demo_runs
script:
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
###############################################################################
# Meanfield and fullcovariance variational inference
#
# The signal is a 1-D lognormal distributed field.
# The data follows a Poisson likelihood.
# The posterior distribution is approximated with a diagonal, as well as a
# full covariance Gaussian distribution. This is achieved by minimizing
# a stochastic estimate of the KL-Divergence
#
# Note that the fullcovariance approximation scales quadratically with the
# number of parameters.
###############################################################################
import numpy as np
import nifty7 as ift
from matplotlib import pyplot as plt
ift.random.push_sseq_from_seed(27)
if __name__ == "__main__":
# Space and model setup
position_space = ift.RGSpace([100])
harmonic_space = position_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(harmonic_space, position_space)
p_space = ift.PowerSpace(harmonic_space)
pd = ift.PowerDistributor(harmonic_space, p_space)
a = ift.PS_field(p_space, lambda k: 1.0 / (1.0 + k ** 2))
A = pd(a)
sky = 10 * ift.exp(HT(ift.makeOp(A))).ducktape("xi")
R = ift.GeometryRemover(position_space)
mask = np.zeros(position_space.shape)
mask[mask.shape[0]//3:2*mask.shape[0]//3] = 1
mask = ift.Field.from_raw(position_space, mask)
R = ift.MaskOperator(mask)
d_space = R.target[0]
lamb = R(sky)
# Generate simulated signal and data and build log-likelihood
mock_position = ift.from_random(sky.domain, "normal")
data = ift.random.current_rng().poisson(lamb(mock_position).val)
data = ift.makeField(d_space, data)
loglikelihood = ift.PoissonianEnergy(data) @ lamb
H = ift.StandardHamiltonian(loglikelihood)
# Settings for minimization
IC = ift.StochasticAbsDeltaEnergyController(5, iteration_limit=200,
name='advi')
minimizer_fc = ift.ADVIOptimizer(IC, eta=0.1)
minimizer_mf = ift.ADVIOptimizer(IC)
# Initial positions
position_fc = ift.from_random(H.domain)*0.1
position_mf = ift.from_random(H.domain)*0.1
# Setup of the variational models
fc = ift.FullCovarianceVI(position_fc, H, 3, True, initial_sig=0.01)
mf = ift.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.01)
niter = 10
for ii in range(niter):
# Plotting
plt.plot(sky(fc.mean).val, "b-", label="Full covariance")
plt.plot(sky(mf.mean).val, "r-", label="Mean field")
for _ in range(5):
plt.plot(sky(fc.draw_sample()).val, "b-", alpha=0.3)
plt.plot(sky(mf.draw_sample()).val, "r-", alpha=0.3)
plt.plot(R.adjoint(data).val, "kx")
plt.plot(sky(mock_position).val, "k-", label="Ground truth")
plt.legend()
plt.ylim(0.1, data.val.max() + 10)
fname = f"meanfield_{ii:03d}.png"
plt.savefig(fname)
print(f"Saved results as '{fname}' ({ii}/{niter-1}).")
plt.close()
# /Plotting
# Run minimization
fc.minimize(minimizer_fc)
mf.minimize(minimizer_mf)
......@@ -19,13 +19,9 @@
###############################################################################
# Variational Inference (VI)
#
# This script demonstrates how MGVI and GeoVI work for an inference problem
# with only two real quantities of interest. This enables us to plot the
# posterior probability density as two-dimensional plot. The approximate
# posterior samples are contrasted with the maximum-a-posterior (MAP) solution
# together with samples drawn with the Laplace method. This method uses the
# local curvature at the MAP solution as inverse covariance of a Gaussian
# probability density.
# This script demonstrates how MGVI, GeoVI, MeanfieldVI and FullCovarianceVI
# work for an inference problem with only two real quantities of interest. This
# enables us to plot the posterior probability density as two-dimensional plot.
###############################################################################
import numpy as np
......@@ -74,65 +70,93 @@ def main():
plt.pause(2.0)
plt.close()
pos = ift.from_random(ham.domain, 'normal')
MAP = ift.EnergyAdapter(pos, ham, want_metric=True)
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=20, name='Mini'))
MAP, _ = minimizer(MAP)
map_xs, map_ys = [], []
for ii in range(10):
samp = (MAP.metric.draw_sample(from_inverse=True) + MAP.position).val
map_xs.append(samp['a'])
map_ys.append(samp['b'])
mapx = xx[z==np.max(z)]
mapy = yy[z==np.max(z)]
meanx = (xx*z).sum()/z.sum()
meany = (yy*z).sum()/z.sum()
n_samples = 100
minimizer = ift.NewtonCG(
ift.GradientNormController(iteration_limit=2, name='Mini'))
pos = pos1 = ift.from_random(ham.domain, 'normal')
fig, axs = plt.subplots(2, 1, figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
# Resample
mgkl = ift.MetricGaussianKL(pos, ham, 100, False)
mini_samp = ift.NewtonCG(ift.GradientNormController(iteration_limit=5))
geokl = ift.GeoMetricKL(pos1, ham, 100, mini_samp, False)
for axx in axs:
ift.GradientNormController(iteration_limit=3, name='Mini'))
IC = ift.StochasticAbsDeltaEnergyController(0.5, iteration_limit=20,
name='advi')
stochastic_minimizer_mf = ift.ADVIOptimizer(IC, eta = 0.3)
stochastic_minimizer_fc = ift.ADVIOptimizer(IC, eta = 0.3)
posmg = posgeo = posmf = posfc = ift.from_random(ham.domain, 'normal')
fc = ift.FullCovarianceVI(posfc, ham, 10, False, initial_sig=0.01)
mf = ift.MeanFieldVI(posmf, ham, 10, False, initial_sig=0.01)
fig, axs = plt.subplots(2, 2, figsize=[12, 8])
axs = axs.flatten()
def update_plot(runs):
for axx, (nn, kl, pp, sam) in zip(axs,runs):
axx.clear()
im = axx.imshow(z.T, origin='lower', norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
cmap='gist_earth_r', extent=x_limits_scaled + y_limits)
if ii == 0:
cbar = plt.colorbar(im, ax=axx)
cbar.ax.set_ylabel('pdf')
for jj, nn, kl, pp in ((0, "MGVI", mgkl, pos), (1, "GeoVI", geokl, pos1)):
axx.imshow(z.T, origin='lower', cmap='gist_earth_r',
norm=LogNorm(vmin=1e-3, vmax=np.max(z)),
extent=x_limits_scaled + y_limits)
xs, ys = [], []
for samp in kl.samples:
samp = (samp + pp).val
xs.append(samp['a'])
ys.append(samp['b'])
axs[jj].scatter(np.array(xs)*scale, np.array(ys), label=f'{nn} samples')
axs[jj].scatter(pp.val['a']*scale, pp.val['b'], label=f'{nn} latent mean')
axs[jj].set_title(nn)
for axx in axs:
axx.scatter(np.array(map_xs)*scale, np.array(map_ys),
label='Laplace samples')
axx.scatter(MAP.position.val['a']*scale, MAP.position.val['b'],
label='Maximum a posterior solution')
if sam:
samples = (samp + pp for samp in kl.samples)
else:
samples = (kl.draw_sample() for _ in range(n_samples))
mx, my = 0., 0.
for samp in samples:
a = samp.val['a']
xs.append(a)
mx += a
b = samp.val['b']
ys.append(b)
my += b
mx /= n_samples
my /= n_samples
axx.scatter(np.array(xs)*scale, np.array(ys),
label = f'{nn} samples')
axx.scatter(mx*scale, my, label = f'{nn} mean')
axx.scatter(mapx*scale, mapy, label = 'MAP')
axx.scatter(meanx*scale, meany, label = 'Posterior mean')
axx.set_title(nn)
axx.set_xlim(x_limits_scaled)
axx.set_ylim(y_limits)
axx.set_ylabel('y')
axx.legend(loc='lower right')
axs[0].xaxis.set_visible(False)
axs[1].set_xlabel('x')
axs[1].xaxis.set_visible(False)
axs[1].yaxis.set_visible(False)
axs[2].set_xlabel('x')
axs[2].set_ylabel('y')
axs[3].yaxis.set_visible(False)
axs[3].set_xlabel('x')
plt.tight_layout()
plt.draw()
plt.pause(1.0)
plt.pause(2.0)
for ii in range(20):
if ii % 2 == 0:
# Resample GeoVI and MGVI
mgkl = ift.MetricGaussianKL(posmg, ham, n_samples, False)
mini_samp = ift.NewtonCG(
ift.AbsDeltaEnergyController(1E-8, iteration_limit=5))
geokl = ift.GeoMetricKL(posgeo, ham, n_samples, mini_samp, False)
runs = (("MGVI", mgkl, posmg, True),
("GeoVI", geokl, posgeo, True),
("MeanfieldVI", mf, posmf, False),
("FullCovarianceVI", fc, posfc, False))
update_plot(runs)
mgkl, _ = minimizer(mgkl)
geokl, _ = minimizer(geokl)
pos = mgkl.position
pos1 = geokl.position
mf.minimize(stochastic_minimizer_mf)
fc.minimize(stochastic_minimizer_fc)
posmg = mgkl.position
posgeo = geokl.position
posmf = mf.mean
posfc = fc.mean
runs = (("MGVI", mgkl, posmg, True),
("GeoVI", geokl, posgeo, True),
("MeanfieldVI", mf, posmf, False),
("FullCovarianceVI", fc, posfc, False))
update_plot(runs)
ift.logger.info('Finished')
# Uncomment the following line in order to leave the plots open
# plt.show()
......
......@@ -53,6 +53,7 @@ from .operators.energy_operators import (
Squared2NormOperator, StudentTEnergy, VariableCovarianceGaussianEnergy)
from .operators.convolution_operators import FuncConvolutionOperator
from .operators.normal_operators import NormalTransform, LognormalTransform
from .operators.multifield2vector import Multifield2Vector
from .probing import probe_with_posterior_samples, probe_diagonal, \
StatCalculator, approximation2endo
......@@ -60,17 +61,18 @@ from .probing import probe_with_posterior_samples, probe_diagonal, \
from .minimization.line_search import LineSearch
from .minimization.iteration_controllers import (
IterationController, GradientNormController, DeltaEnergyController,
GradInfNormController, AbsDeltaEnergyController)
GradInfNormController, AbsDeltaEnergyController, StochasticAbsDeltaEnergyController)
from .minimization.minimizer import Minimizer
from .minimization.conjugate_gradient import ConjugateGradient
from .minimization.nonlinear_cg import NonlinearCG
from .minimization.descent_minimizers import (
DescentMinimizer, SteepestDescent, VL_BFGS, L_BFGS, RelaxedNewton,
NewtonCG)
from .minimization.stochastic_minimizer import ADVIOptimizer
from .minimization.scipy_minimizer import L_BFGS_B
from .minimization.energy import Energy
from .minimization.quadratic_energy import QuadraticEnergy
from .minimization.energy_adapter import EnergyAdapter
from .minimization.energy_adapter import EnergyAdapter, StochasticEnergyAdapter
from .minimization.kl_energies import MetricGaussianKL, GeoMetricKL
from .sugar import *
......@@ -90,6 +92,7 @@ from .library.adjust_variances import (make_adjust_variances_hamiltonian,
from .library.nft import Gridder, FinuFFT
from .library.correlated_fields import CorrelatedFieldMaker
from .library.correlated_fields_simple import SimpleCorrelatedField
from .library.variational_models import MeanFieldVI, FullCovarianceVI
from . import extra
......
......@@ -405,6 +405,11 @@ class Field(Operator):
return Field(DomainTuple.make(return_domain), data)
def scale(self, factor):
if factor == 1:
return self
return factor*self
def sum(self, spaces=None):
"""Sums up over the sub-domains given by `spaces`.
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..linearization import Linearization
from ..minimization.energy_adapter import StochasticEnergyAdapter
from ..multi_field import MultiField
from ..operators.einsum import MultiLinearEinsum
from ..operators.energy_operators import EnergyOperator
from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import from_random, full, is_fieldlike, makeDomain, makeField
from ..utilities import myassert
class MeanFieldVI:
"""Collect the operators required for Gaussian meanfield variational
inference.
Gaussian meanfield variational inference approximates some target
distribution with a Gaussian distribution with a diagonal covariance
matrix. The parameters of the approximation, in this case the mean and
standard deviation, are obtained by minimizing a stochastic estimate of the
Kullback-Leibler divergence between the target and the approximation. In
order to obtain gradients w.r.t the parameters, the reparametrization trick
is employed, which separates the stochastic part of the approximation from
a deterministic function, the generator. Samples from the approximation are
drawn by processing samples from a standard Gaussian through this
generator.
Parameters
----------
position : Field
The initial estimate of the approximate mean parameter.
hamiltonian : Energy
Hamiltonian of the approximated probability distribution.
n_samples : int
Number of samples used to stochastically estimate the KL.
mirror_samples : bool
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles. Mirroring samples stabilizes the KL estimate as extreme sample
variation is counterbalanced. Since it improves stability in many
cases, it is recommended to set `mirror_samples` to `True`.
initial_sig : positive Field or positive float
The initial estimate of the standard deviation.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible across
this communicator. If `mirror_samples` is set, then a sample and its
mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on these
occasions but rather the minimizer is told that the position it has
tried is not sensible.
"""
def __init__(self, position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False):
Flat = Multifield2Vector(position.domain)
self._std = FieldAdapter(Flat.target, 'std').absolute()
latent = FieldAdapter(Flat.target,'latent')
self._mean = FieldAdapter(Flat.target, 'mean')
self._generator = Flat.adjoint(self._mean + self._std * latent)
self._entropy = GaussianEntropy(self._std.target) @ self._std
self._mean = Flat.adjoint @ self._mean
self._std = Flat.adjoint @ self._std
pos = {'mean': Flat(position)}
if is_fieldlike(initial_sig):
pos['std'] = Flat(initial_sig)
else:
pos['std'] = full(Flat.target, initial_sig)
pos = MultiField.from_dict(pos)
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._samdom = latent.domain
@property
def mean(self):
return self._mean.force(self._KL.position)
@property
def std(self):
return self._std.force(self._KL.position)
@property
def entropy(self):
return self._entropy.force(self._KL.position)
@property
def KL(self):
return self._KL
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
return op(self._KL.position)
def minimize(self, minimizer):
self._KL, _ = minimizer(self._KL)
class FullCovarianceVI:
"""Collect the operators required for Gaussian full-covariance variational
Gaussian meanfield variational inference approximates some target
distribution with a Gaussian distribution with a diagonal covariance
matrix. The parameters of the approximation, in this case the mean and a
lower triangular matrix corresponding to a Cholesky decomposition of the
covariance, are obtained by minimizing a stochastic estimate of the
Kullback-Leibler divergence between the target and the approximation. In
order to obtain gradients w.r.t the parameters, the reparametrization trick
is employed, which separates the stochastic part of the approximation from
a deterministic function, the generator. Samples from the approximation are
drawn by processing samples from a standard Gaussian through this
generator.
Note that the size of the covariance scales quadratically with the number
of model parameters.
Parameters
----------
position : Field
The initial estimate of the approximate mean parameter.
hamiltonian : Energy
Hamiltonian of the approximated probability distribution.
n_samples : int
Number of samples used to stochastically estimate the KL.
mirror_samples : bool
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles. Mirroring samples stabilizes the KL estimate as extreme sample
variation is counterbalanced. Since it improves stability in many
cases, it is recommended to set `mirror_samples` to `True`.
initial_sig : positive float
The initial estimate for the standard deviation. Initially no
correlation between the parameters is assumed.
comm : MPI communicator or None
If not None, samples will be distributed as evenly as possible across
this communicator. If `mirror_samples` is set, then a sample and its
mirror image will always reside on the same task.
nanisinf : bool
If true, nan energies which can happen due to overflows in the forward
model are interpreted as inf. Thereby, the code does not crash on these
occasions but rather the minimizer is told that the position it has
tried is not sensible.
"""
def __init__(self, position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False):
Flat = Multifield2Vector(position.domain)
flat_domain = Flat.target[0]
mat_space = DomainTuple.make((flat_domain,flat_domain))
lat = FieldAdapter(Flat.target,'latent')
LT = LowerTriangularInserter(mat_space)
tri = FieldAdapter(LT.domain, 'cov')
mean = FieldAdapter(flat_domain,'mean')
cov = LT @ tri
matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co')
MatMult = MultiLinearEinsum(matmul_setup.target,'ij,j->i',
key_order=('co','latent'))
self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup)
diag_cov = (DiagonalSelector(cov.target) @ cov).absolute()
self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov
diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig))
pos = MultiField.from_dict(
{'mean': Flat(position),
'cov': LT.adjoint(makeField(mat_space, diag_tri))})
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._mean = Flat.adjoint @ mean
self._samdom = lat.domain
@property
def mean(self):
return self._mean.force(self._KL.position)
@property
def entropy(self):
return self._entropy.force(self._KL.position)
@property
def KL(self):
return self._KL
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
return op(self._KL.position)
def minimize(self, minimizer):
self._KL, _ = minimizer(self._KL)
class GaussianEntropy(EnergyOperator):
"""Entropy of a Gaussian distribution given the diagonal of a triangular
decomposition of the covariance.
As metric a `SandwichOperator` of the Jacobian is used. This is not a
proper Fisher metric but may be useful for second order minimization.
Parameters
----------
domain: Domain, DomainTuple, list of Domain
The domain of the diagonal.
"""
def __init__(self, domain):
self._domain = DomainTuple.make(domain)
def apply(self, x):
self._check_input(x)
if isinstance(x, Field):
if not np.issubdtype(x.dtype, np.floating):
raise NotImplementedError("only real fields are allowed")
if isinstance(x, MultiField):
for key in x.keys():
if not np.issubdtype(x[key].dtype, np.floating):
raise NotImplementedError("only real fields are allowed")
res = (x*x).scale(2*np.pi*np.e).log().sum().scale(-0.5)
if not isinstance(x, Linearization):
return res
if not x.want_metric:
return res
return res.add_metric(SandwichOperator.make(res.jac))
class LowerTriangularInserter(LinearOperator):
"""Insert the entries of a lower triangular matrix into a matrix.
Parameters
----------
target: Domain, DomainTuple, list of Domain
A two-dimensional domain with NxN entries.
"""
def __init__(self, target):
myassert(len(target.shape) == 2)
myassert(target.shape[0] == target.shape[1])
self._target = makeDomain(target)
ndof = (target.shape[0]*(target.shape[0]+1))//2
self._domain = makeDomain(UnstructuredDomain(ndof))
self._indices = np.tril_indices(target.shape[0])
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val