test_diagonal_operator.py 3.63 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
16
17
18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
20
import unittest
Martin Reinecke's avatar
Martin Reinecke committed
21
from numpy.testing import assert_equal, assert_allclose
Martin Reinecke's avatar
Martin Reinecke committed
22
import nifty4 as ift
23
24
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
25

26
27

class DiagonalOperator_Tests(unittest.TestCase):
Martin Reinecke's avatar
Martin Reinecke committed
28
29
30
    spaces = [ift.RGSpace(4),
              ift.PowerSpace(ift.RGSpace((4, 4), harmonic=True)),
              ift.LMSpace(5), ift.HPSpace(4), ift.GLSpace(4)]
31

32
33
    @expand(product(spaces))
    def test_property(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
34
35
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
36
37
38
        if D.domain[0] != space:
            raise TypeError

39
40
    @expand(product(spaces))
    def test_times_adjoint(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
41
42
43
44
        rand1 = ift.Field.from_random('normal', domain=space)
        rand2 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
Martin Reinecke's avatar
Martin Reinecke committed
45
46
        tt1 = rand1.vdot(D.times(rand2))
        tt2 = rand2.vdot(D.times(rand1))
Martin Reinecke's avatar
Martin Reinecke committed
47
        assert_allclose(tt1, tt2)
48

49
50
    @expand(product(spaces))
    def test_times_inverse(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
51
52
53
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
54
        tt1 = D.times(D.inverse_times(rand1))
Martin Reinecke's avatar
Martin Reinecke committed
55
        assert_allclose(rand1.to_global_data(), tt1.to_global_data())
56

57
58
    @expand(product(spaces))
    def test_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
59
60
61
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
62
        tt = D.times(rand1)
63
        assert_equal(tt.domain[0], space)
64

65
66
    @expand(product(spaces))
    def test_adjoint_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
67
68
69
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
70
        tt = D.adjoint_times(rand1)
71
        assert_equal(tt.domain[0], space)
72

73
74
    @expand(product(spaces))
    def test_inverse_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
75
76
77
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
78
        tt = D.inverse_times(rand1)
79
        assert_equal(tt.domain[0], space)
80

81
82
    @expand(product(spaces))
    def test_adjoint_inverse_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
83
84
85
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
86
        tt = D.adjoint_inverse_times(rand1)
87
        assert_equal(tt.domain[0], space)
88

89
90
    @expand(product(spaces))
    def test_diagonal(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
91
92
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
Martin Reinecke's avatar
Martin Reinecke committed
93
        diag_op = D.diagonal
Martin Reinecke's avatar
Martin Reinecke committed
94
        assert_allclose(diag.to_global_data(), diag_op.to_global_data())