Commit c88b1575 authored by Neel Shah's avatar Neel Shah
Browse files

rename new operator to TensorDotOperator

parent eb51f86d
Pipeline #106979 failed with stages
......@@ -30,7 +30,7 @@ pmp = pytest.mark.parametrize
domain = list2fixture([dtuple])
spaces = list2fixture((None, (2,), (1, 3), (1, 2, 3), (0, 1, 2, 3)))
def test_matrix_product_endomorphic(domain, spaces, n_tests=4):
def test_tensor_dot_endomorphic(domain, spaces, n_tests=4):
mat_shape = ()
if spaces != None:
for i in spaces:
......@@ -42,10 +42,10 @@ def test_matrix_product_endomorphic(domain, spaces, n_tests=4):
for i in range(n_tests):
mat = ift.random.current_rng().standard_normal(mat_shape)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
def test_matrix_product_spaces(domain, spaces, n_tests=4):
def test_tensor_dot_spaces(domain, spaces, n_tests=4):
mat_shape = (7, 8)
if spaces != None:
for i in spaces:
......@@ -56,23 +56,23 @@ def test_matrix_product_spaces(domain, spaces, n_tests=4):
for i in range(n_tests):
mat = ift.random.current_rng().standard_normal(mat_shape)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
def test_matrix_product_flatten(domain, n_tests=4):
def test_tensor_dot_flatten(domain, n_tests=4):
appl_shape = (ift.utilities.my_product(domain.shape),)
mat_shape = appl_shape * 2
for i in range(n_tests):
mat = ift.random.current_rng().standard_normal(mat_shape)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=None, flatten=True)
op = ift.TensorDotOperator(domain, mat, spaces=None, flatten=True)
ift.extra.check_linear_operator(op)
# the below function demonstrates the only error that cannot be caught
# when the operator is initialized. It is caused due to the matrix having
# when the operator is initialized. It is caused due to the tensor having
# too few dimensions to stand in the places of summed over axes of the domain
# as explained in the operator's documentation.
def test_matrix_product_invalid_shapes(domain):
def test_tensor_dot_invalid_shapes(domain):
mat_shape = ()
spaces = (2,)
if spaces != None:
......@@ -82,7 +82,7 @@ def test_matrix_product_invalid_shapes(domain):
mat_shape += domain.shape
with pytest.raises(ValueError):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
mat_shape = ()
spaces = (3,)
......@@ -92,7 +92,7 @@ def test_matrix_product_invalid_shapes(domain):
else:
mat_shape += domain.shape
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
mat_shape = (7,)
spaces = (1, 2)
......@@ -103,7 +103,7 @@ def test_matrix_product_invalid_shapes(domain):
mat_shape += domain.shape
with pytest.raises(ValueError):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
mat_shape = (7,)
spaces = (1, 3)
......@@ -113,16 +113,16 @@ def test_matrix_product_invalid_shapes(domain):
else:
mat_shape += domain.shape
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.MatrixProductOperator(domain, mat, spaces=spaces)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
ift.extra.check_linear_operator(op)
def test_matrix_product_examples(n_tests=4):
def test_tensor_dot_examples(n_tests=4):
for i in range(n_tests):
# 1. Demonstrate that multiplying by a unitary matrix doesn't change
# the norm of a vector
domain = ift.RGSpace(10, 0.1)
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
norm = ift.MatrixProductOperator(domain, field.conjugate().val).times(field)
norm = ift.TensorDotOperator(domain, field.conjugate().val).times(field)
M = ift.random.current_rng().standard_normal(domain.shape*2)
H = M + M.transpose()
O = np.linalg.eig(H)[1]
......@@ -130,23 +130,23 @@ def test_matrix_product_examples(n_tests=4):
Hd = np.matmul(O_inv, np.matmul(H, O))
Ud = np.diag(np.exp(1j*np.diag(Hd)))
U = np.matmul(O, np.matmul(Ud, O_inv))
U_matmul = ift.MatrixProductOperator(domain, U)
U_matmul = ift.TensorDotOperator(domain, U)
field2 = U_matmul.times(field)
norm2 = ift.MatrixProductOperator(domain, field2.conjugate().val).times(field2)
norm2 = ift.TensorDotOperator(domain, field2.conjugate().val).times(field2)
ift.extra.assert_allclose(norm, norm2, rtol=1e-14, atol=0)
# 2. Demonstrate using the operator to get complex conjugate of a field
domain = dtuple
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
one = ift.Field.from_raw(ift.DomainTuple.make(None), 1)
op = ift.MatrixProductOperator(domain, field.val)
op = ift.TensorDotOperator(domain, field.val)
op_conjugate = op.adjoint_times(one)
ift.extra.assert_equal(op_conjugate, field.conjugate())
# 3. Demonstrate using the operator to take the trace of a square matrix
domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
trace_op = ift.MatrixProductOperator(domain, np.eye(10))
trace_op = ift.TensorDotOperator(domain, np.eye(10))
trace = trace_op.times(field)
np_trace = np.trace(field.val)
np.testing.assert_allclose(trace.val, np_trace, rtol=1e-14, atol=0)
......@@ -157,10 +157,10 @@ def test_matrix_product_examples(n_tests=4):
A = ift.Field.from_random(domain=domain, dtype=np.complex128)
B = ift.Field.from_random(domain=domain, dtype=np.complex128)
C = ift.Field.from_random(domain=domain, dtype=np.complex128)
trace_op = ift.MatrixProductOperator(domain, np.eye(10))
A_matmul = ift.MatrixProductOperator(domain, A.val, spaces=(0,))
B_matmul = ift.MatrixProductOperator(domain, B.val, spaces=(0,))
C_matmul = ift.MatrixProductOperator(domain, C.val, spaces=(0,))
trace_op = ift.TensorDotOperator(domain, np.eye(10))
A_matmul = ift.TensorDotOperator(domain, A.val, spaces=(0,))
B_matmul = ift.TensorDotOperator(domain, B.val, spaces=(0,))
C_matmul = ift.TensorDotOperator(domain, C.val, spaces=(0,))
ABC = A_matmul.times(B_matmul.times(C))
BCA = B_matmul.times(C_matmul.times(A))
CAB = C_matmul.times(A_matmul.times(B))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment