Commit 8d1fb1ae authored by Philipp Arras's avatar Philipp Arras
Browse files

Cosmetics

parent d6b86d97
......@@ -56,7 +56,8 @@ if __name__ == "__main__":
fc = ift.FullCovarianceVI(position_fc, H, 3, True, initial_sig=0.01)
mf = ift.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.01)
IC = ift.StochasticAbsDeltaEnergyController(10,iteration_limit=1000, name='optimizer')
IC = ift.StochasticAbsDeltaEnergyController(10, iteration_limit=1000,
name='advi')
minimizer_fc = ift.ADVIOptimizer(IC, eta=0.1)
minimizer_mf = ift.ADVIOptimizer(IC)
......
......@@ -139,8 +139,8 @@ class FullCovarianceVI:
pos = MultiField.from_dict(
{'mean': Flat(position),
'cov': LT.adjoint(makeField(mat_space, diag_tri))})
self._op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, self._op, ['latent',], n_samples,
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._mean = Flat.adjoint @ mean
self._samdom = lat.domain
......
......@@ -422,8 +422,9 @@ class AbsDeltaEnergyController(IterationController):
return self.CONTINUE
class StochasticAbsDeltaEnergyController(IterationController):
"""An iteration controller checking the standard deviation over a
"""An iteration controller checking the standard deviation over a
period of iterations. Convergence is reported once this quantity
falls below the given threshold
......
......@@ -48,7 +48,6 @@ class ADVIOptimizer(Minimizer):
self.epsilon = epsilon
self.counter = 1
self._controller = controller
# self.steps = steps
self.s = None
self.resample = resample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment