From 3a8d202dc03e81ab18b7f0a95335fe111fd13a86 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Tue, 19 May 2020 14:50:27 +0200 Subject: [PATCH] Add some testing --- nifty6/minimization/iteration_controllers.py | 5 ++++- nifty6/plot.py | 2 +- test/test_kl.py | 15 ++++++++++++--- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/nifty6/minimization/iteration_controllers.py b/nifty6/minimization/iteration_controllers.py index 85c081906..dfa8618f2 100644 --- a/nifty6/minimization/iteration_controllers.py +++ b/nifty6/minimization/iteration_controllers.py @@ -108,7 +108,7 @@ class EnergyHistory(object): self._lst = [] @property - def timestamps(self): + def time_stamps(self): return [x for x, _ in self._lst] @property @@ -128,6 +128,9 @@ class EnergyHistory(object): self._lst += other._lst return self + def __len__(self): + return len(self._lst) + def append_history(func): @functools.wraps(func) diff --git a/nifty6/plot.py b/nifty6/plot.py index 92c5e2a82..8ba30f87c 100644 --- a/nifty6/plot.py +++ b/nifty6/plot.py @@ -288,7 +288,7 @@ def _plot_history(f, ax, **kwargs): plt.yscale(kwargs.pop("yscale", "linear")) mi, ma = np.inf, -np.inf for i, fld in enumerate(f): - xcoord = date2num([dt.fromtimestamp(ts) for ts in fld.timestamps]) + xcoord = date2num([dt.fromtimestamp(ts) for ts in fld.time_stamps]) ycoord = fld.energy_values ax.scatter(xcoord, ycoord, label=label[i], alpha=alpha[i], color=color[i], s=size[i]) diff --git a/test/test_kl.py b/test/test_kl.py index eadbcccfe..fa43322c1 100644 --- a/test/test_kl.py +++ b/test/test_kl.py @@ -39,7 +39,7 @@ def test_kl(constants, point_estimates, mirror_samples, mf): import numpy as np lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op ic = ift.GradientNormController(iteration_limit=5) - ic.activate_and_reset_logging() + ic.enable_logging() h = ift.StandardHamiltonian(lh, ic_samp=ic) mean0 = ift.from_random('normal', h.domain) @@ -51,7 +51,14 @@ def test_kl(constants, point_estimates, mirror_samples, mf): point_estimates=point_estimates, mirror_samples=mirror_samples, napprox=0) - ic.pop_history() + assert_(len(ic.history) > 0) + assert_(len(ic.history) == len(ic.history.time_stamps)) + assert_(len(ic.history) == len(ic.history.energy_values)) + ic.history.reset() + assert_(len(ic.history) == 0) + assert_(len(ic.history) == len(ic.history.time_stamps)) + assert_(len(ic.history) == len(ic.history.energy_values)) + locsamp = kl._local_samples klpure = ift.MetricGaussianKL(mean0, h, @@ -87,8 +94,10 @@ def test_kl(constants, point_estimates, mirror_samples, mf): # Test constants (after some minimization) cg = ift.GradientNormController(iteration_limit=5) - minimizer = ift.NewtonCG(cg) + minimizer = ift.NewtonCG(cg, activate_logging=True) kl, _ = minimizer(kl) + if len(constants) != 2: + assert_(len(minimizer.inversion_history) > 0) diff = (mean0 - kl.position).to_dict() for kk in constants: assert_allclose(diff[kk].val, 0*diff[kk].val) -- GitLab