test_composed_operator.py 1.57 KB
Newer Older
1
import unittest
Martin Reinecke's avatar
Martin Reinecke committed
2
3
from numpy.testing import assert_allclose
from nifty2go import Field, DiagonalOperator, ComposedOperator
4
5
6
from test.common import generate_spaces
from itertools import product
from test.common import expand
Martin Reinecke's avatar
Martin Reinecke committed
7
from nifty2go.dobj import to_ndarray as to_np
Martin Reinecke's avatar
Martin Reinecke committed
8

9
10
11
12

class ComposedOperator_Tests(unittest.TestCase):
    spaces = generate_spaces()

Martin Reinecke's avatar
Martin Reinecke committed
13
    @expand(product(spaces, spaces))
14
    def test_times_adjoint_times(self, space1, space2):
15
        cspace = (space1, space2)
16
17
        diag1 = Field.from_random('normal', domain=space1)
        diag2 = Field.from_random('normal', domain=space2)
18
19
        op1 = DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = DiagonalOperator(diag2, cspace, spaces=(1,))
20
21
22

        op = ComposedOperator((op1, op2))

Martin Reinecke's avatar
Martin Reinecke committed
23
24
        rand1 = Field.from_random('normal', domain=(space1, space2))
        rand2 = Field.from_random('normal', domain=(space1, space2))
25

Martin Reinecke's avatar
Martin Reinecke committed
26
27
        tt1 = rand2.vdot(op.times(rand1))
        tt2 = rand1.vdot(op.adjoint_times(rand2))
Martin Reinecke's avatar
Martin Reinecke committed
28
        assert_allclose(tt1, tt2)
29

30
    @expand(product(spaces, spaces))
31
    def test_times_inverse_times(self, space1, space2):
32
        cspace = (space1, space2)
33
34
        diag1 = Field.from_random('normal', domain=space1)
        diag2 = Field.from_random('normal', domain=space2)
35
36
        op1 = DiagonalOperator(diag1, cspace, spaces=(0,))
        op2 = DiagonalOperator(diag2, cspace, spaces=(1,))
37
38
39
40
41
42

        op = ComposedOperator((op1, op2))

        rand1 = Field.from_random('normal', domain=(space1, space2))
        tt1 = op.inverse_times(op.times(rand1))

Martin Reinecke's avatar
Martin Reinecke committed
43
        assert_allclose(to_np(tt1.val), to_np(rand1.val))