test_diagonal_operator.py 3.63 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
16
17
18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
20

21
22
23
import unittest
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
24

25
26
27
import nifty5 as ift
from numpy.testing import assert_allclose, assert_equal

28
29

class DiagonalOperator_Tests(unittest.TestCase):
Martin Reinecke's avatar
Martin Reinecke committed
30
31
32
    spaces = [ift.RGSpace(4),
              ift.PowerSpace(ift.RGSpace((4, 4), harmonic=True)),
              ift.LMSpace(5), ift.HPSpace(4), ift.GLSpace(4)]
33

34
35
    @expand(product(spaces))
    def test_property(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
36
37
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
38
39
40
        if D.domain[0] != space:
            raise TypeError

41
42
    @expand(product(spaces))
    def test_times_adjoint(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
43
44
45
46
        rand1 = ift.Field.from_random('normal', domain=space)
        rand2 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
Martin Reinecke's avatar
Martin Reinecke committed
47
48
        tt1 = rand1.vdot(D.times(rand2))
        tt2 = rand2.vdot(D.times(rand1))
Martin Reinecke's avatar
Martin Reinecke committed
49
        assert_allclose(tt1, tt2)
50

51
52
    @expand(product(spaces))
    def test_times_inverse(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
53
54
55
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
56
        tt1 = D.times(D.inverse_times(rand1))
57
        assert_allclose(rand1.local_data, tt1.local_data)
58

59
60
    @expand(product(spaces))
    def test_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
61
62
63
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
64
        tt = D.times(rand1)
65
        assert_equal(tt.domain[0], space)
66

67
68
    @expand(product(spaces))
    def test_adjoint_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
69
70
71
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
72
        tt = D.adjoint_times(rand1)
73
        assert_equal(tt.domain[0], space)
74

75
76
    @expand(product(spaces))
    def test_inverse_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
77
78
79
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
80
        tt = D.inverse_times(rand1)
81
        assert_equal(tt.domain[0], space)
82

83
84
    @expand(product(spaces))
    def test_adjoint_inverse_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
85
86
87
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
88
        tt = D.adjoint_inverse_times(rand1)
89
        assert_equal(tt.domain[0], space)
90

91
92
    @expand(product(spaces))
    def test_diagonal(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
93
94
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
95
        diag_op = D(ift.Field.full(space, 1.))
96
        assert_allclose(diag.local_data, diag_op.local_data)