Commit 0dd8321d authored by Neel Shah's avatar Neel Shah
Browse files

Testing GeneralMatrixProduct and fixes in MatrixProductOperator

parent c2dd59ae
Pipeline #104579 canceled with stages
import numpy as np
import pytest
import nifty8 as ift
# The below files are temporarily copied in the test directory as I couldn't
# do a relative import from the operators directory.
from .general_matrix_product import GeneralMatrixProduct
from .old_matrix_product_operator import OldMatrixProductOperator
from ..common import list2fixture
dtuple = ift.DomainTuple.make((ift.RGSpace(2, 0.4), ift.RGSpace(3, 0.3),
ift.RGSpace(4, 0.6), ift.RGSpace(5, 0.2)))
pmp = pytest.mark.parametrize
dtype = list2fixture([np.float64, np.complex128])
# check that the fixes in MatrixProductOperator don't contradict the previous
# version when the previous version works
def metatestMatrixProductOperator_fixes(sp, mat_shape, seed, **kwargs):
with ift.random.Context(seed):
mat = ift.random.current_rng().standard_normal(mat_shape)
op1 = OldMatrixProductOperator(sp, mat, **kwargs)
op2 = ift.MatrixProductOperator(sp, mat, **kwargs)
field = ift.Field.from_random(sp)
ift.extra.assert_equal(op1.times(field), op2.times(field))
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op1 = OldMatrixProductOperator(sp, mat, **kwargs)
op2 = ift.MatrixProductOperator(sp, mat, **kwargs)
ift.extra.assert_equal(op1.times(field), op2.times(field))
ift.extra.assert_equal(op1.adjoint_times(field), op2.adjoint_times(field))
def metatestGeneralMatrixProduct(sp, mat_shape, seed, spaces=None, target=None, flatten=False):
with ift.random.Context(seed):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
ift.extra.check_linear_operator(op)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
ift.extra.check_linear_operator(op)
# test operator's self-consistency and also that it matches with the
# endomorphic MatrixProductOperator
def metatestGeneralMatrixProduct_endomorphic(sp, mat_shape, seed, spaces, target, flatten=False):
with ift.random.Context(seed):
mat = ift.random.current_rng().standard_normal(mat_shape)
op1 = ift.MatrixProductOperator(sp, mat, spaces=spaces, flatten=flatten)
op2 = GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
field = ift.Field.from_random(sp)
ift.extra.check_linear_operator(op2)
ift.extra.assert_equal(op1.times(field), op2.times(field))
ift.extra.assert_equal(op1.adjoint_times(field), op2.adjoint_times(field))
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op1 = ift.MatrixProductOperator(sp, mat, spaces=spaces, flatten=flatten)
op2 = GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
ift.extra.check_linear_operator(op2)
ift.extra.assert_equal(op1.times(field), op2.times(field))
ift.extra.assert_equal(op1.adjoint_times(field), op2.adjoint_times(field))
@pmp('sp', [ift.RGSpace(10)])
@pmp('spaces', [None, (0,)])
@pmp('seed', [12, 3])
def testGeneralMatrixProduct_1d(sp, spaces, seed):
mat_shape = sp.shape * 2
metatestGeneralMatrixProduct_endomorphic(sp, mat_shape, seed, spaces=spaces, target=sp)
@pmp('sp', [ift.DomainTuple.make((ift.RGSpace(2, 0.3), ift.RGSpace(10, 0.2)))])
@pmp('spaces', [None, (0,), (1,), (0, 1)])
@pmp('seed', [12, 3])
def testGeneralMatrixProduct_2d_spaces(sp, spaces, seed):
appl_shape = ()
if spaces != None:
for sp_idx in spaces:
appl_shape += sp[sp_idx].shape
else:
appl_shape = sp.shape
mat_shape = appl_shape * 2
metatestGeneralMatrixProduct_endomorphic(sp, mat_shape, seed, spaces=spaces, target=sp)
@pmp('sp', [dtuple])
@pmp('seed', [12, 3])
def testGeneralMatrixProduct_4d_flatten(sp, seed):
appl_shape = (ift.utilities.my_product(sp.shape),)
mat_shape = appl_shape * 2
metatestGeneralMatrixProduct_endomorphic(sp, mat_shape, seed, spaces=None,
target=None, flatten=True)
@pmp('sp', [dtuple])
@pmp('spaces', [(1, 3), None, (2,)])
@pmp('mat_shape', [(5, 15, 25)])
@pmp('seed', [12, 3])
def testGeneralMatrixProduct(sp, mat_shape, spaces, seed):
if spaces != None:
for i in spaces:
mat_shape += (sp.shape[i],)
else:
mat_shape += sp.shape
metatestGeneralMatrixProduct(sp, mat_shape, seed, spaces=spaces)
@pmp('sp', [ift.RGSpace(10)])
@pmp('spaces', [None, (0,)])
@pmp('seed', [12, 3])
def testMatrixProductOperator_1d_fixes(sp, spaces, seed):
mat_shape = sp.shape * 2
metatestMatrixProductOperator_fixes(sp, mat_shape, seed, spaces=spaces)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment