Commit 22e1116e authored by Martin Reinecke's avatar Martin Reinecke

use new functionality in tests

parent f1635286
Pipeline #24386 passed with stage
in 6 minutes and 10 seconds
from .operator_tests import adjoint_implementation, inverse_implementation, full_implementation
from .operator_tests import consistency_check
import numpy as np
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..field import Field
from .. import dobj
__all__ = ['adjoint_implementation', 'inverse_implemenation', 'full_implementation']
__all__ = ["consistency_check"]
def adjoint_implementation(op, domain_dtype=np.float64, target_dtype=np.float64, atol=0, rtol=1e-7):
f1 = Field.from_random("normal", domain=op.domain, dtype=domain_dtype)
f2 = Field.from_random("normal", domain=op.target, dtype=target_dtype)
def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.ADJOINT_TIMES
if (op.capability & needed_cap) != needed_cap:
return
f1 = Field.from_random("normal", op.domain, dtype=domain_dtype)
f2 = Field.from_random("normal", op.target, dtype=target_dtype)
res1 = f1.vdot(op.adjoint_times(f2))
res2 = op.times(f1).vdot(f2)
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
# Return relative error
return (res1 - res2) / (res1 + res2) * 2
def inverse_implementation(op, domain_dtype=np.float64, target_dtype=np.float64, atol=0, rtol=1e-7):
foo = Field.from_random(domain=op.target, random_type='normal', dtype=target_dtype)
res = op(op.inverse_times(foo)).val
np.testing.assert_allclose(res, foo.val, atol=atol, rtol=rtol)
def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
needed_cap = op.TIMES | op.INVERSE_TIMES
if (op.capability & needed_cap) != needed_cap:
return
foo = Field.from_random("normal", op.target, dtype=target_dtype)
res = op(op.inverse_times(foo))
np.testing.assert_allclose(dobj.to_global_data(res.val),
dobj.to_global_data(foo.val),
atol=atol, rtol=rtol)
foo = Field.from_random(domain=op.domain, random_type='normal', dtype=domain_dtype)
res = op.inverse_times(op(foo)).val
np.testing.assert_allclose(res, foo.val, atol=atol, rtol=rtol)
foo = Field.from_random("normal", op.domain, dtype=domain_dtype)
res = op.inverse_times(op(foo))
np.testing.assert_allclose(dobj.to_global_data(res.val),
dobj.to_global_data(foo.val),
atol=atol, rtol=rtol)
# Return relative error
return (res - foo.val) / (res + foo.val) * 2
def full_implementation(op, domain_dtype, target_dtype, atol, rtol):
adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol)
inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)
def full_implementation(op, domain_dtype=np.float64, target_dtype=np.float64, atol=0, rtol=1e-7):
res1 = inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)
res2 = adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol)
res3 = adjoint_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol)
return res1, res2, res3
def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=0, rtol=1e-7):
full_implementation(op, domain_dtype, target_dtype, atol, rtol)
full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol)
full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol)
full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol,
rtol)
......@@ -259,8 +259,6 @@ class Curvature_Tests(unittest.TestCase):
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
print(a.vdot(a))
print(b.vdot(b))
tol = 1e-7
assert_allclose(ift.dobj.to_global_data(a.val),
ift.dobj.to_global_data(b.val), rtol=tol, atol=tol)
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
......@@ -21,23 +21,6 @@ import nifty4 as ift
import numpy as np
from itertools import product
from test.common import expand
from numpy.testing import assert_allclose
def _check_adjointness(op, dtype=np.float64):
f1 = ift.Field.from_random("normal", domain=op.domain, dtype=dtype)
f2 = ift.Field.from_random("normal", domain=op.target, dtype=dtype)
cap = op.capability
if ((cap & ift.LinearOperator.TIMES) and
(cap & ift.LinearOperator.ADJOINT_TIMES)):
assert_allclose(f1.vdot(op.adjoint_times(f2)),
op.times(f1).vdot(f2),
rtol=1e-8)
if ((cap & ift.LinearOperator.INVERSE_TIMES) and
(cap & ift.LinearOperator.INVERSE_ADJOINT_TIMES)):
assert_allclose(f1.vdot(op.inverse_times(f2)),
op.inverse_adjoint_times(f1).vdot(f2),
rtol=1e-8)
_h_RG_spaces = [ift.RGSpace(7, distances=0.2, harmonic=True),
......@@ -49,29 +32,35 @@ _p_RG_spaces = [ift.RGSpace(19, distances=0.7),
_p_spaces = _p_RG_spaces + [ift.HPSpace(17), ift.GLSpace(8, 13)]
class Adjointness_Tests(unittest.TestCase):
class Consistency_Tests(unittest.TestCase):
@expand(product(_h_spaces, [np.float64, np.complex128]))
def testPPO(self, sp, dtype):
op = ift.PowerProjectionOperator(sp)
_check_adjointness(op, dtype)
ift.extra.consistency_check(op, dtype, dtype)
ps = ift.PowerSpace(
sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=False, nbin=3))
op = ift.PowerProjectionOperator(sp, ps)
_check_adjointness(op, dtype)
ift.extra.consistency_check(op, dtype, dtype)
ps = ift.PowerSpace(
sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=True, nbin=3))
op = ift.PowerProjectionOperator(sp, ps)
_check_adjointness(op, dtype)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_RG_spaces+_p_RG_spaces,
[np.float64, np.complex128]))
def testFFT(self, sp, dtype):
op = ift.FFTOperator(sp)
_check_adjointness(op, dtype)
ift.extra.consistency_check(op, dtype, dtype)
op = ift.FFTOperator(sp.get_default_codomain())
_check_adjointness(op, dtype)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces, [np.float64, np.complex128]))
def testHarmonic(self, sp, dtype):
op = ift.HarmonicTransformOperator(sp)
_check_adjointness(op, dtype)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces+_p_spaces, [np.float64, np.complex128]))
def testDiagonal(self, sp, dtype):
op = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
dtype=dtype))
ift.extra.consistency_check(op, dtype, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment