test_composed_operator.py 1.65 KB
Newer Older
1
2
3
4
import unittest
from numpy.testing import assert_equal,\
    assert_allclose,\
    assert_approx_equal
Martin Reinecke's avatar
Martin Reinecke committed
5
from nifty2go import Field,\
6
7
8
9
10
    DiagonalOperator,\
    ComposedOperator
from test.common import generate_spaces
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
11
12
from nifty2go.dobj import to_ndarray as to_np, from_ndarray as from_np

13
14
15
16

class ComposedOperator_Tests(unittest.TestCase):
    spaces = generate_spaces()

17
    @expand(product(spaces,spaces))
18
    def test_times_adjoint_times(self, space1, space2):
19
        cspace = (space1, space2)
20
21
        diag1 = Field.from_random('normal', domain=space1)
        diag2 = Field.from_random('normal', domain=space2)
22
23
        op1 = DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = DiagonalOperator(diag2, cspace, spaces=(1,))
24
25
26
27
28
29

        op = ComposedOperator((op1, op2))

        rand1 = Field.from_random('normal', domain=(space1,space2))
        rand2 = Field.from_random('normal', domain=(space1,space2))

Martin Reinecke's avatar
Martin Reinecke committed
30
31
        tt1 = rand2.vdot(op.times(rand1))
        tt2 = rand1.vdot(op.adjoint_times(rand2))
32
33
        assert_approx_equal(tt1, tt2)

34
    @expand(product(spaces, spaces))
35
    def test_times_inverse_times(self, space1, space2):
36
        cspace = (space1, space2)
37
38
        diag1 = Field.from_random('normal', domain=space1)
        diag2 = Field.from_random('normal', domain=space2)
39
40
        op1 = DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = DiagonalOperator(diag2, cspace, spaces=(1,))
41
42
43
44
45
46

        op = ComposedOperator((op1, op2))

        rand1 = Field.from_random('normal', domain=(space1, space2))
        tt1 = op.inverse_times(op.times(rand1))

Martin Reinecke's avatar
Martin Reinecke committed
47
        assert_allclose(to_np(tt1.val), to_np(rand1.val))
48