Commit f3e0cd1a by Neel Shah

### minor change

parent a5a73894
Pipeline #106470 failed with stages
 ... ... @@ -116,56 +116,57 @@ def test_matrix_product_invalid_shapes(domain): op = ift.MatrixProductOperator(domain, mat, spaces=spaces) ift.extra.check_linear_operator(op) def test_matrix_product_examples(): # 1. Demonstrate that multiplying by a unitary matrix doesn't change # the norm of a vector domain = ift.RGSpace(10, 0.1) field = ift.Field.from_random(domain=domain, dtype=np.complex128) norm = ift.MatrixProductOperator(domain, field.conjugate().val).times(field) M = ift.random.current_rng().standard_normal(domain.shape*2) H = M + M.transpose() O = np.linalg.eig(H)[1] O_inv = np.linalg.inv(O) Hd = np.matmul(O_inv, np.matmul(H, O)) Ud = np.diag(np.exp(1j*np.diag(Hd))) U = np.matmul(O, np.matmul(Ud, O_inv)) U_matmul = ift.MatrixProductOperator(domain, U) field2 = U_matmul.times(field) norm2 = ift.MatrixProductOperator(domain, field2.conjugate().val).times(field2) ift.extra.assert_allclose(norm, norm2, rtol=1e-14, atol=0) # 2. Demonstrate using the operator to get complex conjugate of a field domain = dtuple field = ift.Field.from_random(domain=domain, dtype=np.complex128) one = ift.Field.from_raw(ift.DomainTuple.make(None), 1) op = ift.MatrixProductOperator(domain, field.val) op_conjugate = op.adjoint_times(one) ift.extra.assert_equal(op_conjugate, field.conjugate()) # 3. Demonstrate using the operator to take the trace of a square matrix domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2))) field = ift.Field.from_random(domain=domain, dtype=np.complex128) trace_op = ift.MatrixProductOperator(domain, np.eye(10)) trace = trace_op.times(field) np_trace = np.trace(field.val) np.testing.assert_allclose(trace.val, np_trace, rtol=1e-14, atol=0) # 4. Demonstrate the cyclic property of the trace using the matrix # product operator for matrix products domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2))) A = ift.Field.from_random(domain=domain, dtype=np.complex128) B = ift.Field.from_random(domain=domain, dtype=np.complex128) C = ift.Field.from_random(domain=domain, dtype=np.complex128) trace_op = ift.MatrixProductOperator(domain, np.eye(10)) A_matmul = ift.MatrixProductOperator(domain, A.val, spaces=(0,)) B_matmul = ift.MatrixProductOperator(domain, B.val, spaces=(0,)) C_matmul = ift.MatrixProductOperator(domain, C.val, spaces=(0,)) ABC = A_matmul.times(B_matmul.times(C)) BCA = B_matmul.times(C_matmul.times(A)) CAB = C_matmul.times(A_matmul.times(B)) trace1 = trace_op.times(ABC) trace2 = trace_op.times(BCA) trace3 = trace_op.times(CAB) ift.extra.assert_allclose(trace1, trace2, rtol=1e-14, atol=0) ift.extra.assert_allclose(trace2, trace3, rtol=1e-14, atol=0) ift.extra.assert_allclose(trace3, trace1, rtol=1e-14, atol=0) def test_matrix_product_examples(n_tests=4): for i in range(n_tests): # 1. Demonstrate that multiplying by a unitary matrix doesn't change # the norm of a vector domain = ift.RGSpace(10, 0.1) field = ift.Field.from_random(domain=domain, dtype=np.complex128) norm = ift.MatrixProductOperator(domain, field.conjugate().val).times(field) M = ift.random.current_rng().standard_normal(domain.shape*2) H = M + M.transpose() O = np.linalg.eig(H)[1] O_inv = np.linalg.inv(O) Hd = np.matmul(O_inv, np.matmul(H, O)) Ud = np.diag(np.exp(1j*np.diag(Hd))) U = np.matmul(O, np.matmul(Ud, O_inv)) U_matmul = ift.MatrixProductOperator(domain, U) field2 = U_matmul.times(field) norm2 = ift.MatrixProductOperator(domain, field2.conjugate().val).times(field2) ift.extra.assert_allclose(norm, norm2, rtol=1e-14, atol=0) # 2. Demonstrate using the operator to get complex conjugate of a field domain = dtuple field = ift.Field.from_random(domain=domain, dtype=np.complex128) one = ift.Field.from_raw(ift.DomainTuple.make(None), 1) op = ift.MatrixProductOperator(domain, field.val) op_conjugate = op.adjoint_times(one) ift.extra.assert_equal(op_conjugate, field.conjugate()) # 3. Demonstrate using the operator to take the trace of a square matrix domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2))) field = ift.Field.from_random(domain=domain, dtype=np.complex128) trace_op = ift.MatrixProductOperator(domain, np.eye(10)) trace = trace_op.times(field) np_trace = np.trace(field.val) np.testing.assert_allclose(trace.val, np_trace, rtol=1e-14, atol=0) # 4. Demonstrate the cyclic property of the trace using the matrix # product operator for matrix products domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2))) A = ift.Field.from_random(domain=domain, dtype=np.complex128) B = ift.Field.from_random(domain=domain, dtype=np.complex128) C = ift.Field.from_random(domain=domain, dtype=np.complex128) trace_op = ift.MatrixProductOperator(domain, np.eye(10)) A_matmul = ift.MatrixProductOperator(domain, A.val, spaces=(0,)) B_matmul = ift.MatrixProductOperator(domain, B.val, spaces=(0,)) C_matmul = ift.MatrixProductOperator(domain, C.val, spaces=(0,)) ABC = A_matmul.times(B_matmul.times(C)) BCA = B_matmul.times(C_matmul.times(A)) CAB = C_matmul.times(A_matmul.times(B)) trace1 = trace_op.times(ABC) trace2 = trace_op.times(BCA) trace3 = trace_op.times(CAB) ift.extra.assert_allclose(trace1, trace2, rtol=1e-14, atol=0) ift.extra.assert_allclose(trace2, trace3, rtol=1e-14, atol=0) ift.extra.assert_allclose(trace3, trace1, rtol=1e-14, atol=0)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!