Commit ef08e9eb authored by Neel Shah's avatar Neel Shah
Browse files

Fixed importing procedure

parent 085c8bd5
Pipeline #104587 failed with stages
......@@ -3,11 +3,6 @@ import pytest
import nifty8 as ift
# The below files are temporarily copied in the test directory as I couldn't
# do a relative import from the operators directory.
from .general_matrix_product import GeneralMatrixProduct
from .old_matrix_product_operator import OldMatrixProductOperator
from ..common import list2fixture
dtuple = ift.DomainTuple.make((ift.RGSpace(2, 0.4), ift.RGSpace(3, 0.3),
......@@ -22,12 +17,12 @@ dtype = list2fixture([np.float64, np.complex128])
def metatestMatrixProductOperator_fixes(sp, mat_shape, seed, **kwargs):
with ift.random.Context(seed):
mat = ift.random.current_rng().standard_normal(mat_shape)
op1 = OldMatrixProductOperator(sp, mat, **kwargs)
op1 = ift.OldMatrixProductOperator(sp, mat, **kwargs)
op2 = ift.MatrixProductOperator(sp, mat, **kwargs)
field = ift.Field.from_random(sp)
ift.extra.assert_equal(op1.times(field), op2.times(field))
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op1 = OldMatrixProductOperator(sp, mat, **kwargs)
op1 = ift.OldMatrixProductOperator(sp, mat, **kwargs)
op2 = ift.MatrixProductOperator(sp, mat, **kwargs)
ift.extra.assert_equal(op1.times(field), op2.times(field))
ift.extra.assert_equal(op1.adjoint_times(field), op2.adjoint_times(field))
......@@ -36,10 +31,10 @@ def metatestMatrixProductOperator_fixes(sp, mat_shape, seed, **kwargs):
def metatestGeneralMatrixProduct(sp, mat_shape, seed, spaces=None, target=None, flatten=False):
with ift.random.Context(seed):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
op = ift.GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
ift.extra.check_linear_operator(op)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
op = ift.GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
ift.extra.check_linear_operator(op)
......@@ -49,14 +44,14 @@ def metatestGeneralMatrixProduct_endomorphic(sp, mat_shape, seed, spaces, target
with ift.random.Context(seed):
mat = ift.random.current_rng().standard_normal(mat_shape)
op1 = ift.MatrixProductOperator(sp, mat, spaces=spaces, flatten=flatten)
op2 = GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
op2 = ift.GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
field = ift.Field.from_random(sp)
ift.extra.check_linear_operator(op2)
ift.extra.assert_equal(op1.times(field), op2.times(field))
ift.extra.assert_equal(op1.adjoint_times(field), op2.adjoint_times(field))
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op1 = ift.MatrixProductOperator(sp, mat, spaces=spaces, flatten=flatten)
op2 = GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
op2 = ift.GeneralMatrixProduct(sp, mat, spaces=spaces, target=target, flatten=flatten)
ift.extra.check_linear_operator(op2)
ift.extra.assert_equal(op1.times(field), op2.times(field))
ift.extra.assert_equal(op1.adjoint_times(field), op2.adjoint_times(field))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment