test_composed_operator.py 2.03 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import unittest

from numpy.testing import assert_equal,\
    assert_allclose,\
    assert_approx_equal

from nifty import Field,\
    DiagonalOperator,\
    ComposedOperator

from test.common import generate_spaces

from itertools import product
from test.common import expand

class ComposedOperator_Tests(unittest.TestCase):
    spaces = generate_spaces()

19
    @expand(product([spaces], [spaces]))
20 21 22 23 24 25 26 27 28 29 30
    def test_property(self, space1, space2):
        rand1 = Field.from_random('normal', domain=space1)
        rand2 = Field.from_random('normal', domain=space2)
        op1 = DiagonalOperator(space1, diagonal=rand1)
        op2 = DiagonalOperator(space2, diagonal=rand2)
        op = ComposedOperator((op1, op2))
        if op.domain != (op1.domain, op2.domain):
            raise TypeError
        if op.unitary != False:
            raise ValueError

31
    @expand(product([spaces],[spaces]))
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    def test_times_adjoint_times(self, space1, space2):
        diag1 = Field.from_random('normal', domain=space1)
        diag2 = Field.from_random('normal', domain=space2)
        op1 = DiagonalOperator(space1, diagonal=diag1)
        op2 = DiagonalOperator(space2, diagonal=diag2)

        op = ComposedOperator((op1, op2))

        rand1 = Field.from_random('normal', domain=(space1,space2))
        rand2 = Field.from_random('normal', domain=(space1,space2))

        tt1 = rand2.dot(op.times(rand1))
        tt2 = rand1.dot(op.adjoint_times(rand2))
        assert_approx_equal(tt1, tt2)

47
    @expand(product([spaces], [spaces]))
48 49 50 51 52 53 54 55 56 57 58 59 60 61
    def test_times_inverse_times(self, space1, space2):
        diag1 = Field.from_random('normal', domain=space1)
        diag2 = Field.from_random('normal', domain=space2)
        op1 = DiagonalOperator(space1, diagonal=diag1)
        op2 = DiagonalOperator(space2, diagonal=diag2)

        op = ComposedOperator((op1, op2))

        rand1 = Field.from_random('normal', domain=(space1, space2))
        tt1 = op.inverse_times(op.times(rand1))

        assert_allclose(tt1.val.get_full_data(),
                        rand1.val.get_full_data())