Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
22e1116e
Commit
22e1116e
authored
Feb 05, 2018
by
Martin Reinecke
Browse files
use new functionality in tests
parent
f1635286
Pipeline
#24386
passed with stage
in 6 minutes and 10 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty4/extra/__init__.py
View file @
22e1116e
from
.operator_tests
import
adjoint_implementation
,
inverse_implementation
,
full_implementation
from
.operator_tests
import
consistency_check
nifty4/extra/operator_tests.py
View file @
22e1116e
import
numpy
as
np
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import
numpy
as
np
from
..field
import
Field
from
..
import
dobj
__all__
=
[
'adjoint_implementation'
,
'inverse_implemenation'
,
'full_implementation'
]
__all__
=
[
"consistency_check"
]
def
adjoint_implementation
(
op
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
float64
,
atol
=
0
,
rtol
=
1e-7
):
f1
=
Field
.
from_random
(
"normal"
,
domain
=
op
.
domain
,
dtype
=
domain_dtype
)
f2
=
Field
.
from_random
(
"normal"
,
domain
=
op
.
target
,
dtype
=
target_dtype
)
def
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
ADJOINT_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
f1
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
f2
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res1
=
f1
.
vdot
(
op
.
adjoint_times
(
f2
))
res2
=
op
.
times
(
f1
).
vdot
(
f2
)
np
.
testing
.
assert_allclose
(
res1
,
res2
,
atol
=
atol
,
rtol
=
rtol
)
# Return relative error
return
(
res1
-
res2
)
/
(
res1
+
res2
)
*
2
def
inverse_implementation
(
op
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
float64
,
atol
=
0
,
rtol
=
1e-7
):
foo
=
Field
.
from_random
(
domain
=
op
.
target
,
random_type
=
'normal'
,
dtype
=
target_dtype
)
res
=
op
(
op
.
inverse_times
(
foo
)).
val
np
.
testing
.
assert_allclose
(
res
,
foo
.
val
,
atol
=
atol
,
rtol
=
rtol
)
def
inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
needed_cap
=
op
.
TIMES
|
op
.
INVERSE_TIMES
if
(
op
.
capability
&
needed_cap
)
!=
needed_cap
:
return
foo
=
Field
.
from_random
(
"normal"
,
op
.
target
,
dtype
=
target_dtype
)
res
=
op
(
op
.
inverse_times
(
foo
))
np
.
testing
.
assert_allclose
(
dobj
.
to_global_data
(
res
.
val
),
dobj
.
to_global_data
(
foo
.
val
),
atol
=
atol
,
rtol
=
rtol
)
foo
=
Field
.
from_random
(
domain
=
op
.
domain
,
random_type
=
'normal'
,
dtype
=
domain_dtype
)
res
=
op
.
inverse_times
(
op
(
foo
)).
val
np
.
testing
.
assert_allclose
(
res
,
foo
.
val
,
atol
=
atol
,
rtol
=
rtol
)
foo
=
Field
.
from_random
(
"normal"
,
op
.
domain
,
dtype
=
domain_dtype
)
res
=
op
.
inverse_times
(
op
(
foo
))
np
.
testing
.
assert_allclose
(
dobj
.
to_global_data
(
res
.
val
),
dobj
.
to_global_data
(
foo
.
val
),
atol
=
atol
,
rtol
=
rtol
)
# Return relative error
return
(
res
-
foo
.
val
)
/
(
res
+
foo
.
val
)
*
2
def
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
):
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
def
full_implementation
(
op
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
float64
,
atol
=
0
,
rtol
=
1e-7
):
res1
=
inverse_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
res2
=
adjoint_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
res3
=
adjoint_implementation
(
op
.
inverse
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
return
res1
,
res2
,
res3
def
consistency_check
(
op
,
domain_dtype
=
np
.
float64
,
target_dtype
=
np
.
float64
,
atol
=
0
,
rtol
=
1e-7
):
full_implementation
(
op
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
full_implementation
(
op
.
adjoint
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
full_implementation
(
op
.
inverse
,
target_dtype
,
domain_dtype
,
atol
,
rtol
)
full_implementation
(
op
.
adjoint
.
inverse
,
domain_dtype
,
target_dtype
,
atol
,
rtol
)
test/test_energies/test_map.py
View file @
22e1116e
...
...
@@ -259,8 +259,6 @@ class Curvature_Tests(unittest.TestCase):
a
=
(
gradient1
-
gradient0
)
/
eps
b
=
energy0
.
curvature
(
direction
)
print
(
a
.
vdot
(
a
))
print
(
b
.
vdot
(
b
))
tol
=
1e-7
assert_allclose
(
ift
.
dobj
.
to_global_data
(
a
.
val
),
ift
.
dobj
.
to_global_data
(
b
.
val
),
rtol
=
tol
,
atol
=
tol
)
test/test_operators/test_adjoint.py
View file @
22e1116e
...
...
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-201
7
Max-Planck-Society
# Copyright(C) 2013-201
8
Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
...
...
@@ -21,23 +21,6 @@ import nifty4 as ift
import
numpy
as
np
from
itertools
import
product
from
test.common
import
expand
from
numpy.testing
import
assert_allclose
def
_check_adjointness
(
op
,
dtype
=
np
.
float64
):
f1
=
ift
.
Field
.
from_random
(
"normal"
,
domain
=
op
.
domain
,
dtype
=
dtype
)
f2
=
ift
.
Field
.
from_random
(
"normal"
,
domain
=
op
.
target
,
dtype
=
dtype
)
cap
=
op
.
capability
if
((
cap
&
ift
.
LinearOperator
.
TIMES
)
and
(
cap
&
ift
.
LinearOperator
.
ADJOINT_TIMES
)):
assert_allclose
(
f1
.
vdot
(
op
.
adjoint_times
(
f2
)),
op
.
times
(
f1
).
vdot
(
f2
),
rtol
=
1e-8
)
if
((
cap
&
ift
.
LinearOperator
.
INVERSE_TIMES
)
and
(
cap
&
ift
.
LinearOperator
.
INVERSE_ADJOINT_TIMES
)):
assert_allclose
(
f1
.
vdot
(
op
.
inverse_times
(
f2
)),
op
.
inverse_adjoint_times
(
f1
).
vdot
(
f2
),
rtol
=
1e-8
)
_h_RG_spaces
=
[
ift
.
RGSpace
(
7
,
distances
=
0.2
,
harmonic
=
True
),
...
...
@@ -49,29 +32,35 @@ _p_RG_spaces = [ift.RGSpace(19, distances=0.7),
_p_spaces
=
_p_RG_spaces
+
[
ift
.
HPSpace
(
17
),
ift
.
GLSpace
(
8
,
13
)]
class
Adjointness
_Tests
(
unittest
.
TestCase
):
class
Consistency
_Tests
(
unittest
.
TestCase
):
@
expand
(
product
(
_h_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testPPO
(
self
,
sp
,
dtype
):
op
=
ift
.
PowerProjectionOperator
(
sp
)
_check_adjointness
(
op
,
dtype
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
ps
=
ift
.
PowerSpace
(
sp
,
ift
.
PowerSpace
.
useful_binbounds
(
sp
,
logarithmic
=
False
,
nbin
=
3
))
op
=
ift
.
PowerProjectionOperator
(
sp
,
ps
)
_check_adjointness
(
op
,
dtype
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
ps
=
ift
.
PowerSpace
(
sp
,
ift
.
PowerSpace
.
useful_binbounds
(
sp
,
logarithmic
=
True
,
nbin
=
3
))
op
=
ift
.
PowerProjectionOperator
(
sp
,
ps
)
_check_adjointness
(
op
,
dtype
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_h_RG_spaces
+
_p_RG_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testFFT
(
self
,
sp
,
dtype
):
op
=
ift
.
FFTOperator
(
sp
)
_check_adjointness
(
op
,
dtype
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
op
=
ift
.
FFTOperator
(
sp
.
get_default_codomain
())
_check_adjointness
(
op
,
dtype
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_h_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testHarmonic
(
self
,
sp
,
dtype
):
op
=
ift
.
HarmonicTransformOperator
(
sp
)
_check_adjointness
(
op
,
dtype
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_h_spaces
+
_p_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testDiagonal
(
self
,
sp
,
dtype
):
op
=
ift
.
DiagonalOperator
(
ift
.
Field
.
from_random
(
"normal"
,
sp
,
dtype
=
dtype
))
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment