test_diagonal_operator.py 5.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import unittest

import numpy as np
from numpy.testing import assert_equal,\
    assert_allclose,\
    assert_approx_equal

from nifty import Field,\
    DiagonalOperator

11
from test.common import generate_spaces
12
13
14
15
16

from itertools import product
from test.common import expand

class DiagonalOperator_Tests(unittest.TestCase):
17
    spaces = generate_spaces()
18
19
20
21
22
23
24

    @expand(product(spaces, [True, False], [True, False]))
    def test_property(self, space, bare, copy):
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag)
        if D.domain[0] != space:
            raise TypeError
25
        if D.unitary != False:
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
            raise TypeError
        if D.self_adjoint != True:
            raise TypeError

    @expand(product(spaces, [True, False], [True, False]))
    def test_times_adjoint(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        rand2 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt1 = rand1.dot(D.times(rand2))
        tt2 = rand2.dot(D.times(rand1))
        assert_approx_equal(tt1, tt2)

    @expand(product(spaces, [True, False], [True, False]))
    def test_times_inverse(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt1 = D.times(D.inverse_times(rand1))
46
        assert_allclose(rand1.val.get_full_data(), tt1.val.get_full_data())
47
48
49
50
51
52
53

    @expand(product(spaces, [True, False], [True, False]))
    def test_times(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt = D.times(rand1)
54
        assert_equal(tt.domain[0], space)
55
56
57
58
59
60
61

    @expand(product(spaces, [True, False], [True, False]))
    def test_adjoint_times(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt = D.adjoint_times(rand1)
62
        assert_equal(tt.domain[0], space)
63
64
65
66
67
68
69

    @expand(product(spaces, [True, False], [True, False]))
    def test_inverse_times(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt = D.inverse_times(rand1)
70
        assert_equal(tt.domain[0], space)
71
72
73
74
75
76
77

    @expand(product(spaces, [True, False], [True, False]))
    def test_adjoint_inverse_times(self, space, bare, copy):
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        tt = D.adjoint_inverse_times(rand1)
78
        assert_equal(tt.domain[0], space)
79

80
81
    @expand(product(spaces, [True, False]))
    def test_diagonal(self, space, copy):
82
        diag = Field.from_random('normal', domain=space)
83
84
        D = DiagonalOperator(space, diagonal=diag, copy=copy)
        diag_op = D.diagonal()
85
86
        assert_allclose(diag.val.get_full_data(), diag_op.val.get_full_data())

87
88
    @expand(product(spaces, [True, False]))
    def test_inverse(self, space, copy):
89
        diag = Field.from_random('normal', domain=space)
90
91
        D = DiagonalOperator(space, diagonal=diag, copy=copy)
        diag_op = D.inverse_diagonal()
92
93
        assert_allclose(1./diag.val.get_full_data(), diag_op.val.get_full_data())

94
95
    @expand(product(spaces, [True, False]))
    def test_trace(self, space, copy):
96
        diag = Field.from_random('normal', domain=space)
97
98
        D = DiagonalOperator(space, diagonal=diag, copy=copy)
        trace_op = D.trace()
99
        assert_allclose(trace_op, np.sum(diag.val.get_full_data()))
100

101
102
    @expand(product(spaces, [True, False]))
    def test_inverse_trace(self, space, copy):
103
        diag = Field.from_random('normal', domain=space)
104
105
        D = DiagonalOperator(space, diagonal=diag, copy=copy)
        trace_op = D.inverse_trace()
106
        assert_allclose(trace_op, np.sum(1./diag.val.get_full_data()))
107

108
109
    @expand(product(spaces, [True, False]))
    def test_trace_log(self, space, copy):
110
        diag = Field.from_random('normal', domain=space)
111
        D = DiagonalOperator(space, diagonal=diag, copy=copy)
112
        trace_log = D.trace_log()
Pumpe, Daniel (dpumpe)'s avatar
update    
Pumpe, Daniel (dpumpe) committed
113
        assert_allclose(trace_log, np.log(np.sum(diag.val.get_full_data())))
114

115
116
    @expand(product(spaces, [True, False]))
    def test_determinant(self, space, copy):
117
118
119
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        det = D.determinant()
120
        assert_allclose(det, np.prod(diag.val.get_full_data()))
121
122
123
124
125
126
127
128
129
130
131
132
133
134

    @expand(product(spaces, [True, False], [True, False]))
    def test_inverse_determinant(self, space, bare, copy):
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        inv_det = D.inverse_determinant()
        assert_allclose(inv_det, 1./D.determinant())

    @expand(product(spaces, [True, False], [True, False]))
    def test_log_determinant(self, space, bare, copy):
        diag = Field.from_random('normal', domain=space)
        D = DiagonalOperator(space, diagonal=diag, bare=bare, copy=copy)
        log_det = D.log_determinant()
        assert_allclose(log_det, np.log(D.determinant()))