test_minimizers.py 1.31 KB
 Matevz, Sraml (sraml) committed May 24, 2017 1 2 ``````import unittest `````` Theo Steininger committed Jul 08, 2017 3 ``````import numpy as np `````` Martin Reinecke committed Aug 19, 2017 4 ``````from numpy.testing import assert_allclose `````` Matevz, Sraml (sraml) committed May 24, 2017 5 `````` `````` Martin Reinecke committed Aug 19, 2017 6 ``````import nifty as ift `````` Matevz, Sraml (sraml) committed May 24, 2017 7 `````` `````` Martin Reinecke committed Aug 19, 2017 8 ``````from itertools import product `````` Matevz, Sraml (sraml) committed May 24, 2017 9 10 ``````from test.common import expand `````` Martin Reinecke committed Aug 19, 2017 11 12 13 ``````spaces = [ift.RGSpace([1024], distances=0.123), ift.HPSpace(32)] minimizers = [ift.SteepestDescent, ift.RelaxedNewton, ift.VL_BFGS, ift.ConjugateGradient] `````` Theo Steininger committed Jul 08, 2017 14 15 `````` `````` Martin Reinecke committed Aug 19, 2017 16 ``````class Test_Minimizers(unittest.TestCase): `````` Theo Steininger committed Jul 08, 2017 17 `````` `````` Martin Reinecke committed Aug 19, 2017 18 19 `````` @expand(product(minimizers, spaces)) def test_minimization(self, minimizer_class, space): `````` Theo Steininger committed Jul 08, 2017 20 `````` np.random.seed(42) `````` Martin Reinecke committed Aug 19, 2017 21 22 23 24 25 `````` starting_point = ift.Field.from_random('normal', domain=space)*10 covariance_diagonal = ift.Field.from_random( 'uniform', domain=space) + 0.5 covariance = ift.DiagonalOperator(space, diagonal=covariance_diagonal) required_result = ift.Field(space, val=1.) `````` Theo Steininger committed Jul 08, 2017 26 `````` `````` Martin Reinecke committed Aug 19, 2017 27 28 29 30 `````` IC = ift.DefaultIterationController(tol_gradnorm=1e-5) minimizer = minimizer_class(controller=IC) energy = ift.QuadraticEnergy(A=covariance, b=required_result, position=starting_point) `````` Theo Steininger committed Jul 08, 2017 31 `````` `````` Martin Reinecke committed Aug 19, 2017 32 `````` (energy, convergence) = minimizer(energy) `````` Martin Reinecke committed Aug 19, 2017 33 `````` assert convergence == IC.CONVERGED `````` Martin Reinecke committed Aug 19, 2017 34 `````` assert_allclose(energy.position.val.get_full_data(), `````` Martin Reinecke committed Aug 19, 2017 35 36 `````` 1./covariance_diagonal.val.get_full_data(), rtol=1e-3, atol=1e-3)``````