diff --git a/test/test_kl.py b/test/test_kl.py index ad5d813f57fea7c3c5057fa4593a439e18e7424f..410ff3a9db8116d1bac4aa9b0a5a80eda6bf1a10 100644 --- a/test/test_kl.py +++ b/test/test_kl.py @@ -133,8 +133,7 @@ def test_ParametricVI(mirror_samples, fc): assert_allclose(true_val.val, kl.value, rtol=0.1) samples = model.KL.samples() - mini = ift.SteepestDescent(ift.GradientNormController(iteration_limit=3)) - model.minimize(mini) + model.minimize(ift.ADVIOptimizer(ift.GradientNormController(iteration_limit=3))) samples1 = model.KL.samples() for aa, bb in zip(samples, samples1): ift.extra.assert_allclose(aa, bb)