Commit 6b285039 authored by Neel Shah's avatar Neel Shah
Browse files

removed print statements, added more tests

parent 4b2d8e95
Pipeline #106466 canceled with stages
......@@ -44,8 +44,6 @@ def test_matrix_product_endomorphic(domain, spaces, n_tests=4):
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
print(f'Domain shape={domain.shape}, spaces={spaces}, '+
f'matrix shape={mat_shape}, target=domain (endomorphic)')
def test_matrix_product_spaces(domain, spaces, n_tests=4):
mat_shape = (7, 8)
......@@ -60,8 +58,6 @@ def test_matrix_product_spaces(domain, spaces, n_tests=4):
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
print(f'Domain shape={domain.shape}, spaces={spaces}, '+
f'matrix shape={mat_shape}, target shape={op.target.shape}')
def test_matrix_product_flatten(domain, n_tests=4):
appl_shape = (ift.utilities.my_product(domain.shape),)
......@@ -71,7 +67,6 @@ def test_matrix_product_flatten(domain, n_tests=4):
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=None, flatten=True)
ift.extra.check_linear_operator(op)
print(f'flatten=True. Domain shape={domain.shape}, matrix shape={mat_shape}')
# the below function demonstrates the only error that cannot be caught
# when the operator is initialized. It is caused due to the matrix having
......@@ -89,9 +84,6 @@ def test_matrix_product_invalid_shapes(domain):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
print('ValueError raised because positions of unused subspaces of '+
'domain are changed.\n'+
f'Domain shape={domain.shape}, spaces={spaces}, matrix shape={mat_shape}')
mat_shape = ()
spaces = (3,)
if spaces != None:
......@@ -102,10 +94,6 @@ def test_matrix_product_invalid_shapes(domain):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
print('No errors raised because positions of unused subspaces of '+
'domain are not changed.\n'+
f'Domain shape={domain.shape}, spaces={spaces}, '+
f'matrix shape={mat_shape}, target shape= {op.target.shape}')
mat_shape = (7,)
spaces = (1, 2)
if spaces != None:
......@@ -117,9 +105,6 @@ def test_matrix_product_invalid_shapes(domain):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
print('ValueError raised because positions of unused subspaces of '+
'domain are changed.\n'+
f'Domain shape={domain.shape}, spaces={spaces}, matrix shape={mat_shape}')
mat_shape = (7,)
spaces = (1, 3)
if spaces != None:
......@@ -130,7 +115,57 @@ def test_matrix_product_invalid_shapes(domain):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
print('No errors raised because positions of unused subspaces of '+
'domain are not changed.\n'+
f'Domain shape={domain.shape}, spaces={spaces}, '+
f'matrix shape={mat_shape}, target shape={op.target.shape}')
def test_matrix_product_examples():
# 1. Demonstrate that multiplying by a unitary matrix doesn't change
# the norm of a vector
domain = ift.RGSpace(10, 0.1)
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
norm = ift.MatrixProductOperator(domain, field.conjugate().val).times(field)
M = ift.random.current_rng().standard_normal(domain.shape*2)
H = M + M.transpose()
O = np.linalg.eig(H)[1]
O_inv = np.linalg.inv(O)
Hd = np.matmul(O_inv, np.matmul(H, O))
Ud = np.diag(np.exp(1j*np.diag(Hd)))
U = np.matmul(O, np.matmul(Ud, O_inv))
U_matmul = ift.MatrixProductOperator(domain, U)
field2 = U_matmul.times(field)
norm2 = ift.MatrixProductOperator(domain, field2.conjugate().val).times(field2)
ift.extra.assert_allclose(norm, norm2, rtol=1e-14, atol=0)
# 2. Demonstrate using the operator to get complex conjugate of a field
domain = dtuple
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
one = ift.Field.from_raw(ift.DomainTuple.make(None), 1)
op = ift.MatrixProductOperator(domain, field.val)
op_conjugate = op.adjoint_times(one)
ift.extra.assert_equal(op_conjugate, field.conjugate())
# 3. Demonstrate using the operator to take the trace of a square matrix
domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
trace_op = ift.MatrixProductOperator(domain, np.eye(10))
trace = trace_op.times(field)
np_trace = np.trace(field.val)
np.testing.assert_allclose(trace.val, np_trace, rtol=1e-14, atol=0)
# 4. Demonstrate the cyclic property of the trace using the matrix
# product operator for matrix products
domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
A = ift.Field.from_random(domain=domain, dtype=np.complex128)
B = ift.Field.from_random(domain=domain, dtype=np.complex128)
C = ift.Field.from_random(domain=domain, dtype=np.complex128)
trace_op = ift.MatrixProductOperator(domain, np.eye(10))
A_matmul = ift.MatrixProductOperator(domain, A.val, spaces=(0,))
B_matmul = ift.MatrixProductOperator(domain, B.val, spaces=(0,))
C_matmul = ift.MatrixProductOperator(domain, C.val, spaces=(0,))
ABC = A_matmul.times(B_matmul.times(C))
BCA = B_matmul.times(C_matmul.times(A))
CAB = C_matmul.times(A_matmul.times(B))
trace1 = trace_op.times(ABC)
trace2 = trace_op.times(BCA)
trace3 = trace_op.times(CAB)
ift.extra.assert_allclose(trace1, trace2, rtol=1e-14, atol=0)
ift.extra.assert_allclose(trace2, trace3, rtol=1e-14, atol=0)
ift.extra.assert_allclose(trace3, trace1, rtol=1e-14, atol=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment