Commit 22e1116e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

use new functionality in tests

parent f1635286
Pipeline #24386 passed with stage
in 6 minutes and 10 seconds
from .operator_tests import adjoint_implementation, inverse_implementation, full_implementation from .operator_tests import consistency_check
import numpy as np # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ..field import Field from ..field import Field
from .. import dobj
__all__ = ['adjoint_implementation', 'inverse_implemenation', 'full_implementation'] __all__ = ["consistency_check"]
def adjoint_implementation(op, domain_dtype=np.float64, target_dtype=np.float64, atol=0, rtol=1e-7): def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
f1 = Field.from_random("normal", domain=op.domain, dtype=domain_dtype) needed_cap = op.TIMES | op.ADJOINT_TIMES
f2 = Field.from_random("normal", domain=op.target, dtype=target_dtype) if (op.capability & needed_cap) != needed_cap:
return
f1 = Field.from_random("normal", op.domain, dtype=domain_dtype)
f2 = Field.from_random("normal", op.target, dtype=target_dtype)
res1 = f1.vdot(op.adjoint_times(f2)) res1 = f1.vdot(op.adjoint_times(f2))
res2 = op.times(f1).vdot(f2) res2 = op.times(f1).vdot(f2)
np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol) np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
# Return relative error
return (res1 - res2) / (res1 + res2) * 2
def inverse_implementation(op, domain_dtype=np.float64, target_dtype=np.float64, atol=0, rtol=1e-7): def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
foo = Field.from_random(domain=op.target, random_type='normal', dtype=target_dtype) needed_cap = op.TIMES | op.INVERSE_TIMES
res = op(op.inverse_times(foo)).val if (op.capability & needed_cap) != needed_cap:
np.testing.assert_allclose(res, foo.val, atol=atol, rtol=rtol) return
foo = Field.from_random("normal", op.target, dtype=target_dtype)
res = op(op.inverse_times(foo))
np.testing.assert_allclose(dobj.to_global_data(res.val),
dobj.to_global_data(foo.val),
atol=atol, rtol=rtol)
foo = Field.from_random(domain=op.domain, random_type='normal', dtype=domain_dtype) foo = Field.from_random("normal", op.domain, dtype=domain_dtype)
res = op.inverse_times(op(foo)).val res = op.inverse_times(op(foo))
np.testing.assert_allclose(res, foo.val, atol=atol, rtol=rtol) np.testing.assert_allclose(dobj.to_global_data(res.val),
dobj.to_global_data(foo.val),
atol=atol, rtol=rtol)
# Return relative error
return (res - foo.val) / (res + foo.val) * 2
def full_implementation(op, domain_dtype, target_dtype, atol, rtol):
adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol)
inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)
def full_implementation(op, domain_dtype=np.float64, target_dtype=np.float64, atol=0, rtol=1e-7):
res1 = inverse_implementation(op, domain_dtype, target_dtype, atol, rtol)
res2 = adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol)
res3 = adjoint_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol)
return res1, res2, res3 def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
atol=0, rtol=1e-7):
full_implementation(op, domain_dtype, target_dtype, atol, rtol)
full_implementation(op.adjoint, target_dtype, domain_dtype, atol, rtol)
full_implementation(op.inverse, target_dtype, domain_dtype, atol, rtol)
full_implementation(op.adjoint.inverse, domain_dtype, target_dtype, atol,
rtol)
...@@ -259,8 +259,6 @@ class Curvature_Tests(unittest.TestCase): ...@@ -259,8 +259,6 @@ class Curvature_Tests(unittest.TestCase):
a = (gradient1 - gradient0) / eps a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction) b = energy0.curvature(direction)
print(a.vdot(a))
print(b.vdot(b))
tol = 1e-7 tol = 1e-7
assert_allclose(ift.dobj.to_global_data(a.val), assert_allclose(ift.dobj.to_global_data(a.val),
ift.dobj.to_global_data(b.val), rtol=tol, atol=tol) ift.dobj.to_global_data(b.val), rtol=tol, atol=tol)
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2017 Max-Planck-Society # Copyright(C) 2013-2018 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
...@@ -21,23 +21,6 @@ import nifty4 as ift ...@@ -21,23 +21,6 @@ import nifty4 as ift
import numpy as np import numpy as np
from itertools import product from itertools import product
from test.common import expand from test.common import expand
from numpy.testing import assert_allclose
def _check_adjointness(op, dtype=np.float64):
f1 = ift.Field.from_random("normal", domain=op.domain, dtype=dtype)
f2 = ift.Field.from_random("normal", domain=op.target, dtype=dtype)
cap = op.capability
if ((cap & ift.LinearOperator.TIMES) and
(cap & ift.LinearOperator.ADJOINT_TIMES)):
assert_allclose(f1.vdot(op.adjoint_times(f2)),
op.times(f1).vdot(f2),
rtol=1e-8)
if ((cap & ift.LinearOperator.INVERSE_TIMES) and
(cap & ift.LinearOperator.INVERSE_ADJOINT_TIMES)):
assert_allclose(f1.vdot(op.inverse_times(f2)),
op.inverse_adjoint_times(f1).vdot(f2),
rtol=1e-8)
_h_RG_spaces = [ift.RGSpace(7, distances=0.2, harmonic=True), _h_RG_spaces = [ift.RGSpace(7, distances=0.2, harmonic=True),
...@@ -49,29 +32,35 @@ _p_RG_spaces = [ift.RGSpace(19, distances=0.7), ...@@ -49,29 +32,35 @@ _p_RG_spaces = [ift.RGSpace(19, distances=0.7),
_p_spaces = _p_RG_spaces + [ift.HPSpace(17), ift.GLSpace(8, 13)] _p_spaces = _p_RG_spaces + [ift.HPSpace(17), ift.GLSpace(8, 13)]
class Adjointness_Tests(unittest.TestCase): class Consistency_Tests(unittest.TestCase):
@expand(product(_h_spaces, [np.float64, np.complex128])) @expand(product(_h_spaces, [np.float64, np.complex128]))
def testPPO(self, sp, dtype): def testPPO(self, sp, dtype):
op = ift.PowerProjectionOperator(sp) op = ift.PowerProjectionOperator(sp)
_check_adjointness(op, dtype) ift.extra.consistency_check(op, dtype, dtype)
ps = ift.PowerSpace( ps = ift.PowerSpace(
sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=False, nbin=3)) sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=False, nbin=3))
op = ift.PowerProjectionOperator(sp, ps) op = ift.PowerProjectionOperator(sp, ps)
_check_adjointness(op, dtype) ift.extra.consistency_check(op, dtype, dtype)
ps = ift.PowerSpace( ps = ift.PowerSpace(
sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=True, nbin=3)) sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=True, nbin=3))
op = ift.PowerProjectionOperator(sp, ps) op = ift.PowerProjectionOperator(sp, ps)
_check_adjointness(op, dtype) ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_RG_spaces+_p_RG_spaces, @expand(product(_h_RG_spaces+_p_RG_spaces,
[np.float64, np.complex128])) [np.float64, np.complex128]))
def testFFT(self, sp, dtype): def testFFT(self, sp, dtype):
op = ift.FFTOperator(sp) op = ift.FFTOperator(sp)
_check_adjointness(op, dtype) ift.extra.consistency_check(op, dtype, dtype)
op = ift.FFTOperator(sp.get_default_codomain()) op = ift.FFTOperator(sp.get_default_codomain())
_check_adjointness(op, dtype) ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces, [np.float64, np.complex128])) @expand(product(_h_spaces, [np.float64, np.complex128]))
def testHarmonic(self, sp, dtype): def testHarmonic(self, sp, dtype):
op = ift.HarmonicTransformOperator(sp) op = ift.HarmonicTransformOperator(sp)
_check_adjointness(op, dtype) ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces+_p_spaces, [np.float64, np.complex128]))
def testDiagonal(self, sp, dtype):
op = ift.DiagonalOperator(ift.Field.from_random("normal", sp,
dtype=dtype))
ift.extra.consistency_check(op, dtype, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment