Commit bc31f0ba authored by Philipp Arras's avatar Philipp Arras
Browse files

Streamline tests

parent 9ddc2cd9
Pipeline #103083 canceled with stages
in 9 minutes and 15 seconds
......@@ -21,6 +21,7 @@ from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..linearization import Linearization
from ..minimization.energy_adapter import StochasticEnergyAdapter
from ..multi_field import MultiField
from ..operators.einsum import MultiLinearEinsum
from ..operators.energy_operators import EnergyOperator
......@@ -28,8 +29,7 @@ from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import full, makeField, makeDomain, from_random, is_fieldlike
from ..minimization.energy_adapter import StochasticEnergyAdapter
from ..sugar import from_random, full, is_fieldlike, makeDomain, makeField
from ..utilities import myassert
......@@ -168,16 +168,16 @@ class GaussianEntropy(EnergyOperator):
Parameters
----------
domain: Domain
domain: Domain FIXME
The domain of the diagonal.
"""
def __init__(self, domain):
self._domain = domain
self._domain = DomainTuple.make(domain)
def apply(self, x):
self._check_input(x)
res = -0.5*(2*np.pi*np.e*x**2).log().sum()
res = (x*x).scale(2*np.pi*np.e).log().sum().scale(-0.5)
if not isinstance(x, Linearization):
return res
if not x.want_metric:
......@@ -191,7 +191,7 @@ class LowerTriangularInserter(LinearOperator):
Parameters
----------
target: Domain
target: Domain FIXME
A two-dimensional domain with NxN entries.
"""
......@@ -220,7 +220,7 @@ class DiagonalSelector(LinearOperator):
Parameters
----------
domain: Domain
domain: Domain FIXME
The two-dimensional domain of the input field. Must be of shape NxN.
"""
......
......@@ -11,15 +11,14 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np
import pytest
import nifty7 as ift
from .common import list2fixture, setup_function, teardown_function
spaces = [ift.GLSpace(5),
......@@ -59,7 +58,7 @@ def test_QuadraticFormOperator(field):
def test_studentt(field):
if isinstance(field.domain, ift.MultiDomain):
return
pytest.skip()
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_operator(energy, field)
theta = ift.from_random(field.domain, 'normal').exp()
......@@ -80,7 +79,7 @@ def test_hamiltonian_and_KL(field):
def test_variablecovariancegaussian(field):
if isinstance(field.domain, ift.MultiDomain):
return
pytest.skip()
dc = {'a': field, 'b': field.ptw("exp")}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b', np.float64)
......@@ -90,7 +89,7 @@ def test_variablecovariancegaussian(field):
def test_specialgamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
pytest.skip()
energy = ift.operators.energy_operators._SpecialGammaEnergy(field)
loc = ift.from_random(energy.domain).exp()
ift.extra.check_operator(energy, loc, ntries=ntries)
......@@ -99,7 +98,7 @@ def test_specialgamma(field):
def test_inverse_gamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
pytest.skip()
field = field.ptw("exp")
space = field.domain
d = ift.random.current_rng().normal(10, size=space.shape)**2
......@@ -110,7 +109,7 @@ def test_inverse_gamma(field):
def testPoissonian(field):
if isinstance(field.domain, ift.MultiDomain):
return
pytest.skip()
field = field.ptw("exp")
space = field.domain
d = ift.random.current_rng().poisson(120, size=space.shape)
......@@ -121,7 +120,7 @@ def testPoissonian(field):
def test_bernoulli(field):
if isinstance(field.domain, ift.MultiDomain):
return
pytest.skip()
field = field.ptw("sigmoid")
space = field.domain
d = ift.random.current_rng().binomial(1, 0.1, size=space.shape)
......@@ -129,10 +128,10 @@ def test_bernoulli(field):
energy = ift.BernoulliEnergy(d)
ift.extra.check_operator(energy, field, tol=1e-10)
def test_gaussian_entropy(field):
if isinstance(field.domain, ift.MultiDomain):
return
pytest.skip()
field = field.ptw("sigmoid")
energy = ift.library.variational_models.GaussianEntropy(field.domain)
ift.extra.check_operator(energy, field)
\ No newline at end of file
ift.extra.check_operator(energy, field)
......@@ -11,15 +11,14 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np
import pytest
import nifty7 as ift
from ..common import list2fixture, setup_function, teardown_function
_h_RG_spaces = [ift.RGSpace(7, distances=0.2, harmonic=True),
......@@ -340,9 +339,9 @@ def testDiagonalExtractor(seed):
ift.extra.check_linear_operator(op)
@pmp('seed', [12, 3])
def testLowerTriangularInserter(seed):
N = 42
square_space = ift.RGSpace([N,N])
@pmp("N", [10, 17])
def testLowerTriangularInserter(seed, N):
square_space = ift.RGSpace([N, N])
op = ift.library.variational_models.LowerTriangularInserter(square_space)
ift.extra.check_linear_operator(op)
......
......@@ -11,25 +11,34 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np
import pytest
from nifty7.library.variational_models import DiagonalSelector
from numpy.testing import assert_allclose
import nifty7 as ift
from ..common import list2fixture, setup_function, teardown_function
from nifty7.library.variational_models import DiagonalSelector
pmp = pytest.mark.parametrize
def test_diagonal_selector():
N = 42
square_space = ift.RGSpace([N,N])
@pmp("N", [17, 32])
def test_diagonal_selector(N):
square_space = ift.RGSpace([N, N])
linear_space = ift.RGSpace(N)
myField = ift.from_random(square_space)
myDiagonalSelector = DiagonalSelector(square_space)
selected = myDiagonalSelector(myField).val
np_selected = np.diag(myField.val)
assert_allclose(np_selected, selected)
@pmp("n", [1, 3])
def test_error(n):
square_space = ift.RGSpace(n*[13])
with pytest.raises(AssertionError):
myDiagonalSelector = DiagonalSelector(square_space)
......@@ -11,24 +11,26 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np
import pytest
from nifty7.library.variational_models import GaussianEntropy
from numpy.testing import assert_allclose
import nifty7 as ift
from ..common import list2fixture, setup_function, teardown_function
from nifty7.library.variational_models import GaussianEntropy
pmp = pytest.mark.parametrize
def test_gaussian_entropy():
N = 42
@pmp("N", [17, 32])
def test_gaussian_entropy(N):
linear_space = ift.RGSpace(N)
myField = ift.from_random(linear_space, 'uniform')
vals = myField.val
entropy = - 0.5 * np.sum(np.log(2 * vals**2 * np.pi * np.e)) # minus due to subtraction in KL
myEntropy = GaussianEntropy(myField.domain)
assert_allclose(entropy, myEntropy(myField).val)
fld = ift.from_random(linear_space, 'uniform')
# minus due to subtraction in KL
entropy = -0.5*np.sum(np.log(2*fld.val*fld.val*np.pi*np.e))
myEntropy = GaussianEntropy(fld.domain)
assert_allclose(entropy, myEntropy(fld).val)
......@@ -11,39 +11,30 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np
import pytest
from nifty7.library.variational_models import LowerTriangularInserter
from numpy.testing import assert_allclose
import nifty7 as ift
from nifty7.library.variational_models import LowerTriangularInserter
from ..common import list2fixture, setup_function, teardown_function
pmp = pytest.mark.parametrize
def test_lower_triangular_inserter():
N = 42
square_space = ift.RGSpace([N,N])
myInserter = LowerTriangularInserter(square_space)
flat_space = myInserter.domain
assert_allclose(flat_space.size, N*(N+1)//2)
myField = ift.from_random(flat_space)
myMatrix = myInserter(myField)
@pmp("N", [17, 32])
def test_lower_triangular_inserter(N):
square_space = ift.RGSpace([N, N])
op = LowerTriangularInserter(square_space)
flat_space = op.domain
assert_allclose(flat_space.shape, (N*(N+1)//2,))
upper_i = []
upper_j = []
mat = op(ift.from_random(flat_space))
for i in range(0,N):
for j in range(i+1,N):
upper_i.append(i)
upper_j.append(j)
zeros = myMatrix.val[(upper_i,upper_j)]
assert_allclose(zeros, np.zeros((N-1)*N//2))
assert_allclose((myMatrix==0).val.sum(), (N-1)*N//2)
for j in range(i+1, N):
assert mat.val[i, j] == 0.
assert_allclose((mat.val == 0).sum(), (N-1)*N//2)
......@@ -11,36 +11,31 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np
import pytest
from numpy.testing import assert_allclose
import nifty7 as ift
def test_multifield2vector():
myFirstSpace = ift.RGSpace([42,43])
mySecondSpace = ift.RGSpace([1,2,3])
myDict = {
'first': ift.from_random(myFirstSpace),
'second': ift.from_random(mySecondSpace)
}
myMultiField = ift.MultiField.from_dict(myDict)
myMultifield2Vector = ift.Multifield2Vector(myMultiField.domain)
myVector = myMultifield2Vector(myMultiField)
assert_allclose(myVector.size, myMultiField.size)
myNumpyVector = np.empty(myMultiField.size)
myNumpyVector[:myFirstSpace.size] = myDict['first'].val.flatten()
myNumpyVector[myFirstSpace.size:] = myDict['second'].val.flatten()
assert_allclose(myNumpyVector, myVector.val)
from ..common import list2fixture, setup_function, teardown_function
from .test_adjoint import _h_spaces, _p_spaces, _pow_spaces
mySecondMultiField = myMultifield2Vector.adjoint(myVector)
pmp = pytest.mark.parametrize
assert_allclose(mySecondMultiField['first'].val, myMultiField['first'].val)
assert_allclose(mySecondMultiField['second'].val, myMultiField['second'].val)
@pmp("dom0", _h_spaces + _p_spaces + _pow_spaces)
@pmp("dom1", _h_spaces + _p_spaces + _pow_spaces)
def test_multifield2vector(dom0, dom1):
fld = ift.MultiField.from_dict({'a': ift.from_random(dom0),
'b': ift.from_random(dom1)})
op = ift.Multifield2Vector(fld.domain)
ift.extra.assert_allclose(fld, op.adjoint(op(fld)))
assert op.target.size == fld.size
arr = np.empty(fld.size)
arr[:dom0.size] = fld['a'].val.flatten()
arr[dom0.size:] = fld['b'].val.flatten()
assert_allclose(arr, op(fld).val)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment