Commit bc31f0ba authored by Philipp Arras's avatar Philipp Arras
Browse files

Streamline tests

parent 9ddc2cd9
...@@ -21,6 +21,7 @@ from ..domain_tuple import DomainTuple ...@@ -21,6 +21,7 @@ from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field from ..field import Field
from ..linearization import Linearization from ..linearization import Linearization
from ..minimization.energy_adapter import StochasticEnergyAdapter
from ..multi_field import MultiField from ..multi_field import MultiField
from ..operators.einsum import MultiLinearEinsum from ..operators.einsum import MultiLinearEinsum
from ..operators.energy_operators import EnergyOperator from ..operators.energy_operators import EnergyOperator
...@@ -28,8 +29,7 @@ from ..operators.linear_operator import LinearOperator ...@@ -28,8 +29,7 @@ from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator from ..operators.sandwich_operator import SandwichOperator
from ..operators.simple_linear_operators import FieldAdapter from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import full, makeField, makeDomain, from_random, is_fieldlike from ..sugar import from_random, full, is_fieldlike, makeDomain, makeField
from ..minimization.energy_adapter import StochasticEnergyAdapter
from ..utilities import myassert from ..utilities import myassert
...@@ -168,16 +168,16 @@ class GaussianEntropy(EnergyOperator): ...@@ -168,16 +168,16 @@ class GaussianEntropy(EnergyOperator):
Parameters Parameters
---------- ----------
domain: Domain domain: Domain FIXME
The domain of the diagonal. The domain of the diagonal.
""" """
def __init__(self, domain): def __init__(self, domain):
self._domain = domain self._domain = DomainTuple.make(domain)
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
res = -0.5*(2*np.pi*np.e*x**2).log().sum() res = (x*x).scale(2*np.pi*np.e).log().sum().scale(-0.5)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return res return res
if not x.want_metric: if not x.want_metric:
...@@ -191,7 +191,7 @@ class LowerTriangularInserter(LinearOperator): ...@@ -191,7 +191,7 @@ class LowerTriangularInserter(LinearOperator):
Parameters Parameters
---------- ----------
target: Domain target: Domain FIXME
A two-dimensional domain with NxN entries. A two-dimensional domain with NxN entries.
""" """
...@@ -220,7 +220,7 @@ class DiagonalSelector(LinearOperator): ...@@ -220,7 +220,7 @@ class DiagonalSelector(LinearOperator):
Parameters Parameters
---------- ----------
domain: Domain domain: Domain FIXME
The two-dimensional domain of the input field. Must be of shape NxN. The two-dimensional domain of the input field. Must be of shape NxN.
""" """
......
...@@ -11,15 +11,14 @@ ...@@ -11,15 +11,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2020 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np import numpy as np
import pytest import pytest
import nifty7 as ift
from .common import list2fixture, setup_function, teardown_function from .common import list2fixture, setup_function, teardown_function
spaces = [ift.GLSpace(5), spaces = [ift.GLSpace(5),
...@@ -59,7 +58,7 @@ def test_QuadraticFormOperator(field): ...@@ -59,7 +58,7 @@ def test_QuadraticFormOperator(field):
def test_studentt(field): def test_studentt(field):
if isinstance(field.domain, ift.MultiDomain): if isinstance(field.domain, ift.MultiDomain):
return pytest.skip()
energy = ift.StudentTEnergy(domain=field.domain, theta=.5) energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
ift.extra.check_operator(energy, field) ift.extra.check_operator(energy, field)
theta = ift.from_random(field.domain, 'normal').exp() theta = ift.from_random(field.domain, 'normal').exp()
...@@ -80,7 +79,7 @@ def test_hamiltonian_and_KL(field): ...@@ -80,7 +79,7 @@ def test_hamiltonian_and_KL(field):
def test_variablecovariancegaussian(field): def test_variablecovariancegaussian(field):
if isinstance(field.domain, ift.MultiDomain): if isinstance(field.domain, ift.MultiDomain):
return pytest.skip()
dc = {'a': field, 'b': field.ptw("exp")} dc = {'a': field, 'b': field.ptw("exp")}
mf = ift.MultiField.from_dict(dc) mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b', np.float64) energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b', np.float64)
...@@ -90,7 +89,7 @@ def test_variablecovariancegaussian(field): ...@@ -90,7 +89,7 @@ def test_variablecovariancegaussian(field):
def test_specialgamma(field): def test_specialgamma(field):
if isinstance(field.domain, ift.MultiDomain): if isinstance(field.domain, ift.MultiDomain):
return pytest.skip()
energy = ift.operators.energy_operators._SpecialGammaEnergy(field) energy = ift.operators.energy_operators._SpecialGammaEnergy(field)
loc = ift.from_random(energy.domain).exp() loc = ift.from_random(energy.domain).exp()
ift.extra.check_operator(energy, loc, ntries=ntries) ift.extra.check_operator(energy, loc, ntries=ntries)
...@@ -99,7 +98,7 @@ def test_specialgamma(field): ...@@ -99,7 +98,7 @@ def test_specialgamma(field):
def test_inverse_gamma(field): def test_inverse_gamma(field):
if isinstance(field.domain, ift.MultiDomain): if isinstance(field.domain, ift.MultiDomain):
return pytest.skip()
field = field.ptw("exp") field = field.ptw("exp")
space = field.domain space = field.domain
d = ift.random.current_rng().normal(10, size=space.shape)**2 d = ift.random.current_rng().normal(10, size=space.shape)**2
...@@ -110,7 +109,7 @@ def test_inverse_gamma(field): ...@@ -110,7 +109,7 @@ def test_inverse_gamma(field):
def testPoissonian(field): def testPoissonian(field):
if isinstance(field.domain, ift.MultiDomain): if isinstance(field.domain, ift.MultiDomain):
return pytest.skip()
field = field.ptw("exp") field = field.ptw("exp")
space = field.domain space = field.domain
d = ift.random.current_rng().poisson(120, size=space.shape) d = ift.random.current_rng().poisson(120, size=space.shape)
...@@ -121,7 +120,7 @@ def testPoissonian(field): ...@@ -121,7 +120,7 @@ def testPoissonian(field):
def test_bernoulli(field): def test_bernoulli(field):
if isinstance(field.domain, ift.MultiDomain): if isinstance(field.domain, ift.MultiDomain):
return pytest.skip()
field = field.ptw("sigmoid") field = field.ptw("sigmoid")
space = field.domain space = field.domain
d = ift.random.current_rng().binomial(1, 0.1, size=space.shape) d = ift.random.current_rng().binomial(1, 0.1, size=space.shape)
...@@ -129,10 +128,10 @@ def test_bernoulli(field): ...@@ -129,10 +128,10 @@ def test_bernoulli(field):
energy = ift.BernoulliEnergy(d) energy = ift.BernoulliEnergy(d)
ift.extra.check_operator(energy, field, tol=1e-10) ift.extra.check_operator(energy, field, tol=1e-10)
def test_gaussian_entropy(field): def test_gaussian_entropy(field):
if isinstance(field.domain, ift.MultiDomain): if isinstance(field.domain, ift.MultiDomain):
return pytest.skip()
field = field.ptw("sigmoid") field = field.ptw("sigmoid")
energy = ift.library.variational_models.GaussianEntropy(field.domain) energy = ift.library.variational_models.GaussianEntropy(field.domain)
ift.extra.check_operator(energy, field)
ift.extra.check_operator(energy, field)
\ No newline at end of file
...@@ -11,15 +11,14 @@ ...@@ -11,15 +11,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2020 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np import numpy as np
import pytest import pytest
import nifty7 as ift
from ..common import list2fixture, setup_function, teardown_function from ..common import list2fixture, setup_function, teardown_function
_h_RG_spaces = [ift.RGSpace(7, distances=0.2, harmonic=True), _h_RG_spaces = [ift.RGSpace(7, distances=0.2, harmonic=True),
...@@ -340,9 +339,9 @@ def testDiagonalExtractor(seed): ...@@ -340,9 +339,9 @@ def testDiagonalExtractor(seed):
ift.extra.check_linear_operator(op) ift.extra.check_linear_operator(op)
@pmp('seed', [12, 3]) @pmp('seed', [12, 3])
def testLowerTriangularInserter(seed): @pmp("N", [10, 17])
N = 42 def testLowerTriangularInserter(seed, N):
square_space = ift.RGSpace([N,N]) square_space = ift.RGSpace([N, N])
op = ift.library.variational_models.LowerTriangularInserter(square_space) op = ift.library.variational_models.LowerTriangularInserter(square_space)
ift.extra.check_linear_operator(op) ift.extra.check_linear_operator(op)
......
...@@ -11,25 +11,34 @@ ...@@ -11,25 +11,34 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2020 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np import numpy as np
import pytest
from nifty7.library.variational_models import DiagonalSelector
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
import nifty7 as ift from ..common import list2fixture, setup_function, teardown_function
from nifty7.library.variational_models import DiagonalSelector pmp = pytest.mark.parametrize
def test_diagonal_selector(): @pmp("N", [17, 32])
N = 42 def test_diagonal_selector(N):
square_space = ift.RGSpace([N,N]) square_space = ift.RGSpace([N, N])
linear_space = ift.RGSpace(N) linear_space = ift.RGSpace(N)
myField = ift.from_random(square_space) myField = ift.from_random(square_space)
myDiagonalSelector = DiagonalSelector(square_space) myDiagonalSelector = DiagonalSelector(square_space)
selected = myDiagonalSelector(myField).val selected = myDiagonalSelector(myField).val
np_selected = np.diag(myField.val) np_selected = np.diag(myField.val)
assert_allclose(np_selected, selected) assert_allclose(np_selected, selected)
@pmp("n", [1, 3])
def test_error(n):
square_space = ift.RGSpace(n*[13])
with pytest.raises(AssertionError):
myDiagonalSelector = DiagonalSelector(square_space)
...@@ -11,24 +11,26 @@ ...@@ -11,24 +11,26 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2020 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np import numpy as np
import pytest
from nifty7.library.variational_models import GaussianEntropy
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
import nifty7 as ift from ..common import list2fixture, setup_function, teardown_function
from nifty7.library.variational_models import GaussianEntropy pmp = pytest.mark.parametrize
def test_gaussian_entropy(): @pmp("N", [17, 32])
N = 42 def test_gaussian_entropy(N):
linear_space = ift.RGSpace(N) linear_space = ift.RGSpace(N)
myField = ift.from_random(linear_space, 'uniform') fld = ift.from_random(linear_space, 'uniform')
vals = myField.val # minus due to subtraction in KL
entropy = - 0.5 * np.sum(np.log(2 * vals**2 * np.pi * np.e)) # minus due to subtraction in KL entropy = -0.5*np.sum(np.log(2*fld.val*fld.val*np.pi*np.e))
myEntropy = GaussianEntropy(myField.domain) myEntropy = GaussianEntropy(fld.domain)
assert_allclose(entropy, myEntropy(myField).val) assert_allclose(entropy, myEntropy(fld).val)
...@@ -11,39 +11,30 @@ ...@@ -11,39 +11,30 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2020 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np import numpy as np
import pytest
from nifty7.library.variational_models import LowerTriangularInserter
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
import nifty7 as ift from ..common import list2fixture, setup_function, teardown_function
from nifty7.library.variational_models import LowerTriangularInserter
pmp = pytest.mark.parametrize
def test_lower_triangular_inserter():
N = 42
square_space = ift.RGSpace([N,N])
myInserter = LowerTriangularInserter(square_space)
flat_space = myInserter.domain
assert_allclose(flat_space.size, N*(N+1)//2)
myField = ift.from_random(flat_space) @pmp("N", [17, 32])
myMatrix = myInserter(myField) def test_lower_triangular_inserter(N):
square_space = ift.RGSpace([N, N])
op = LowerTriangularInserter(square_space)
flat_space = op.domain
assert_allclose(flat_space.shape, (N*(N+1)//2,))
upper_i = [] mat = op(ift.from_random(flat_space))
upper_j = []
for i in range(0,N): for i in range(0,N):
for j in range(i+1,N): for j in range(i+1, N):
upper_i.append(i) assert mat.val[i, j] == 0.
upper_j.append(j) assert_allclose((mat.val == 0).sum(), (N-1)*N//2)
zeros = myMatrix.val[(upper_i,upper_j)]
assert_allclose(zeros, np.zeros((N-1)*N//2))
assert_allclose((myMatrix==0).val.sum(), (N-1)*N//2)
...@@ -11,36 +11,31 @@ ...@@ -11,36 +11,31 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2020 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import nifty7 as ift
import numpy as np import numpy as np
import pytest
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
import nifty7 as ift from ..common import list2fixture, setup_function, teardown_function
from .test_adjoint import _h_spaces, _p_spaces, _pow_spaces
def test_multifield2vector():
myFirstSpace = ift.RGSpace([42,43])
mySecondSpace = ift.RGSpace([1,2,3])
myDict = {
'first': ift.from_random(myFirstSpace),
'second': ift.from_random(mySecondSpace)
}
myMultiField = ift.MultiField.from_dict(myDict)
myMultifield2Vector = ift.Multifield2Vector(myMultiField.domain)
myVector = myMultifield2Vector(myMultiField)
assert_allclose(myVector.size, myMultiField.size)
myNumpyVector = np.empty(myMultiField.size)
myNumpyVector[:myFirstSpace.size] = myDict['first'].val.flatten()
myNumpyVector[myFirstSpace.size:] = myDict['second'].val.flatten()
assert_allclose(myNumpyVector, myVector.val)
mySecondMultiField = myMultifield2Vector.adjoint(myVector) pmp = pytest.mark.parametrize
assert_allclose(mySecondMultiField['first'].val, myMultiField['first'].val)
assert_allclose(mySecondMultiField['second'].val, myMultiField['second'].val)
@pmp("dom0", _h_spaces + _p_spaces + _pow_spaces)
@pmp("dom1", _h_spaces + _p_spaces + _pow_spaces)
def test_multifield2vector(dom0, dom1):
fld = ift.MultiField.from_dict({'a': ift.from_random(dom0),
'b': ift.from_random(dom1)})
op = ift.Multifield2Vector(fld.domain)
ift.extra.assert_allclose(fld, op.adjoint(op(fld)))
assert op.target.size == fld.size
arr = np.empty(fld.size)
arr[:dom0.size] = fld['a'].val.flatten()
arr[dom0.size:] = fld['b'].val.flatten()
assert_allclose(arr, op(fld).val)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment