From 63b02fcf7d0c2afc8694af16255683f74bab97a7 Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Sat, 24 Feb 2018 16:22:18 +0100 Subject: [PATCH] add (disabled) test illustrating problems with outr minimizers --- test/test_minimization/test_minimizers.py | 44 +++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py index c3b84899..d993ebe0 100644 --- a/test/test_minimization/test_minimizers.py +++ b/test/test_minimization/test_minimizers.py @@ -32,6 +32,9 @@ minimizers = [ift.SteepestDescent, ift.RelaxedNewton, ift.VL_BFGS, minimizers2 = [ift.RelaxedNewton, ift.VL_BFGS, ift.NonlinearCG, ift.NewtonCG, ift.L_BFGS_B] +minimizers3 = [ift.SteepestDescent, ift.RelaxedNewton, ift.VL_BFGS, + ift.NonlinearCG, ift.NewtonCG, ift.L_BFGS_B] + class Test_Minimizers(unittest.TestCase): @@ -124,3 +127,44 @@ class Test_Minimizers(unittest.TestCase): assert_equal(convergence, IC.CONVERGED) assert_allclose(energy.position.to_global_data(), 1., rtol=1e-3, atol=1e-3) + + @expand(product(minimizers3)) + def DISABLED_test_nonlinear(self, minimizer_class): + print (minimizer_class) + space = ift.UnstructuredDomain((1,)) + starting_point = ift.Field(space, val=5.) + + class ExpEnergy(ift.Energy): + def __init__(self, position): + super(ExpEnergy, self).__init__(position) + + @property + def value(self): + x = self.position.to_global_data()[0] + return -np.exp(-(x**2)) + + @property + def gradient(self): + x = self.position.to_global_data()[0] + return ift.Field(self.position.domain, val=2*x*np.exp(-(x**2))) + + @property + def curvature(self): + x = self.position.to_global_data()[0] + v = (2 - 4*x*x)*np.exp(-x**2) + return ift.DiagonalOperator( + ift.Field(self.position.domain, val=v)) + + IC = ift.GradientNormController(tol_abs_gradnorm=1e-10, + iteration_limit=10000) + try: + minimizer = minimizer_class(controller=IC) + energy = ExpEnergy(position=starting_point) + + (energy, convergence) = minimizer(energy) + except NotImplementedError: + raise SkipTest + + assert_equal(convergence, IC.CONVERGED) + assert_allclose(energy.position.to_global_data(), 0., + atol=1e-3) -- GitLab