test_composed_operator.py 2.78 KB
Newer Older
1
import unittest
2
from numpy.testing import assert_allclose, assert_equal
Martin Reinecke's avatar
Martin Reinecke committed
3
import nifty2go as ift
4 5 6
from test.common import generate_spaces
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
7

8 9 10 11

class ComposedOperator_Tests(unittest.TestCase):
    spaces = generate_spaces()

Martin Reinecke's avatar
Martin Reinecke committed
12
    @expand(product(spaces, spaces))
13
    def test_times_adjoint_times(self, space1, space2):
14
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
15 16 17 18
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
19

Martin Reinecke's avatar
Martin Reinecke committed
20
        op = op2*op1
21

Martin Reinecke's avatar
Martin Reinecke committed
22 23
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
        rand2 = ift.Field.from_random('normal', domain=(space1, space2))
24

Martin Reinecke's avatar
Martin Reinecke committed
25 26
        tt1 = rand2.vdot(op.times(rand1))
        tt2 = rand1.vdot(op.adjoint_times(rand2))
Martin Reinecke's avatar
Martin Reinecke committed
27
        assert_allclose(tt1, tt2)
28

29
    @expand(product(spaces, spaces))
30
    def test_times_inverse_times(self, space1, space2):
31
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
32 33 34 35
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
36

Martin Reinecke's avatar
Martin Reinecke committed
37
        op = op2*op1
38

Martin Reinecke's avatar
Martin Reinecke committed
39
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
40 41
        tt1 = op.inverse_times(op.times(rand1))

Martin Reinecke's avatar
Martin Reinecke committed
42 43
        assert_allclose(ift.dobj.to_global_data(tt1.val),
                        ift.dobj.to_global_data(rand1.val))
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

    @expand(product(spaces))
    def test_sum(self, space):
        op1 = ift.DiagonalOperator(ift.Field(space, 2.))
        op2 = ift.ScalingOperator(3., space)
        full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2
        x = ift.Field(space, 1.)
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
        assert_allclose(ift.dobj.to_global_data(res.val), 11.)

    @expand(product(spaces))
    def test_chain(self, space):
        op1 = ift.DiagonalOperator(ift.Field(space, 2.))
        op2 = ift.ScalingOperator(3., space)
        full_op = op1 * op2 * (op2 * op1) * op1 * op1 * op2
        x = ift.Field(space, 1.)
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
        assert_allclose(ift.dobj.to_global_data(res.val), 432.)
Martin Reinecke's avatar
Martin Reinecke committed
64 65 66 67 68 69 70 71 72 73

    @expand(product(spaces))
    def test_mix(self, space):
        op1 = ift.DiagonalOperator(ift.Field(space, 2.))
        op2 = ift.ScalingOperator(3., space)
        full_op = op1 * (op2 + op2) * op1 * op1 - op1 * op2
        x = ift.Field(space, 1.)
        res = full_op(x)
        assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
        assert_allclose(ift.dobj.to_global_data(res.val), 42.)