Commit f4ace45a authored by Jakob Knollmüller's avatar Jakob Knollmüller
Browse files

polish meanfield demo

parent 4d2c5ffd
......@@ -16,9 +16,16 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
###############################################################################
# FIXME Short text what this demo does
# Meanfield and fullcovariance variational inference
#
# The signal is a 1-D lognormal distributed field.
# The data follows a Poisson likelihood.
# The posterior distribution is approximated with a diagonal, as well as a
# full covariance Gaussian distribution. This is achieved by minimizing
# a stochastic estimate of the KL-Divergence
#
# Note that the fullcovariance approximation scales quadratically with the
# number of parameters.
###############################################################################
import numpy as np
......@@ -26,8 +33,11 @@ import numpy as np
import nifty7 as ift
from matplotlib import pyplot as plt
ift.random.push_sseq_from_seed(27)
if __name__ == "__main__":
# Space and model setup
position_space = ift.RGSpace([100])
harmonic_space = position_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(harmonic_space, position_space)
......@@ -39,29 +49,37 @@ if __name__ == "__main__":
sky = 10 * ift.exp(HT(ift.makeOp(A))).ducktape("xi")
R = ift.GeometryRemover(position_space)
mask = np.zeros(position_space.shape)
mask[mask.shape[0]//3:2*mask.shape[0]//3] = 1
mask = ift.Field.from_raw(position_space, mask)
R = ift.MaskOperator(mask)
d_space = R.target[0]
lamb = R(sky)
# Generate simulated signal and data and build likelihood.
mock_position = ift.from_random(sky.domain, "normal")
data = ift.random.current_rng().poisson(lamb(mock_position).val)
likelihood = ift.PoissonianEnergy(ift.makeField(d_space, data)) @ lamb
data = ift.makeField(d_space, data)
likelihood = ift.PoissonianEnergy(data) @ lamb
H = ift.StandardHamiltonian(likelihood)
# Settings for minimization
ic_newton = ift.DeltaEnergyController(
name="Newton", iteration_limit=1, tol_rel_deltaE=1e-8
)
IC = ift.StochasticAbsDeltaEnergyController(5, iteration_limit=200,
name='advi')
minimizer_fc = ift.ADVIOptimizer(IC, eta=0.1)
minimizer_mf = ift.ADVIOptimizer(IC)
H = ift.StandardHamiltonian(likelihood)
# Initial positions
position_fc = ift.from_random(H.domain)*0.1
position_mf = ift.from_random(H.domain)*0.1
# Setup of the variational models
fc = ift.FullCovarianceVI(position_fc, H, 3, True, initial_sig=0.01)
mf = ift.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.01)
IC = ift.StochasticAbsDeltaEnergyController(10, iteration_limit=1000,
name='advi')
minimizer_fc = ift.ADVIOptimizer(IC, eta=0.1)
minimizer_mf = ift.ADVIOptimizer(IC)
niter = 25
niter = 10
for ii in range(niter):
# Plotting
plt.plot(sky(fc.mean).val, "b-", label="Full covariance")
......@@ -69,15 +87,16 @@ if __name__ == "__main__":
for _ in range(5):
plt.plot(sky(fc.draw_sample()).val, "b-", alpha=0.3)
plt.plot(sky(mf.draw_sample()).val, "r-", alpha=0.3)
plt.plot(data, "kx")
plt.plot(R.adjoint(data).val, "kx")
plt.plot(sky(mock_position).val, "k-", label="Ground truth")
plt.legend()
plt.ylim(0, data.max() + 10)
plt.ylim(0.1, data.val.max() + 10)
fname = f"meanfield_{ii:03d}.png"
plt.savefig(fname)
print(f"Saved results as '{fname}' ({ii}/{niter-1}).")
plt.close()
# /Plotting
# Run minimization
fc.minimize(minimizer_fc)
mf.minimize(minimizer_mf)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment