# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2021 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. import numpy as np import pytest import nifty8 as ift from ..common import list2fixture dtuple = ift.DomainTuple.make((ift.RGSpace(2, 0.2), ift.RGSpace(3, 0.3), ift.RGSpace(4, 0.4), ift.RGSpace(5, 0.5))) pmp = pytest.mark.parametrize domain = list2fixture([dtuple]) spaces = list2fixture((None, (2,), (1, 3), (1, 2, 3), (0, 1, 2, 3))) def test_tensor_dot_endomorphic(domain, spaces, n_tests=4): mat_shape = () if spaces != None: for i in spaces: mat_shape += domain[i].shape else: mat_shape += domain.shape mat_shape = mat_shape*2 for i in range(n_tests): mat = ift.random.current_rng().standard_normal(mat_shape) mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape) op = ift.TensorDotOperator(domain, mat, spaces=spaces) ift.extra.check_linear_operator(op) def test_tensor_dot_spaces(domain, spaces, n_tests=4): mat_shape = (7, 8) if spaces != None: for i in spaces: mat_shape += domain[i].shape else: mat_shape += domain.shape for i in range(n_tests): mat = ift.random.current_rng().standard_normal(mat_shape) mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape) op = ift.TensorDotOperator(domain, mat, spaces=spaces) ift.extra.check_linear_operator(op) def test_tensor_dot_flatten(domain, n_tests=4): appl_shape = (ift.utilities.my_product(domain.shape),) mat_shape = appl_shape * 2 for i in range(n_tests): mat = ift.random.current_rng().standard_normal(mat_shape) mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape) op = ift.TensorDotOperator(domain, mat, spaces=None, flatten=True) ift.extra.check_linear_operator(op) # the below function demonstrates the only error that cannot be caught # when the operator is initialized. It is caused due to the tensor having # too few dimensions to stand in the places of summed over axes of the domain # as explained in the operator's documentation. def test_tensor_dot_invalid_shapes(domain): mat_shape = () spaces = (2,) if spaces != None: for i in spaces: mat_shape += domain[i].shape else: mat_shape += domain.shape with pytest.raises(ValueError): mat = ift.random.current_rng().standard_normal(mat_shape) op = ift.TensorDotOperator(domain, mat, spaces=spaces) ift.extra.check_linear_operator(op) mat_shape = () spaces = (3,) if spaces != None: for i in spaces: mat_shape += domain[i].shape else: mat_shape += domain.shape mat = ift.random.current_rng().standard_normal(mat_shape) op = ift.TensorDotOperator(domain, mat, spaces=spaces) ift.extra.check_linear_operator(op) mat_shape = (7,) spaces = (1, 2) if spaces != None: for i in spaces: mat_shape += domain[i].shape else: mat_shape += domain.shape with pytest.raises(ValueError): mat = ift.random.current_rng().standard_normal(mat_shape) op = ift.TensorDotOperator(domain, mat, spaces=spaces) ift.extra.check_linear_operator(op) mat_shape = (7,) spaces = (1, 3) if spaces != None: for i in spaces: mat_shape += domain[i].shape else: mat_shape += domain.shape mat = ift.random.current_rng().standard_normal(mat_shape) op = ift.TensorDotOperator(domain, mat, spaces=spaces) ift.extra.check_linear_operator(op) def test_tensor_dot_examples(n_tests=4): for i in range(n_tests): # 1. Demonstrate that multiplying by a unitary matrix doesn't change # the norm of a vector domain = ift.RGSpace(10, 0.1) field = ift.Field.from_random(domain=domain, dtype=np.complex128) norm = ift.TensorDotOperator(domain, field.conjugate().val).times(field) M = ift.random.current_rng().standard_normal(domain.shape*2) H = M + M.transpose() O = np.linalg.eig(H)[1] O_inv = np.linalg.inv(O) Hd = np.matmul(O_inv, np.matmul(H, O)) Ud = np.diag(np.exp(1j*np.diag(Hd))) U = np.matmul(O, np.matmul(Ud, O_inv)) U_matmul = ift.TensorDotOperator(domain, U) field2 = U_matmul.times(field) norm2 = ift.TensorDotOperator(domain, field2.conjugate().val).times(field2) ift.extra.assert_allclose(norm, norm2, rtol=1e-14, atol=0) # 2. Demonstrate using the operator to get complex conjugate of a field domain = dtuple field = ift.Field.from_random(domain=domain, dtype=np.complex128) one = ift.Field.from_raw(ift.DomainTuple.make(None), 1) op = ift.TensorDotOperator(domain, field.val) op_conjugate = op.adjoint_times(one) ift.extra.assert_equal(op_conjugate, field.conjugate()) # 3. Demonstrate using the operator to take the trace of a square matrix domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2))) field = ift.Field.from_random(domain=domain, dtype=np.complex128) trace_op = ift.TensorDotOperator(domain, np.eye(10)) trace = trace_op.times(field) np_trace = np.trace(field.val) np.testing.assert_allclose(trace.val, np_trace, rtol=1e-14, atol=0) # 4. Demonstrate the cyclic property of the trace using the matrix # product operator for matrix products domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2))) A = ift.Field.from_random(domain=domain, dtype=np.complex128) B = ift.Field.from_random(domain=domain, dtype=np.complex128) C = ift.Field.from_random(domain=domain, dtype=np.complex128) trace_op = ift.TensorDotOperator(domain, np.eye(10)) A_matmul = ift.TensorDotOperator(domain, A.val, spaces=(0,)) B_matmul = ift.TensorDotOperator(domain, B.val, spaces=(0,)) C_matmul = ift.TensorDotOperator(domain, C.val, spaces=(0,)) ABC = A_matmul.times(B_matmul.times(C)) BCA = B_matmul.times(C_matmul.times(A)) CAB = C_matmul.times(A_matmul.times(B)) trace1 = trace_op.times(ABC) trace2 = trace_op.times(BCA) trace3 = trace_op.times(CAB) ift.extra.assert_allclose(trace1, trace2, rtol=1e-14, atol=0) ift.extra.assert_allclose(trace2, trace3, rtol=1e-14, atol=0) ift.extra.assert_allclose(trace3, trace1, rtol=1e-14, atol=0)