test_composed_operator.py 3.58 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

19
import unittest
20
from numpy.testing import assert_allclose, assert_equal
Martin Reinecke's avatar
Martin Reinecke committed
21
import nifty2go as ift
22 23 24
from test.common import generate_spaces
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
25

26 27 28 29

class ComposedOperator_Tests(unittest.TestCase):
    spaces = generate_spaces()

Martin Reinecke's avatar
Martin Reinecke committed
30
    @expand(product(spaces, spaces))
31
    def test_times_adjoint_times(self, space1, space2):
32
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
33 34 35 36
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
37

Martin Reinecke's avatar
Martin Reinecke committed
38
        op = op2*op1
39

Martin Reinecke's avatar
Martin Reinecke committed
40 41
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
        rand2 = ift.Field.from_random('normal', domain=(space1, space2))
42

Martin Reinecke's avatar
Martin Reinecke committed
43 44
        tt1 = rand2.vdot(op.times(rand1))
        tt2 = rand1.vdot(op.adjoint_times(rand2))
Martin Reinecke's avatar
Martin Reinecke committed
45
        assert_allclose(tt1, tt2)
46

47
    @expand(product(spaces, spaces))
48
    def test_times_inverse_times(self, space1, space2):
49
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
50 51 52 53
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
54

Martin Reinecke's avatar
Martin Reinecke committed
55
        op = op2*op1
56

Martin Reinecke's avatar
Martin Reinecke committed
57
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
58 59
        tt1 = op.inverse_times(op.times(rand1))

Martin Reinecke's avatar
Martin Reinecke committed
60 61
        assert_allclose(ift.dobj.to_global_data(tt1.val),
                        ift.dobj.to_global_data(rand1.val))
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

    @expand(product(spaces))
    def test_sum(self, space):
        op1 = ift.DiagonalOperator(ift.Field(space, 2.))
        op2 = ift.ScalingOperator(3., space)
        full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2
        x = ift.Field(space, 1.)
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
        assert_allclose(ift.dobj.to_global_data(res.val), 11.)

    @expand(product(spaces))
    def test_chain(self, space):
        op1 = ift.DiagonalOperator(ift.Field(space, 2.))
        op2 = ift.ScalingOperator(3., space)
        full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2
        x = ift.Field(space, 1.)
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
        assert_allclose(ift.dobj.to_global_data(res.val), 432.)
Martin Reinecke's avatar
Martin Reinecke committed
82 83 84 85 86 87 88 89 90 91

    @expand(product(spaces))
    def test_mix(self, space):
        op1 = ift.DiagonalOperator(ift.Field(space, 2.))
        op2 = ift.ScalingOperator(3., space)
        full_op = op1 * (op2 + op2) * op1 * op1 - op1 * op2
        x = ift.Field(space, 1.)
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
        assert_allclose(ift.dobj.to_global_data(res.val), 42.)