test_composed_operator.py 3.54 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

19
20
21
import unittest
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
22

23
24
25
import nifty5 as ift
from numpy.testing import assert_allclose, assert_equal

26
27

class ComposedOperator_Tests(unittest.TestCase):
Martin Reinecke's avatar
Martin Reinecke committed
28
29
30
    spaces = [ift.RGSpace(4),
              ift.PowerSpace(ift.RGSpace((4, 4), harmonic=True)),
              ift.LMSpace(5), ift.HPSpace(4), ift.GLSpace(4)]
31

Martin Reinecke's avatar
Martin Reinecke committed
32
    @expand(product(spaces, spaces))
33
    def test_times_adjoint_times(self, space1, space2):
34
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
35
36
37
38
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
39

Martin Reinecke's avatar
Martin Reinecke committed
40
        op = op2(op1)
41

Martin Reinecke's avatar
Martin Reinecke committed
42
43
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
        rand2 = ift.Field.from_random('normal', domain=(space1, space2))
44

Martin Reinecke's avatar
Martin Reinecke committed
45
46
        tt1 = rand2.vdot(op.times(rand1))
        tt2 = rand1.vdot(op.adjoint_times(rand2))
Martin Reinecke's avatar
Martin Reinecke committed
47
        assert_allclose(tt1, tt2)
48

49
    @expand(product(spaces, spaces))
50
    def test_times_inverse_times(self, space1, space2):
51
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
52
53
54
55
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
56

Martin Reinecke's avatar
Martin Reinecke committed
57
        op = op2(op1)
58

Martin Reinecke's avatar
Martin Reinecke committed
59
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
60
61
        tt1 = op.inverse_times(op.times(rand1))

62
        assert_allclose(tt1.local_data, rand1.local_data)
63
64
65

    @expand(product(spaces))
    def test_sum(self, space):
66
        op1 = ift.makeOp(ift.Field.full(space, 2.))
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
67
        op2 = ift.ScalingOperator(3., space)
68
        full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2
Martin Reinecke's avatar
Martin Reinecke committed
69
        x = ift.Field.full(space, 1.)
70
71
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
72
        assert_allclose(res.local_data, 11.)
73
74
75

    @expand(product(spaces))
    def test_chain(self, space):
76
        op1 = ift.makeOp(ift.Field.full(space, 2.))
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
77
        op2 = ift.ScalingOperator(3., space)
Martin Reinecke's avatar
Martin Reinecke committed
78
        full_op = op1(op2)(op2)(op1)(op1)(op1)(op2)
Martin Reinecke's avatar
Martin Reinecke committed
79
        x = ift.Field.full(space, 1.)
80
81
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
82
        assert_allclose(res.local_data, 432.)
Martin Reinecke's avatar
Martin Reinecke committed
83
84
85

    @expand(product(spaces))
    def test_mix(self, space):
86
        op1 = ift.makeOp(ift.Field.full(space, 2.))
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
87
        op2 = ift.ScalingOperator(3., space)
Martin Reinecke's avatar
Martin Reinecke committed
88
        full_op = op1(op2+op2)(op1)(op1) - op1(op2)
Martin Reinecke's avatar
Martin Reinecke committed
89
        x = ift.Field.full(space, 1.)
Martin Reinecke's avatar
Martin Reinecke committed
90
91
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
92
        assert_allclose(res.local_data, 42.)