test_diagonal_operator.py 2.89 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
from __future__ import division
2
3
4
5
6
import unittest
import numpy as np
from numpy.testing import assert_equal,\
    assert_allclose,\
    assert_approx_equal
Martin Reinecke's avatar
Martin Reinecke committed
7
from nifty2go import Field, DiagonalOperator
8
from test.common import generate_spaces
9
10
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
11
from nifty2go.dobj import to_ndarray as to_np, from_ndarray as from_np
12
13

class DiagonalOperator_Tests(unittest.TestCase):
14
    spaces = generate_spaces()
15

16
17
    @expand(product(spaces))
    def test_property(self, space):
18
        diag = Field.from_random('normal', domain=space)
Martin Reinecke's avatar
Martin Reinecke committed
19
        D = DiagonalOperator(diag)
20
21
        if D.domain[0] != space:
            raise TypeError
22
        if D.unitary != False:
23
24
25
26
            raise TypeError
        if D.self_adjoint != True:
            raise TypeError

27
28
    @expand(product(spaces))
    def test_times_adjoint(self, space):
29
30
31
        rand1 = Field.from_random('normal', domain=space)
        rand2 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
32
        D = DiagonalOperator(diag)
Martin Reinecke's avatar
Martin Reinecke committed
33
34
        tt1 = rand1.vdot(D.times(rand2))
        tt2 = rand2.vdot(D.times(rand1))
35
36
        assert_approx_equal(tt1, tt2)

37
38
    @expand(product(spaces))
    def test_times_inverse(self, space):
39
40
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
41
        D = DiagonalOperator(diag)
42
        tt1 = D.times(D.inverse_times(rand1))
Martin Reinecke's avatar
Martin Reinecke committed
43
        assert_allclose(to_np(rand1.val), to_np(tt1.val))
44

45
46
    @expand(product(spaces))
    def test_times(self, space):
47
48
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
49
        D = DiagonalOperator(diag)
50
        tt = D.times(rand1)
51
        assert_equal(tt.domain[0], space)
52

53
54
    @expand(product(spaces))
    def test_adjoint_times(self, space):
55
56
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
57
        D = DiagonalOperator(diag)
58
        tt = D.adjoint_times(rand1)
59
        assert_equal(tt.domain[0], space)
60

61
62
    @expand(product(spaces))
    def test_inverse_times(self, space):
63
64
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
65
        D = DiagonalOperator(diag)
66
        tt = D.inverse_times(rand1)
67
        assert_equal(tt.domain[0], space)
68

69
70
    @expand(product(spaces))
    def test_adjoint_inverse_times(self, space):
71
72
        rand1 = Field.from_random('normal', domain=space)
        diag = Field.from_random('normal', domain=space)
73
        D = DiagonalOperator(diag)
74
        tt = D.adjoint_inverse_times(rand1)
75
        assert_equal(tt.domain[0], space)
76

77
78
    @expand(product(spaces))
    def test_diagonal(self, space):
79
        diag = Field.from_random('normal', domain=space)
80
        D = DiagonalOperator(diag)
81
        diag_op = D.diagonal()
Martin Reinecke's avatar
Martin Reinecke committed
82
        assert_allclose(to_np(diag.val), to_np(diag_op.val))