Commit f3e0cd1a authored by Neel Shah's avatar Neel Shah
Browse files

minor change

parent a5a73894
Pipeline #106470 failed with stages
......@@ -116,56 +116,57 @@ def test_matrix_product_invalid_shapes(domain):
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
def test_matrix_product_examples():
# 1. Demonstrate that multiplying by a unitary matrix doesn't change
# the norm of a vector
domain = ift.RGSpace(10, 0.1)
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
norm = ift.MatrixProductOperator(domain, field.conjugate().val).times(field)
M = ift.random.current_rng().standard_normal(domain.shape*2)
H = M + M.transpose()
O = np.linalg.eig(H)[1]
O_inv = np.linalg.inv(O)
Hd = np.matmul(O_inv, np.matmul(H, O))
Ud = np.diag(np.exp(1j*np.diag(Hd)))
U = np.matmul(O, np.matmul(Ud, O_inv))
U_matmul = ift.MatrixProductOperator(domain, U)
field2 = U_matmul.times(field)
norm2 = ift.MatrixProductOperator(domain, field2.conjugate().val).times(field2)
ift.extra.assert_allclose(norm, norm2, rtol=1e-14, atol=0)
# 2. Demonstrate using the operator to get complex conjugate of a field
domain = dtuple
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
one = ift.Field.from_raw(ift.DomainTuple.make(None), 1)
op = ift.MatrixProductOperator(domain, field.val)
op_conjugate = op.adjoint_times(one)
ift.extra.assert_equal(op_conjugate, field.conjugate())
# 3. Demonstrate using the operator to take the trace of a square matrix
domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
trace_op = ift.MatrixProductOperator(domain, np.eye(10))
trace = trace_op.times(field)
np_trace = np.trace(field.val)
np.testing.assert_allclose(trace.val, np_trace, rtol=1e-14, atol=0)
# 4. Demonstrate the cyclic property of the trace using the matrix
# product operator for matrix products
domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
A = ift.Field.from_random(domain=domain, dtype=np.complex128)
B = ift.Field.from_random(domain=domain, dtype=np.complex128)
C = ift.Field.from_random(domain=domain, dtype=np.complex128)
trace_op = ift.MatrixProductOperator(domain, np.eye(10))
A_matmul = ift.MatrixProductOperator(domain, A.val, spaces=(0,))
B_matmul = ift.MatrixProductOperator(domain, B.val, spaces=(0,))
C_matmul = ift.MatrixProductOperator(domain, C.val, spaces=(0,))
ABC = A_matmul.times(B_matmul.times(C))
BCA = B_matmul.times(C_matmul.times(A))
CAB = C_matmul.times(A_matmul.times(B))
trace1 = trace_op.times(ABC)
trace2 = trace_op.times(BCA)
trace3 = trace_op.times(CAB)
ift.extra.assert_allclose(trace1, trace2, rtol=1e-14, atol=0)
ift.extra.assert_allclose(trace2, trace3, rtol=1e-14, atol=0)
ift.extra.assert_allclose(trace3, trace1, rtol=1e-14, atol=0)
def test_matrix_product_examples(n_tests=4):
for i in range(n_tests):
# 1. Demonstrate that multiplying by a unitary matrix doesn't change
# the norm of a vector
domain = ift.RGSpace(10, 0.1)
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
norm = ift.MatrixProductOperator(domain, field.conjugate().val).times(field)
M = ift.random.current_rng().standard_normal(domain.shape*2)
H = M + M.transpose()
O = np.linalg.eig(H)[1]
O_inv = np.linalg.inv(O)
Hd = np.matmul(O_inv, np.matmul(H, O))
Ud = np.diag(np.exp(1j*np.diag(Hd)))
U = np.matmul(O, np.matmul(Ud, O_inv))
U_matmul = ift.MatrixProductOperator(domain, U)
field2 = U_matmul.times(field)
norm2 = ift.MatrixProductOperator(domain, field2.conjugate().val).times(field2)
ift.extra.assert_allclose(norm, norm2, rtol=1e-14, atol=0)
# 2. Demonstrate using the operator to get complex conjugate of a field
domain = dtuple
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
one = ift.Field.from_raw(ift.DomainTuple.make(None), 1)
op = ift.MatrixProductOperator(domain, field.val)
op_conjugate = op.adjoint_times(one)
ift.extra.assert_equal(op_conjugate, field.conjugate())
# 3. Demonstrate using the operator to take the trace of a square matrix
domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
trace_op = ift.MatrixProductOperator(domain, np.eye(10))
trace = trace_op.times(field)
np_trace = np.trace(field.val)
np.testing.assert_allclose(trace.val, np_trace, rtol=1e-14, atol=0)
# 4. Demonstrate the cyclic property of the trace using the matrix
# product operator for matrix products
domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
A = ift.Field.from_random(domain=domain, dtype=np.complex128)
B = ift.Field.from_random(domain=domain, dtype=np.complex128)
C = ift.Field.from_random(domain=domain, dtype=np.complex128)
trace_op = ift.MatrixProductOperator(domain, np.eye(10))
A_matmul = ift.MatrixProductOperator(domain, A.val, spaces=(0,))
B_matmul = ift.MatrixProductOperator(domain, B.val, spaces=(0,))
C_matmul = ift.MatrixProductOperator(domain, C.val, spaces=(0,))
ABC = A_matmul.times(B_matmul.times(C))
BCA = B_matmul.times(C_matmul.times(A))
CAB = C_matmul.times(A_matmul.times(B))
trace1 = trace_op.times(ABC)
trace2 = trace_op.times(BCA)
trace3 = trace_op.times(CAB)
ift.extra.assert_allclose(trace1, trace2, rtol=1e-14, atol=0)
ift.extra.assert_allclose(trace2, trace3, rtol=1e-14, atol=0)
ift.extra.assert_allclose(trace3, trace1, rtol=1e-14, atol=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment