Commit d6b86d97 authored by Philipp Arras's avatar Philipp Arras
Browse files

Cleanup interface

parent 12df7bd8
......@@ -70,8 +70,8 @@ class MeanFieldVI:
else:
pos['std'] = full(Flat.target, initial_sig)
pos = MultiField.from_dict(pos)
self._op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, self._op, ['latent',], n_samples,
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._samdom = latent.domain
......
......@@ -127,15 +127,10 @@ def test_ParametricVI(mirror_samples, fc):
true_val = []
for i in range(expected_nsamps):
lat_rnd = ift.from_random(model._op.domain['latent'])
lat_rnd = ift.from_random(model._KL._op.domain['latent'])
samp = kl.position.to_dict()
samp['latent'] = lat_rnd
samp = ift.MultiField.from_dict(samp)
true_val.append(model._op(samp))
true_val.append(model._KL._op(samp))
true_val = sum(true_val)/expected_nsamps
assert_allclose(true_val.val, kl.value, rtol=0.1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment