test_diagonal_operator.py 3.65 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import division
20
import unittest
Martin Reinecke's avatar
Martin Reinecke committed
21
from numpy.testing import assert_equal, assert_allclose
Martin Reinecke's avatar
Martin Reinecke committed
22
import nifty2go as ift
23
from test.common import generate_spaces
24 25
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
26

27 28

class DiagonalOperator_Tests(unittest.TestCase):
29
    spaces = generate_spaces()
30

31 32
    @expand(product(spaces))
    def test_property(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
33 34
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
35 36 37
        if D.domain[0] != space:
            raise TypeError

38 39
    @expand(product(spaces))
    def test_times_adjoint(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
40 41 42 43
        rand1 = ift.Field.from_random('normal', domain=space)
        rand2 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
Martin Reinecke's avatar
Martin Reinecke committed
44 45
        tt1 = rand1.vdot(D.times(rand2))
        tt2 = rand2.vdot(D.times(rand1))
Martin Reinecke's avatar
Martin Reinecke committed
46
        assert_allclose(tt1, tt2)
47

48 49
    @expand(product(spaces))
    def test_times_inverse(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
50 51 52
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
53
        tt1 = D.times(D.inverse_times(rand1))
Martin Reinecke's avatar
Martin Reinecke committed
54 55
        assert_allclose(ift.dobj.to_global_data(rand1.val),
                        ift.dobj.to_global_data(tt1.val))
56

57 58
    @expand(product(spaces))
    def test_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
59 60 61
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
62
        tt = D.times(rand1)
63
        assert_equal(tt.domain[0], space)
64

65 66
    @expand(product(spaces))
    def test_adjoint_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
67 68 69
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
70
        tt = D.adjoint_times(rand1)
71
        assert_equal(tt.domain[0], space)
72

73 74
    @expand(product(spaces))
    def test_inverse_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
75 76 77
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
78
        tt = D.inverse_times(rand1)
79
        assert_equal(tt.domain[0], space)
80

81 82
    @expand(product(spaces))
    def test_adjoint_inverse_times(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
83 84 85
        rand1 = ift.Field.from_random('normal', domain=space)
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
86
        tt = D.adjoint_inverse_times(rand1)
87
        assert_equal(tt.domain[0], space)
88

89 90
    @expand(product(spaces))
    def test_diagonal(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
91 92
        diag = ift.Field.from_random('normal', domain=space)
        D = ift.DiagonalOperator(diag)
93
        diag_op = D.diagonal()
Martin Reinecke's avatar
Martin Reinecke committed
94 95
        assert_allclose(ift.dobj.to_global_data(diag.val),
                        ift.dobj.to_global_data(diag_op.val))