Commit 17e2d148 authored by Philipp Arras's avatar Philipp Arras
Browse files

Docstrings and simplifications

parent 8d1fb1ae
......@@ -424,9 +424,9 @@ class AbsDeltaEnergyController(IterationController):
class StochasticAbsDeltaEnergyController(IterationController):
"""An iteration controller checking the standard deviation over a
period of iterations. Convergence is reported once this quantity
falls below the given threshold
"""Check the standard deviation over a period of iterations.
Convergence is reported once this quantity falls below the given threshold.
Parameters
......@@ -434,16 +434,17 @@ class StochasticAbsDeltaEnergyController(IterationController):
deltaE : float
If the standard deviation of the last energies is below this
value, the convergence counter will be increased in this iteration.
convergence_level : int, default=1
convergence_level : int, optional
The number which the convergence counter must reach before the
iteration is considered to be converged
iteration is considered to be converged. Defaults to 1.
iteration_limit : int, optional
The maximum number of iterations that will be carried out.
name : str, optional
If supplied, this string and some diagnostic information will be
printed after every iteration.
memory_length : int, default=10
The number of last energies considered for determining convergence.
memory_length : int, optional
The number of last energies considered for determining convergence,
defaults to 10.
"""
def __init__(self, deltaE, convergence_level=1, iteration_limit=None,
......@@ -469,7 +470,7 @@ class StochasticAbsDeltaEnergyController(IterationController):
inclvl = False
Eval = energy.value
self._memory.append(Eval)
if len(self._memory)>self.memory_length:
if len(self._memory) > self.memory_length:
self._memory = self._memory[1:]
diff = np.std(self._memory)
if self._itcount > 0:
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -116,15 +116,12 @@ def test_ParametricVI(mirror_samples, fc):
h = ift.StandardHamiltonian(lh, ic_samp=ic)
initial_mean = ift.from_random(h.domain, 'normal')
nsamps = 1000
if fc:
model = ift.library.variational_models.FullCovarianceVI(initial_mean, h, nsamps, mirror_samples, initial_sig=0.01)
else:
model = ift.library.variational_models.MeanFieldVI(initial_mean, h, nsamps, mirror_samples, initial_sig=0.01)
args = initial_mean, h, nsamps, mirror_samples, 0.01
model = (ift.FullCovarianceVI if fc else ift.MeanFieldVI)(*args)
kl = model._KL
expected_nsamps = 2*nsamps if mirror_samples else nsamps
myassert(len(tuple(kl._local_ops)) == expected_nsamps)
true_val = []
for i in range(expected_nsamps):
lat_rnd = ift.from_random(model._KL._op.domain['latent'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment