Commit 7a2af8a0 authored by Philipp Arras's avatar Philipp Arras
Browse files

fixup! Port to geoVI compatible NIFTy_7

parent 4d60d6c4
......@@ -70,8 +70,6 @@ if __name__ == "__main__":
ic_newton = ift.DeltaEnergyController(
name="Newton", iteration_limit=1, tol_rel_deltaE=1e-8
)
minimizer_fc = ift.ADVIOptimizer(steps=10)
minimizer_mf = ift.ADVIOptimizer(steps=10)
H = ift.StandardHamiltonian(likelihood)
fullcov_model = ift.FullCovarianceModel(H.domain)
......@@ -79,8 +77,13 @@ if __name__ == "__main__":
position_fc = fullcov_model.get_initial_pos(initial_sig=0.01)
position_mf = meanfield_model.get_initial_pos(initial_sig=0.01)
KL_fc = ift.ParametricGaussianKL(position_fc, H, fullcov_model, 3, True)
KL_mf = ift.ParametricGaussianKL(position_mf, H, meanfield_model, 3, True)
f_KL_fc = lambda x: ift.ParametricGaussianKL(x, H, fullcov_model, 3, True)
KL_fc = f_KL_fc(position_fc)
f_KL_mf = lambda x: ift.ParametricGaussianKL(x, H, meanfield_model, 3, True)
KL_mf = f_KL_mf(position_mf)
minimizer_fc = ift.ADVIOptimizer(10, f_KL_fc)
minimizer_mf = ift.ADVIOptimizer(10, f_KL_mf)
plt.pause(0.001)
for i in range(25):
......
......@@ -540,11 +540,7 @@ def ParametricGaussianKL(variational_parameters, hamiltonian,
raise TypeError
from ..sugar import full, from_random
assert variational_model.generator.target is hamiltonian.domain
# FIXME self._variational_model = variational_model
# FIXME self._full_model = (
# hamiltonian(variational_model.generator) + variational_model.entropy
# )
full_model = hamiltonian(variational_model.generator) + variational_model.entropy
local_samples = []
sseq = random.spawn_sseq(n_samples)
......@@ -556,11 +552,9 @@ def ParametricGaussianKL(variational_parameters, hamiltonian,
DirtyMask = MultiField.from_dict(DirtyMaskDict)
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
local_samples.append(
DirtyMask * from_random(variational_model.generator.domain)
)
s = DirtyMask * from_random(variational_model.generator.domain)
local_samples.append(s)
local_samples = tuple(local_samples)
# FIXME What about variational_model?
return _SampledKLEnergy(variational_parameters, hamiltonian, n_samples, mirror_samples, comm,
local_samples, nanisinf)
return _SampledKLEnergy(variational_parameters, full_model, n_samples,
mirror_samples, comm, local_samples, nanisinf)
......@@ -37,13 +37,14 @@ class ADVIOptimizer(Minimizer):
A small value guarantees Robbins and Monro conditions.
"""
def __init__(self, steps, eta=1, alpha=0.1, tau=1, epsilon=1e-16):
def __init__(self, steps, resample_energy, eta=1, alpha=0.1, tau=1, epsilon=1e-16):
self.alpha = alpha
self.eta = eta
self.tau = tau
self.epsilon = epsilon
self.counter = 1
self.steps = steps
self._resample = resample_energy
self.s = None
def _step(self, position, gradient):
......@@ -58,18 +59,14 @@ class ADVIOptimizer(Minimizer):
return new_position
def __call__(self, E):
from ..minimization.parametric_gaussian_kl import ParametricGaussianKL
from ..minimization.kl_energies import ParametricGaussianKL
if self.s is None:
self.s = E.gradient ** 2
# FIXME come up with somthing how to determine convergence
convergence = 0
for i in range(self.steps):
x = self._step(E.position, E.gradient)
# FIXME maybe some KL function for resample? Should make it more generic.
E = ParametricGaussianKL.make(
x, E._hamiltonian, E._variational_model, E._n_samples, E._mirror_samples
)
E = self._resample(x)
return E, convergence
def reset(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment