Commit e75ca178 authored by Jakob Knollmüller's avatar Jakob Knollmüller
Browse files

testing successfull

parent 93f178d9
Pipeline #103212 passed with stages
in 15 minutes and 12 seconds
......@@ -70,8 +70,8 @@ class MeanFieldVI:
else:
pos['std'] = full(Flat.target, initial_sig)
pos = MultiField.from_dict(pos)
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
self._op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, self._op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._samdom = latent.domain
......@@ -139,8 +139,8 @@ class FullCovarianceVI:
pos = MultiField.from_dict(
{'mean': Flat(position),
'cov': LT.adjoint(makeField(mat_space, diag_tri))})
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
self._op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, self._op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._mean = Flat.adjoint @ mean
self._samdom = lat.domain
......
......@@ -103,3 +103,39 @@ def test_kl(constants, point_estimates, mirror_samples, mf, geo):
else:
res0 = klpure.gradient[kk].val
assert_allclose(res0, res1)
@pmp('mirror_samples', (True, False))
@pmp('fc',(True, False))
def test_ParametricVI(mirror_samples, fc):
dom = ift.RGSpace((12,), (2.12))
op = ift.HarmonicSmoothingOperator(dom, 3)
op = ift.ducktape(dom, None, 'a')*(op.ducktape('b'))
lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op
ic = ift.GradientNormController(iteration_limit=5)
ic.enable_logging()
h = ift.StandardHamiltonian(lh, ic_samp=ic)
initial_mean = ift.from_random(h.domain, 'normal')
nsamps = 1000
if fc:
model = ift.library.variational_models.FullCovarianceVI(initial_mean, h, nsamps, mirror_samples, initial_sig=0.01)
else:
model = ift.library.variational_models.MeanFieldVI(initial_mean, h, nsamps, mirror_samples, initial_sig=0.01)
kl = model._KL
expected_nsamps = 2*nsamps if mirror_samples else nsamps
myassert(len(tuple(kl._local_ops)) == expected_nsamps)
true_val = []
for i in range(expected_nsamps):
lat_rnd = ift.from_random(model._op.domain['latent'])
samp = kl.position.to_dict()
samp['latent'] = lat_rnd
samp = ift.MultiField.from_dict(samp)
true_val.append(model._op(samp))
true_val = sum(true_val)/expected_nsamps
assert_allclose(true_val.val, kl.value, rtol=0.1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment