test_composed_operator.py 1.55 KB
Newer Older
1
import unittest
Martin Reinecke's avatar
Martin Reinecke committed
2
from numpy.testing import assert_allclose
Martin Reinecke's avatar
Martin Reinecke committed
3
import nifty2go as ift
4 5 6
from test.common import generate_spaces
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
7

8 9 10 11

class ComposedOperator_Tests(unittest.TestCase):
    spaces = generate_spaces()

Martin Reinecke's avatar
Martin Reinecke committed
12
    @expand(product(spaces, spaces))
13
    def test_times_adjoint_times(self, space1, space2):
14
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
15 16 17 18
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
19

Martin Reinecke's avatar
Martin Reinecke committed
20
        op = op2*op1
21

Martin Reinecke's avatar
Martin Reinecke committed
22 23
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
        rand2 = ift.Field.from_random('normal', domain=(space1, space2))
24

Martin Reinecke's avatar
Martin Reinecke committed
25 26
        tt1 = rand2.vdot(op.times(rand1))
        tt2 = rand1.vdot(op.adjoint_times(rand2))
Martin Reinecke's avatar
Martin Reinecke committed
27
        assert_allclose(tt1, tt2)
28

29
    @expand(product(spaces, spaces))
30
    def test_times_inverse_times(self, space1, space2):
31
        cspace = (space1, space2)
Martin Reinecke's avatar
Martin Reinecke committed
32 33 34 35
        diag1 = ift.Field.from_random('normal', domain=space1)
        diag2 = ift.Field.from_random('normal', domain=space2)
        op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
36

Martin Reinecke's avatar
Martin Reinecke committed
37
        op = op2*op1
38

Martin Reinecke's avatar
Martin Reinecke committed
39
        rand1 = ift.Field.from_random('normal', domain=(space1, space2))
40 41
        tt1 = op.inverse_times(op.times(rand1))

Martin Reinecke's avatar
Martin Reinecke committed
42 43
        assert_allclose(ift.dobj.to_global_data(tt1.val),
                        ift.dobj.to_global_data(rand1.val))