Commit a530e555 authored by Neel Shah's avatar Neel Shah
Browse files

added a test for sparse matrices, changed variable names

parent 86940bf1
Pipeline #107178 failed with stages
......@@ -16,56 +16,68 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import scipy.sparse
import pytest
import nifty8 as ift
from ..common import list2fixture
dtuple = ift.DomainTuple.make((ift.RGSpace(2, 0.2), ift.RGSpace(3, 0.3),
ift.RGSpace(4, 0.4), ift.RGSpace(5, 0.5)))
dtuple = ift.DomainTuple.make((ift.RGSpace(2, 0.2), ift.RGSpace(3, 0.3),
ift.RGSpace(4, 0.4), ift.RGSpace(5, 0.5)))
pmp = pytest.mark.parametrize
domain = list2fixture([dtuple])
spaces = list2fixture((None, (2,), (1, 3), (1, 2, 3), (0, 1, 2, 3)))
def test_tensor_dot_endomorphic(domain, spaces, n_tests=4):
mat_shape = ()
tensor_shape = ()
if spaces != None:
for i in spaces:
mat_shape += domain[i].shape
tensor_shape += domain[i].shape
else:
mat_shape += domain.shape
mat_shape = mat_shape*2
tensor_shape += domain.shape
tensor_shape = tensor_shape*2
for i in range(n_tests):
mat = ift.random.current_rng().standard_normal(mat_shape)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
tensor = ift.random.current_rng().standard_normal(tensor_shape)
tensor = tensor + 1j*ift.random.current_rng().standard_normal(tensor_shape)
op = ift.TensorDotOperator(domain, tensor, spaces=spaces)
ift.extra.check_linear_operator(op)
def test_tensor_dot_spaces(domain, spaces, n_tests=4):
mat_shape = (7, 8)
tensor_shape = (7, 8)
if spaces != None:
for i in spaces:
mat_shape += domain[i].shape
tensor_shape += domain[i].shape
else:
mat_shape += domain.shape
tensor_shape += domain.shape
for i in range(n_tests):
mat = ift.random.current_rng().standard_normal(mat_shape)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
tensor = ift.random.current_rng().standard_normal(tensor_shape)
tensor = tensor + 1j*ift.random.current_rng().standard_normal(tensor_shape)
op = ift.TensorDotOperator(domain, tensor, spaces=spaces)
ift.extra.check_linear_operator(op)
def test_tensor_dot_flatten(domain, n_tests=4):
appl_shape = (ift.utilities.my_product(domain.shape),)
mat_shape = appl_shape * 2
tensor_shape = appl_shape * 2
for i in range(n_tests):
mat = ift.random.current_rng().standard_normal(mat_shape)
mat = mat + 1j*ift.random.current_rng().standard_normal(mat_shape)
op = ift.TensorDotOperator(domain, mat, spaces=None, flatten=True)
tensor = ift.random.current_rng().standard_normal(tensor_shape)
tensor = tensor + 1j*ift.random.current_rng().standard_normal(tensor_shape)
op = ift.TensorDotOperator(domain, tensor, spaces=None, flatten=True)
ift.extra.check_linear_operator(op)
def test_tensor_dot_sparse(n_tests=4):
domain = ift.RGSpace(10)
for i in range(n_tests):
tensor = scipy.sparse.rand(10, 10, 0.5)
op = ift.TensorDotOperator(domain, tensor)
ift.extra.check_linear_operator(op)
# the below function demonstrates the only error that cannot be caught
......@@ -73,56 +85,58 @@ def test_tensor_dot_flatten(domain, n_tests=4):
# too few dimensions to stand in the places of summed over axes of the domain
# as explained in the operator's documentation.
def test_tensor_dot_invalid_shapes(domain):
mat_shape = ()
tensor_shape = ()
spaces = (2,)
if spaces != None:
for i in spaces:
mat_shape += domain[i].shape
tensor_shape += domain[i].shape
else:
mat_shape += domain.shape
tensor_shape += domain.shape
with pytest.raises(ValueError):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
tensor = ift.random.current_rng().standard_normal(tensor_shape)
op = ift.TensorDotOperator(domain, tensor, spaces=spaces)
ift.extra.check_linear_operator(op)
mat_shape = ()
tensor_shape = ()
spaces = (3,)
if spaces != None:
for i in spaces:
mat_shape += domain[i].shape
tensor_shape += domain[i].shape
else:
mat_shape += domain.shape
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
tensor_shape += domain.shape
tensor = ift.random.current_rng().standard_normal(tensor_shape)
op = ift.TensorDotOperator(domain, tensor, spaces=spaces)
ift.extra.check_linear_operator(op)
mat_shape = (7,)
tensor_shape = (7,)
spaces = (1, 2)
if spaces != None:
for i in spaces:
mat_shape += domain[i].shape
tensor_shape += domain[i].shape
else:
mat_shape += domain.shape
tensor_shape += domain.shape
with pytest.raises(ValueError):
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
tensor = ift.random.current_rng().standard_normal(tensor_shape)
op = ift.TensorDotOperator(domain, tensor, spaces=spaces)
ift.extra.check_linear_operator(op)
mat_shape = (7,)
tensor_shape = (7,)
spaces = (1, 3)
if spaces != None:
for i in spaces:
mat_shape += domain[i].shape
tensor_shape += domain[i].shape
else:
mat_shape += domain.shape
mat = ift.random.current_rng().standard_normal(mat_shape)
op = ift.TensorDotOperator(domain, mat, spaces=spaces)
tensor_shape += domain.shape
tensor = ift.random.current_rng().standard_normal(tensor_shape)
op = ift.TensorDotOperator(domain, tensor, spaces=spaces)
ift.extra.check_linear_operator(op)
def test_tensor_dot_examples(n_tests=4):
for i in range(n_tests):
# 1. Demonstrate that multiplying by a unitary matrix doesn't change
# the norm of a vector
domain = ift.RGSpace(10, 0.1)
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
norm = ift.TensorDotOperator(domain, field.conjugate().val).times(field)
norm = ift.TensorDotOperator(
domain, field.conjugate().val).times(field)
M = ift.random.current_rng().standard_normal(domain.shape*2)
H = M + M.transpose()
O = np.linalg.eig(H)[1]
......@@ -132,9 +146,10 @@ def test_tensor_dot_examples(n_tests=4):
U = np.matmul(O, np.matmul(Ud, O_inv))
U_matmul = ift.TensorDotOperator(domain, U)
field2 = U_matmul.times(field)
norm2 = ift.TensorDotOperator(domain, field2.conjugate().val).times(field2)
norm2 = ift.TensorDotOperator(
domain, field2.conjugate().val).times(field2)
ift.extra.assert_allclose(norm, norm2, rtol=1e-14, atol=0)
# 2. Demonstrate using the operator to get complex conjugate of a field
domain = dtuple
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
......@@ -142,18 +157,20 @@ def test_tensor_dot_examples(n_tests=4):
op = ift.TensorDotOperator(domain, field.val)
op_conjugate = op.adjoint_times(one)
ift.extra.assert_equal(op_conjugate, field.conjugate())
# 3. Demonstrate using the operator to take the trace of a square matrix
domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
domain = ift.DomainTuple.make(
(ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
field = ift.Field.from_random(domain=domain, dtype=np.complex128)
trace_op = ift.TensorDotOperator(domain, np.eye(10))
trace = trace_op.times(field)
np_trace = np.trace(field.val)
np.testing.assert_allclose(trace.val, np_trace, rtol=1e-14, atol=0)
# 4. Demonstrate the cyclic property of the trace using the matrix
# product operator for matrix products
domain = ift.DomainTuple.make((ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
# 4. Demonstrate the cyclic property of the trace using the tensor dot
# operator for matrix products
domain = ift.DomainTuple.make(
(ift.RGSpace(10, 0.3), ift.RGSpace(10, 0.2)))
A = ift.Field.from_random(domain=domain, dtype=np.complex128)
B = ift.Field.from_random(domain=domain, dtype=np.complex128)
C = ift.Field.from_random(domain=domain, dtype=np.complex128)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment