test_composed_operator.py 3.53 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

19
import unittest
20
from numpy.testing import assert_allclose, assert_equal
Martin Reinecke's avatar
Martin Reinecke committed
21
import nifty4 as ift
22
23
24
from test.common import generate_spaces
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
25

26
27
28
29

class ComposedOperator_Tests(unittest.TestCase):
    spaces = generate_spaces()

Martin Reinecke's avatar
Martin Reinecke committed
30
    @expand(product(spaces, spaces))
31
    def test_times_adjoint_times(self, space1, space2):
32
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
33
34
35
36
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
37

Martin Reinecke's avatar
Martin Reinecke committed
38
        op = op2*op1
39

Martin Reinecke's avatar
Martin Reinecke committed
40
41
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
        rand2 = ift.Field.from_random('normal', domain=(space1, space2))
42

Martin Reinecke's avatar
Martin Reinecke committed
43
44
        tt1 = rand2.vdot(op.times(rand1))
        tt2 = rand1.vdot(op.adjoint_times(rand2))
Martin Reinecke's avatar
Martin Reinecke committed
45
        assert_allclose(tt1, tt2)
46

47
    @expand(product(spaces, spaces))
48
    def test_times_inverse_times(self, space1, space2):
49
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
50
51
52
53
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
54

Martin Reinecke's avatar
Martin Reinecke committed
55
        op = op2*op1
56

Martin Reinecke's avatar
Martin Reinecke committed
57
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
58
59
        tt1 = op.inverse_times(op.times(rand1))

Martin Reinecke's avatar
Martin Reinecke committed
60
        assert_allclose(tt1.to_global_data(), rand1.to_global_data())
61
62
63

    @expand(product(spaces))
    def test_sum(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
64
        op1 = ift.DiagonalOperator(ift.Field.full(space, 2.))
65
66
        op2 = ift.ScalingOperator(3., space)
        full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2
Martin Reinecke's avatar
Martin Reinecke committed
67
        x = ift.Field.full(space, 1.)
68
69
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
Martin Reinecke's avatar
Martin Reinecke committed
70
        assert_allclose(res.to_global_data(), 11.)
71
72
73

    @expand(product(spaces))
    def test_chain(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
74
        op1 = ift.DiagonalOperator(ift.Field.full(space, 2.))
75
76
        op2 = ift.ScalingOperator(3., space)
        full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2
Martin Reinecke's avatar
Martin Reinecke committed
77
        x = ift.Field.full(space, 1.)
78
79
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
Martin Reinecke's avatar
Martin Reinecke committed
80
        assert_allclose(res.to_global_data(), 432.)
Martin Reinecke's avatar
Martin Reinecke committed
81
82
83

    @expand(product(spaces))
    def test_mix(self, space):
Martin Reinecke's avatar
Martin Reinecke committed
84
        op1 = ift.DiagonalOperator(ift.Field.full(space, 2.))
Martin Reinecke's avatar
Martin Reinecke committed
85
86
        op2 = ift.ScalingOperator(3., space)
        full_op = op1 * (op2 + op2) * op1 * op1 - op1 * op2
Martin Reinecke's avatar
Martin Reinecke committed
87
        x = ift.Field.full(space, 1.)
Martin Reinecke's avatar
Martin Reinecke committed
88
89
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
Martin Reinecke's avatar
Martin Reinecke committed
90
        assert_allclose(res.to_global_data(), 42.)