Commit f4ace45a authored by Jakob Knollmüller's avatar Jakob Knollmüller
Browse files

polish meanfield demo

parent 4d2c5ffd
...@@ -16,9 +16,16 @@ ...@@ -16,9 +16,16 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
############################################################################### ###############################################################################
# FIXME Short text what this demo does # Meanfield and fullcovariance variational inference
# #
# The signal is a 1-D lognormal distributed field.
# The data follows a Poisson likelihood.
# The posterior distribution is approximated with a diagonal, as well as a
# full covariance Gaussian distribution. This is achieved by minimizing
# a stochastic estimate of the KL-Divergence
# #
# Note that the fullcovariance approximation scales quadratically with the
# number of parameters.
############################################################################### ###############################################################################
import numpy as np import numpy as np
...@@ -26,8 +33,11 @@ import numpy as np ...@@ -26,8 +33,11 @@ import numpy as np
import nifty7 as ift import nifty7 as ift
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
ift.random.push_sseq_from_seed(27)
if __name__ == "__main__": if __name__ == "__main__":
# Space and model setup
position_space = ift.RGSpace([100]) position_space = ift.RGSpace([100])
harmonic_space = position_space.get_default_codomain() harmonic_space = position_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(harmonic_space, position_space) HT = ift.HarmonicTransformOperator(harmonic_space, position_space)
...@@ -39,29 +49,37 @@ if __name__ == "__main__": ...@@ -39,29 +49,37 @@ if __name__ == "__main__":
sky = 10 * ift.exp(HT(ift.makeOp(A))).ducktape("xi") sky = 10 * ift.exp(HT(ift.makeOp(A))).ducktape("xi")
R = ift.GeometryRemover(position_space) R = ift.GeometryRemover(position_space)
mask = np.zeros(position_space.shape)
mask[mask.shape[0]//3:2*mask.shape[0]//3] = 1
mask = ift.Field.from_raw(position_space, mask)
R = ift.MaskOperator(mask)
d_space = R.target[0] d_space = R.target[0]
lamb = R(sky) lamb = R(sky)
# Generate simulated signal and data and build likelihood.
mock_position = ift.from_random(sky.domain, "normal") mock_position = ift.from_random(sky.domain, "normal")
data = ift.random.current_rng().poisson(lamb(mock_position).val) data = ift.random.current_rng().poisson(lamb(mock_position).val)
likelihood = ift.PoissonianEnergy(ift.makeField(d_space, data)) @ lamb data = ift.makeField(d_space, data)
likelihood = ift.PoissonianEnergy(data) @ lamb
H = ift.StandardHamiltonian(likelihood)
# Settings for minimization # Settings for minimization
ic_newton = ift.DeltaEnergyController( IC = ift.StochasticAbsDeltaEnergyController(5, iteration_limit=200,
name="Newton", iteration_limit=1, tol_rel_deltaE=1e-8 name='advi')
) minimizer_fc = ift.ADVIOptimizer(IC, eta=0.1)
minimizer_mf = ift.ADVIOptimizer(IC)
H = ift.StandardHamiltonian(likelihood) # Initial positions
position_fc = ift.from_random(H.domain)*0.1 position_fc = ift.from_random(H.domain)*0.1
position_mf = ift.from_random(H.domain)*0.1 position_mf = ift.from_random(H.domain)*0.1
# Setup of the variational models
fc = ift.FullCovarianceVI(position_fc, H, 3, True, initial_sig=0.01) fc = ift.FullCovarianceVI(position_fc, H, 3, True, initial_sig=0.01)
mf = ift.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.01) mf = ift.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.01)
IC = ift.StochasticAbsDeltaEnergyController(10, iteration_limit=1000,
name='advi')
minimizer_fc = ift.ADVIOptimizer(IC, eta=0.1)
minimizer_mf = ift.ADVIOptimizer(IC)
niter = 25
niter = 10
for ii in range(niter): for ii in range(niter):
# Plotting # Plotting
plt.plot(sky(fc.mean).val, "b-", label="Full covariance") plt.plot(sky(fc.mean).val, "b-", label="Full covariance")
...@@ -69,15 +87,16 @@ if __name__ == "__main__": ...@@ -69,15 +87,16 @@ if __name__ == "__main__":
for _ in range(5): for _ in range(5):
plt.plot(sky(fc.draw_sample()).val, "b-", alpha=0.3) plt.plot(sky(fc.draw_sample()).val, "b-", alpha=0.3)
plt.plot(sky(mf.draw_sample()).val, "r-", alpha=0.3) plt.plot(sky(mf.draw_sample()).val, "r-", alpha=0.3)
plt.plot(data, "kx") plt.plot(R.adjoint(data).val, "kx")
plt.plot(sky(mock_position).val, "k-", label="Ground truth") plt.plot(sky(mock_position).val, "k-", label="Ground truth")
plt.legend() plt.legend()
plt.ylim(0, data.max() + 10) plt.ylim(0.1, data.val.max() + 10)
fname = f"meanfield_{ii:03d}.png" fname = f"meanfield_{ii:03d}.png"
plt.savefig(fname) plt.savefig(fname)
print(f"Saved results as '{fname}' ({ii}/{niter-1}).") print(f"Saved results as '{fname}' ({ii}/{niter-1}).")
plt.close() plt.close()
# /Plotting # /Plotting
# Run minimization
fc.minimize(minimizer_fc) fc.minimize(minimizer_fc)
mf.minimize(minimizer_mf) mf.minimize(minimizer_mf)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment