Commit 1e4e8d70 authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

test_selection_operators.py: Test more spaces

parent da977baf
Pipeline #75017 passed with stages
in 25 minutes and 13 seconds
......@@ -22,26 +22,30 @@ from nifty6.extra import consistency_check
import numpy as np
import nifty6 as ift
from ..common import setup_function, teardown_function
from ..common import list2fixture, setup_function, teardown_function
pmp = pytest.mark.parametrize
# The test cases do not work on a multi-dimensional RGSpace yet
spaces = (
ift.UnstructuredDomain(4),
ift.LMSpace(5),
ift.GLSpace(4),
)
space1 = list2fixture(spaces)
space2 = list2fixture(spaces)
dtype = list2fixture([np.float64, np.complex128])
@pmp("n_unstructured", (3, 9))
@pmp("nside", (4, 8))
def test_split_operator_first_axes_without_intersections(
n_unstructured, nside, n_splits=3
space1, space2, n_splits=3
):
setup_function()
rng = ift.random.current_rng()
pos_space = ift.HPSpace(nside)
dom = ift.DomainTuple.make(
(ift.UnstructuredDomain(n_unstructured), pos_space)
)
orig_idx = np.arange(n_unstructured)
dom = ift.DomainTuple.make((space1, space2))
orig_idx = np.arange(space1.shape[0])
rng.shuffle(orig_idx)
split_idx = np.split(orig_idx, n_splits)
split_idx = np.array_split(orig_idx, n_splits)
split = ift.SplitOperator(
dom, {f"{i:06d}": (si, )
for i, si in enumerate(split_idx)}
......@@ -58,24 +62,16 @@ def test_split_operator_first_axes_without_intersections(
# the generated indices and without intersections.
assert_array_equal(split.adjoint(split_r).val, r.val)
teardown_function()
@pmp("n_unstructured", (3, 9))
@pmp("nside", (4, 8))
def test_split_operator_first_axes_with_intersections(
n_unstructured, nside, n_splits=3
space1, space2, n_splits=3
):
setup_function()
rng = ift.random.current_rng()
pos_space = ift.HPSpace(nside)
dom = ift.DomainTuple.make(
(ift.UnstructuredDomain(n_unstructured), pos_space)
)
orig_idx = np.arange(n_unstructured)
dom = ift.DomainTuple.make((space1, space2))
orig_idx = np.arange(space1.shape[0])
split_idx = [
rng.choice(orig_idx, rng.integers(1, n_unstructured), replace=False)
rng.choice(orig_idx, rng.integers(1, space1.shape[0]), replace=False)
for _ in range(n_splits)
]
split = ift.SplitOperator(
......@@ -95,9 +91,7 @@ def test_split_operator_first_axes_with_intersections(
r_diy = np.copy(r.val)
unique_freq = np.unique(np.concatenate(split_idx), return_counts=True)
# Null values that were not selected
r_diy[list(set(unique_freq[0]) ^ set(range(n_unstructured)))] = 0.
r_diy[list(set(unique_freq[0]) ^ set(range(space1.shape[0])))] = 0.
for idx, freq in zip(*unique_freq):
r_diy[idx] *= freq
assert_allclose(split.adjoint(split_r).val, r_diy)
teardown_function()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment