Skip to content
Snippets Groups Projects
test_relaxed_newton.py 2.65 KiB

import unittest

from numpy.testing import assert_equal, assert_almost_equal

from nifty import *

from itertools import product
from test.common import expand
from test.common import generate_spaces



class QuadraticPot(Energy):
    def __init__(self, position, N):
        super(QuadraticPot, self).__init__(position)
        self.N = N
        
    def at(self, position):
        return self.__class__(position, N = self.N)


    @property
    def value(self):
        H = 0.5 *self.position.dot(self.N.inverse_times(self.position))
        return H.real

    @property
    def gradient(self):
        g = self.N.inverse_times(self.position)
        return_g = g.copy_empty(dtype=np.float)
        return_g.val = g.val.real
        return return_g
            
    @property
    def curvature(self):
        return self.N



class RelaxedNewton_Tests(unittest.TestCase):
    spaces = generate_spaces()


    @expand(product(spaces, [10,  100, 1000], [1E-3, 1E-4, 1E-5], [2, 3, 4] ))
    def test_property(self, space, iteration_limit, convergence_tolerance, 
                      convergence_level):
        
        x = Field.from_random('normal', domain=space)
        N = DiagonalOperator(space, diagonal = 1.)
        energy = QuadraticPot(position=x , N=N)
        
        minimizer = RelaxedNewton(iteration_limit=iteration_limit,
                                    convergence_tolerance=convergence_tolerance, 
                                    convergence_level=convergence_level)
                                    
        (energy, convergence) = minimizer(energy)
        
        if energy.position.domain[0] != space:
            raise TypeError
        if type(convergence) != int:
            raise TypeError

    @expand(product(spaces, [10,  100, 1000], [1E-3, 1E-4, 1E-5], [2, 3, 4] ))
    def test_property(self, space, iteration_limit, convergence_tolerance, 
                      convergence_level):
        
        x = Field.from_random('normal', domain=space)
        test_x = Field(space, val = 0.)
        N = DiagonalOperator(space, diagonal = 1.)
        energy = QuadraticPot(position=x , N=N)
        
        minimizer = RelaxedNewton(iteration_limit=iteration_limit,
                                    convergence_tolerance=convergence_tolerance, 
                                    convergence_level=convergence_level)
                                    
        (energy, convergence) = minimizer(energy)
        
        assert_almost_equal(energy.value, 0, significant=3)
        assert_almost_equal(energy.position.val.get_full_data(), 
                            test_x.val.get_full_data(), significant=3)
        assert_equal(convergence, convergence_level+2)