test_composed_operator.py 1.57 KB
Newer Older
1
import unittest
Martin Reinecke's avatar
Martin Reinecke committed
2 3
from numpy.testing import assert_allclose
from nifty2go import Field, DiagonalOperator, ComposedOperator
4 5 6
from test.common import generate_spaces
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
7
from nifty2go.dobj import to_global_data as to_np
Martin Reinecke's avatar
Martin Reinecke committed
8

9 10 11 12

class ComposedOperator_Tests(unittest.TestCase):
    spaces = generate_spaces()

Martin Reinecke's avatar
Martin Reinecke committed
13
    @expand(product(spaces, spaces))
14
    def test_times_adjoint_times(self, space1, space2):
15
        cspace = (space1, space2)
16 17
        diag1 = Field.from_random('normal', domain=space1)
        diag2 = Field.from_random('normal', domain=space2)
18 19
        op1 = DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = DiagonalOperator(diag2, cspace, spaces=(1,))
20 21 22

        op = ComposedOperator((op1, op2))

Martin Reinecke's avatar
Martin Reinecke committed
23 24
        rand1 = Field.from_random('normal', domain=(space1, space2))
        rand2 = Field.from_random('normal', domain=(space1, space2))
25

Martin Reinecke's avatar
Martin Reinecke committed
26 27
        tt1 = rand2.vdot(op.times(rand1))
        tt2 = rand1.vdot(op.adjoint_times(rand2))
Martin Reinecke's avatar
Martin Reinecke committed
28
        assert_allclose(tt1, tt2)
29

30
    @expand(product(spaces, spaces))
31
    def test_times_inverse_times(self, space1, space2):
32
        cspace = (space1, space2)
33 34
        diag1 = Field.from_random('normal', domain=space1)
        diag2 = Field.from_random('normal', domain=space2)
35 36
        op1 = DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = DiagonalOperator(diag2, cspace, spaces=(1,))
37 38 39 40 41 42

        op = ComposedOperator((op1, op2))

        rand1 = Field.from_random('normal', domain=(space1, space2))
        tt1 = op.inverse_times(op.times(rand1))

Martin Reinecke's avatar
Martin Reinecke committed
43
        assert_allclose(to_np(tt1.val), to_np(rand1.val))