test_minimizers.py 1.35 KB
Newer Older
Matevz, Sraml (sraml)'s avatar
Matevz, Sraml (sraml) committed
1
import unittest
2
import numpy as np
Martin Reinecke's avatar
changes  
Martin Reinecke committed
3
from numpy.testing import assert_allclose
Martin Reinecke's avatar
Martin Reinecke committed
4
import nifty2go as ift
Martin Reinecke's avatar
changes  
Martin Reinecke committed
5
from itertools import product
Matevz, Sraml (sraml)'s avatar
Matevz, Sraml (sraml) committed
6 7
from test.common import expand

Martin Reinecke's avatar
changes  
Martin Reinecke committed
8 9
spaces = [ift.RGSpace([1024], distances=0.123), ift.HPSpace(32)]
minimizers = [ift.SteepestDescent, ift.RelaxedNewton, ift.VL_BFGS,
10
              ift.ConjugateGradient, ift.NonlinearCG]
11 12


Martin Reinecke's avatar
changes  
Martin Reinecke committed
13
class Test_Minimizers(unittest.TestCase):
14

Martin Reinecke's avatar
changes  
Martin Reinecke committed
15
    @expand(product(minimizers, spaces))
16
    def test_quadratic_minimization(self, minimizer_class, space):
17
        np.random.seed(42)
Martin Reinecke's avatar
changes  
Martin Reinecke committed
18 19 20
        starting_point = ift.Field.from_random('normal', domain=space)*10
        covariance_diagonal = ift.Field.from_random(
                                  'uniform', domain=space) + 0.5
21
        covariance = ift.DiagonalOperator(covariance_diagonal)
22
        required_result = ift.Field.ones(space, dtype=np.float64)
23

24
        IC = ift.GradientNormController(tol_abs_gradnorm=1e-5)
Martin Reinecke's avatar
changes  
Martin Reinecke committed
25 26 27
        minimizer = minimizer_class(controller=IC)
        energy = ift.QuadraticEnergy(A=covariance, b=required_result,
                                     position=starting_point)
28

Martin Reinecke's avatar
Martin Reinecke committed
29
        (energy, convergence) = minimizer(energy)
Martin Reinecke's avatar
changes  
Martin Reinecke committed
30
        assert convergence == IC.CONVERGED
Martin Reinecke's avatar
Martin Reinecke committed
31 32
        assert_allclose(ift.dobj.to_global_data(energy.position.val),
                        1./ift.dobj.to_global_data(covariance_diagonal.val),
Martin Reinecke's avatar
changes  
Martin Reinecke committed
33
                        rtol=1e-3, atol=1e-3)