test_diagonal_operator.py 3.78 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
from __future__ import division
2
3
4
5
6
7
8
9
10
11
import unittest

import numpy as np
from numpy.testing import assert_equal,\
    assert_allclose,\
    assert_approx_equal

from nifty import Field,\
    DiagonalOperator

12
from test.common import generate_spaces
13
14
15
16
17

from itertools import product
from test.common import expand

class DiagonalOperator_Tests(unittest.TestCase):
18
    spaces = generate_spaces()
19
20
21
22
23
24
25

    @expand(product(spaces, [True, False], [True, False]))
    def test_property(self, space, bare, copy):
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag)
        if D.domain[0] != space:
            raise TypeError
26
        if D.unitary != False:
27
28
29
30
31
32
33
34
35
36
            raise TypeError
        if D.self_adjoint != True:
            raise TypeError

    @expand(product(spaces, [True, False], [True, False]))
    def test_times_adjoint(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        rand2 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
Martin Reinecke's avatar
Martin Reinecke committed
37
38
        tt1 = rand1.vdot(D.times(rand2))
        tt2 = rand2.vdot(D.times(rand1))
39
40
41
42
43
44
45
46
        assert_approx_equal(tt1, tt2)

    @expand(product(spaces, [True, False], [True, False]))
    def test_times_inverse(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt1 = D.times(D.inverse_times(rand1))
47
        assert_allclose(rand1.val.get_full_data(), tt1.val.get_full_data())
48
49
50
51
52
53
54

    @expand(product(spaces, [True, False], [True, False]))
    def test_times(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt = D.times(rand1)
55
        assert_equal(tt.domain[0], space)
56
57
58
59
60
61
62

    @expand(product(spaces, [True, False], [True, False]))
    def test_adjoint_times(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt = D.adjoint_times(rand1)
63
        assert_equal(tt.domain[0], space)
64
65
66
67
68
69
70

    @expand(product(spaces, [True, False], [True, False]))
    def test_inverse_times(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt = D.inverse_times(rand1)
71
        assert_equal(tt.domain[0], space)
72
73
74
75
76
77
78

    @expand(product(spaces, [True, False], [True, False]))
    def test_adjoint_inverse_times(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt = D.adjoint_inverse_times(rand1)
79
        assert_equal(tt.domain[0], space)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    @expand(product(spaces, [True, False]))
    def test_diagonal(self, space, copy):
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, copy=copy)
        diag_op = D.diagonal()
        assert_allclose(diag.val.get_full_data(), diag_op.val.get_full_data())

    @expand(product(spaces, [True, False]))
    def test_inverse(self, space, copy):
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, copy=copy)
        diag_op = D.inverse_diagonal()
        assert_allclose(1./diag.val.get_full_data(),
                        diag_op.val.get_full_data())