Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
6705b416
Commit
6705b416
authored
Jul 09, 2017
by
Theo Steininger
Browse files
Merge branch 'master' into working_on_demos
parents
24c5b7c8
841b7ecb
Pipeline
#14527
passed with stage
in 6 minutes and 48 seconds
Changes
17
Pipelines
1
Show whitespace changes
Inline
Side-by-side
nifty/field.py
View file @
6705b416
...
@@ -599,58 +599,55 @@ class Field(Loggable, Versionable, object):
...
@@ -599,58 +599,55 @@ class Field(Loggable, Versionable, object):
# hermitianize for the first space
# hermitianize for the first space
(
h
,
a
)
=
domain
[
spaces
[
0
]].
hermitian_decomposition
(
(
h
,
a
)
=
domain
[
spaces
[
0
]].
hermitian_decomposition
(
val
,
val
,
domain_axes
[
spaces
[
0
]],
domain_axes
[
spaces
[
0
]])
preserve_gaussian_variance
=
preserve_gaussian_variance
)
# hermitianize all remaining spaces using the iterative formula
# hermitianize all remaining spaces using the iterative formula
for
space
in
xrange
(
1
,
len
(
spaces
))
:
for
space
in
spaces
[
1
:]
:
(
hh
,
ha
)
=
domain
[
space
].
hermitian_decomposition
(
(
hh
,
ha
)
=
domain
[
space
].
hermitian_decomposition
(
h
,
h
,
domain_axes
[
space
],
domain_axes
[
space
])
preserve_gaussian_variance
=
False
)
(
ah
,
aa
)
=
domain
[
space
].
hermitian_decomposition
(
(
ah
,
aa
)
=
domain
[
space
].
hermitian_decomposition
(
a
,
a
,
domain_axes
[
space
],
domain_axes
[
space
])
preserve_gaussian_variance
=
False
)
c
=
(
hh
-
ha
-
ah
+
aa
).
conjugate
()
c
=
(
hh
-
ha
-
ah
+
aa
).
conjugate
()
full
=
(
hh
+
ha
+
ah
+
aa
)
full
=
(
hh
+
ha
+
ah
+
aa
)
h
=
(
full
+
c
)
/
2.
h
=
(
full
+
c
)
/
2.
a
=
(
full
-
c
)
/
2.
a
=
(
full
-
c
)
/
2.
# correct variance
# correct variance
if
preserve_gaussian_variance
:
h
*=
np
.
sqrt
(
2
)
a
*=
np
.
sqrt
(
2
)
if
not
issubclass
(
val
.
dtype
.
type
,
np
.
complexfloating
):
# in principle one must not correct the variance for the fixed
# in principle one must not correct the variance for the fixed
# points of the hermitianization. However, for a complex field
# points of the hermitianization. However, for a complex field
# the input field loses half of its power at its fixed points
# the input field loses half of its power at its fixed points
# in the `hermitian` part. Hence, here a factor of sqrt(2) is
# in the `hermitian` part. Hence, here a factor of sqrt(2) is
# also necessary!
# also necessary!
# => The hermitianization can be done on a space level since either
# => The hermitianization can be done on a space level since
# nothing must be done (LMSpace) or ALL points need a factor of sqrt(2)
# either nothing must be done (LMSpace) or ALL points need a
# factor of sqrt(2)
# => use the preserve_gaussian_variance flag in the
# => use the preserve_gaussian_variance flag in the
# hermitian_decomposition method above.
# hermitian_decomposition method above.
# This code is for educational purposes:
# This code is for educational purposes:
# fixed_points = [domain[i].hermitian_fixed_points() for i in spaces]
fixed_points
=
[
domain
[
i
].
hermitian_fixed_points
()
# # check if there was at least one flipping during hermitianization
for
i
in
spaces
]
# flipped_Q = np.any([fp is not None for fp in fixed_points])
fixed_points
=
[[
fp
]
if
fp
is
None
else
fp
# # if the array got flipped, correct the variance
for
fp
in
fixed_points
]
# if flipped_Q:
# h *= np.sqrt(2)
for
product_point
in
itertools
.
product
(
*
fixed_points
):
# a *= np.sqrt(2)
slice_object
=
np
.
array
((
slice
(
None
),
)
*
len
(
val
.
shape
),
#
dtype
=
np
.
object
)
# fixed_points = [[fp] if fp is None else fp for fp in fixed_points]
for
i
,
sp
in
enumerate
(
spaces
):
# for product_point in itertools.product(*fixed_points):
point_component
=
product_point
[
i
]
# slice_object = np.array((slice(None), )*len(val.shape),
if
point_component
is
None
:
# dtype=np.object)
point_component
=
slice
(
None
)
# for i, sp in enumerate(spaces):
slice_object
[
list
(
domain_axes
[
sp
])]
=
point_component
# point_component = product_point[i]
# if point_component is None:
slice_object
=
tuple
(
slice_object
)
# point_component = slice(None)
h
[
slice_object
]
/=
np
.
sqrt
(
2
)
# slice_object[list(domain_axes[sp])] = point_component
a
[
slice_object
]
/=
np
.
sqrt
(
2
)
#
# slice_object = tuple(slice_object)
# h[slice_object] /= np.sqrt(2)
# a[slice_object] /= np.sqrt(2)
return
(
h
,
a
)
return
(
h
,
a
)
def
_spec_to_rescaler
(
self
,
spec
,
result_list
,
power_space_index
):
def
_spec_to_rescaler
(
self
,
spec
,
result_list
,
power_space_index
):
...
@@ -667,7 +664,7 @@ class Field(Loggable, Versionable, object):
...
@@ -667,7 +664,7 @@ class Field(Loggable, Versionable, object):
if
pindex
.
distribution_strategy
is
not
local_distribution_strategy
:
if
pindex
.
distribution_strategy
is
not
local_distribution_strategy
:
self
.
logger
.
warn
(
self
.
logger
.
warn
(
"The distribution_strag
e
y of pindex does not fit the "
"The distribution_stra
te
gy of pindex does not fit the "
"slice_local distribution strategy of the synthesized field."
)
"slice_local distribution strategy of the synthesized field."
)
# Now use numpy advanced indexing in order to put the entries of the
# Now use numpy advanced indexing in order to put the entries of the
...
@@ -675,8 +672,11 @@ class Field(Loggable, Versionable, object):
...
@@ -675,8 +672,11 @@ class Field(Loggable, Versionable, object):
# Do this for every 'pindex-slice' in parallel using the 'slice(None)'s
# Do this for every 'pindex-slice' in parallel using the 'slice(None)'s
local_pindex
=
pindex
.
get_local_data
(
copy
=
False
)
local_pindex
=
pindex
.
get_local_data
(
copy
=
False
)
local_blow_up
=
[
slice
(
None
)]
*
len
(
self
.
shape
)
local_blow_up
=
[
slice
(
None
)]
*
len
(
spec
.
shape
)
local_blow_up
[
self
.
domain_axes
[
power_space_index
][
0
]]
=
local_pindex
# it is important to count from behind, since spec potentially grows
# with every iteration
index
=
self
.
domain_axes
[
power_space_index
][
0
]
-
len
(
self
.
shape
)
local_blow_up
[
index
]
=
local_pindex
# here, the power_spectrum is distributed into the new shape
# here, the power_spectrum is distributed into the new shape
local_rescaler
=
spec
[
local_blow_up
]
local_rescaler
=
spec
[
local_blow_up
]
return
local_rescaler
return
local_rescaler
...
...
nifty/minimization/descent_minimizer.py
View file @
6705b416
...
@@ -156,13 +156,20 @@ class DescentMinimizer(Loggable, object):
...
@@ -156,13 +156,20 @@ class DescentMinimizer(Loggable, object):
pk
=
descend_direction
,
pk
=
descend_direction
,
f_k_minus_1
=
f_k_minus_1
)
f_k_minus_1
=
f_k_minus_1
)
f_k_minus_1
=
energy
.
value
f_k_minus_1
=
energy
.
value
energy
=
new_energy
# check if new energy value is bigger than old energy value
if
(
new_energy
.
value
-
energy
.
value
)
>
0
:
self
.
logger
.
info
(
"Line search algorithm returned a new energy "
"that was larger than the old one. Stopping."
)
break
energy
=
new_energy
# check convergence
# check convergence
delta
=
abs
(
gradient
).
max
()
*
(
step_length
/
gradient_norm
)
delta
=
abs
(
gradient
).
max
()
*
(
step_length
/
gradient_norm
)
self
.
logger
.
debug
(
"Iteration : %08u step_length = %3.1E "
self
.
logger
.
debug
(
"Iteration:%08u step_length=%3.1E "
"delta = %3.1E"
%
"delta=%3.1E energy=%3.1E"
%
(
iteration_number
,
step_length
,
delta
))
(
iteration_number
,
step_length
,
delta
,
energy
.
value
))
if
delta
==
0
:
if
delta
==
0
:
convergence
=
self
.
convergence_level
+
2
convergence
=
self
.
convergence_level
+
2
self
.
logger
.
info
(
"Found minimum according to line-search. "
self
.
logger
.
info
(
"Found minimum according to line-search. "
...
...
nifty/minimization/steepest_descent.py
View file @
6705b416
...
@@ -40,8 +40,4 @@ class SteepestDescent(DescentMinimizer):
...
@@ -40,8 +40,4 @@ class SteepestDescent(DescentMinimizer):
"""
"""
descend_direction
=
energy
.
gradient
descend_direction
=
energy
.
gradient
norm
=
descend_direction
.
norm
()
if
norm
!=
1
:
return
descend_direction
/
-
norm
else
:
return
descend_direction
*
-
1
return
descend_direction
*
-
1
nifty/minimization/vl_bfgs.py
View file @
6705b416
...
@@ -25,7 +25,7 @@ from .line_searching import LineSearchStrongWolfe
...
@@ -25,7 +25,7 @@ from .line_searching import LineSearchStrongWolfe
class
VL_BFGS
(
DescentMinimizer
):
class
VL_BFGS
(
DescentMinimizer
):
def
__init__
(
self
,
line_searcher
=
LineSearchStrongWolfe
(),
callback
=
None
,
def
__init__
(
self
,
line_searcher
=
LineSearchStrongWolfe
(),
callback
=
None
,
convergence_tolerance
=
1E-4
,
convergence_level
=
3
,
convergence_tolerance
=
1E-4
,
convergence_level
=
3
,
iteration_limit
=
None
,
max_history_length
=
10
):
iteration_limit
=
None
,
max_history_length
=
5
):
super
(
VL_BFGS
,
self
).
__init__
(
super
(
VL_BFGS
,
self
).
__init__
(
line_searcher
=
line_searcher
,
line_searcher
=
line_searcher
,
...
@@ -84,9 +84,6 @@ class VL_BFGS(DescentMinimizer):
...
@@ -84,9 +84,6 @@ class VL_BFGS(DescentMinimizer):
for
i
in
xrange
(
1
,
len
(
delta
)):
for
i
in
xrange
(
1
,
len
(
delta
)):
descend_direction
+=
delta
[
i
]
*
b
[
i
]
descend_direction
+=
delta
[
i
]
*
b
[
i
]
norm
=
descend_direction
.
norm
()
if
norm
!=
1
:
descend_direction
/=
norm
return
descend_direction
return
descend_direction
...
...
nifty/operators/diagonal_operator/diagonal_operator.py
View file @
6705b416
...
@@ -21,7 +21,6 @@ import numpy as np
...
@@ -21,7 +21,6 @@ import numpy as np
from
d2o
import
distributed_data_object
,
\
from
d2o
import
distributed_data_object
,
\
STRATEGIES
as
DISTRIBUTION_STRATEGIES
STRATEGIES
as
DISTRIBUTION_STRATEGIES
from
nifty.basic_arithmetics
import
log
as
nifty_log
from
nifty.config
import
nifty_configuration
as
gc
from
nifty.config
import
nifty_configuration
as
gc
from
nifty.field
import
Field
from
nifty.field
import
Field
from
nifty.operators.endomorphic_operator
import
EndomorphicOperator
from
nifty.operators.endomorphic_operator
import
EndomorphicOperator
...
...
nifty/operators/linear_operator/linear_operator.py
View file @
6705b416
...
@@ -73,7 +73,7 @@ class LinearOperator(Loggable, object):
...
@@ -73,7 +73,7 @@ class LinearOperator(Loggable, object):
__metaclass__
=
NiftyMeta
__metaclass__
=
NiftyMeta
def
__init__
(
self
,
default_spaces
=
None
):
def
__init__
(
self
,
default_spaces
=
None
):
self
.
default_spaces
=
default_spaces
self
.
_
default_spaces
=
default_spaces
@
staticmethod
@
staticmethod
def
_parse_domain
(
domain
):
def
_parse_domain
(
domain
):
...
@@ -119,10 +119,6 @@ class LinearOperator(Loggable, object):
...
@@ -119,10 +119,6 @@ class LinearOperator(Loggable, object):
def
default_spaces
(
self
):
def
default_spaces
(
self
):
return
self
.
_default_spaces
return
self
.
_default_spaces
@
default_spaces
.
setter
def
default_spaces
(
self
,
spaces
):
self
.
_default_spaces
=
utilities
.
cast_axis_to_tuple
(
spaces
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
times
(
*
args
,
**
kwargs
)
return
self
.
times
(
*
args
,
**
kwargs
)
...
...
nifty/operators/projection_operator/projection_operator.py
View file @
6705b416
...
@@ -163,3 +163,9 @@ class ProjectionOperator(EndomorphicOperator):
...
@@ -163,3 +163,9 @@ class ProjectionOperator(EndomorphicOperator):
@
property
@
property
def
self_adjoint
(
self
):
def
self_adjoint
(
self
):
return
True
return
True
# ---Added properties and methods---
@
property
def
projection_field
(
self
):
return
self
.
_projection_field
nifty/operators/smoothing_operator/smoothing_operator.py
View file @
6705b416
...
@@ -135,8 +135,8 @@ class SmoothingOperator(EndomorphicOperator):
...
@@ -135,8 +135,8 @@ class SmoothingOperator(EndomorphicOperator):
# "space as input domain.")
# "space as input domain.")
self
.
_domain
=
self
.
_parse_domain
(
domain
)
self
.
_domain
=
self
.
_parse_domain
(
domain
)
self
.
sigma
=
sigma
self
.
_
sigma
=
sigma
self
.
log_distances
=
log_distances
self
.
_
log_distances
=
log_distances
def
_inverse_times
(
self
,
x
,
spaces
):
def
_inverse_times
(
self
,
x
,
spaces
):
if
self
.
sigma
==
0
:
if
self
.
sigma
==
0
:
...
@@ -183,18 +183,10 @@ class SmoothingOperator(EndomorphicOperator):
...
@@ -183,18 +183,10 @@ class SmoothingOperator(EndomorphicOperator):
def
sigma
(
self
):
def
sigma
(
self
):
return
self
.
_sigma
return
self
.
_sigma
@
sigma
.
setter
def
sigma
(
self
,
sigma
):
self
.
_sigma
=
np
.
float
(
sigma
)
@
property
@
property
def
log_distances
(
self
):
def
log_distances
(
self
):
return
self
.
_log_distances
return
self
.
_log_distances
@
log_distances
.
setter
def
log_distances
(
self
,
log_distances
):
self
.
_log_distances
=
bool
(
log_distances
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
_smooth
(
self
,
x
,
spaces
,
inverse
):
def
_smooth
(
self
,
x
,
spaces
,
inverse
):
raise
NotImplementedError
raise
NotImplementedError
nifty/spaces/lm_space/lm_space.py
View file @
6705b416
...
@@ -89,25 +89,21 @@ class LMSpace(Space):
...
@@ -89,25 +89,21 @@ class LMSpace(Space):
super
(
LMSpace
,
self
).
__init__
()
super
(
LMSpace
,
self
).
__init__
()
self
.
_lmax
=
self
.
_parse_lmax
(
lmax
)
self
.
_lmax
=
self
.
_parse_lmax
(
lmax
)
def
hermitian_decomposition
(
self
,
x
,
axes
=
None
,
def
hermitian_decomposition
(
self
,
x
,
axes
=
None
):
preserve_gaussian_variance
=
False
):
if
issubclass
(
x
.
dtype
.
type
,
np
.
complexfloating
):
if
issubclass
(
x
.
dtype
.
type
,
np
.
complexfloating
):
hermitian_part
=
x
.
copy_empty
()
hermitian_part
=
x
.
copy_empty
()
anti_hermitian_part
=
x
.
copy_empty
()
anti_hermitian_part
=
x
.
copy_empty
()
hermitian_part
[:]
=
x
.
real
hermitian_part
[:]
=
x
.
real
anti_hermitian_part
[:]
=
x
.
imag
*
1j
anti_hermitian_part
[:]
=
x
.
imag
*
1j
if
preserve_gaussian_variance
:
hermitian_part
*=
np
.
sqrt
(
2
)
anti_hermitian_part
*=
np
.
sqrt
(
2
)
else
:
else
:
hermitian_part
=
x
.
copy
()
hermitian_part
=
x
.
copy
()
anti_hermitian_part
=
x
.
copy_empty
()
anti_hermitian_part
=
x
.
copy_empty
()
anti_hermitian_part
.
val
[:]
=
0
anti_hermitian_part
[:]
=
0
return
(
hermitian_part
,
anti_hermitian_part
)
return
(
hermitian_part
,
anti_hermitian_part
)
#
def hermitian_fixed_points(self):
def
hermitian_fixed_points
(
self
):
#
return None
return
None
# ---Mandatory properties and methods---
# ---Mandatory properties and methods---
...
...
nifty/spaces/rg_space/rg_space.py
View file @
6705b416
...
@@ -102,6 +102,12 @@ class RGSpace(Space):
...
@@ -102,6 +102,12 @@ class RGSpace(Space):
def
hermitian_decomposition
(
self
,
x
,
axes
=
None
,
def
hermitian_decomposition
(
self
,
x
,
axes
=
None
,
preserve_gaussian_variance
=
False
):
preserve_gaussian_variance
=
False
):
# check axes
if
axes
is
None
:
axes
=
range
(
len
(
self
.
shape
))
assert
len
(
x
.
shape
)
>=
len
(
self
.
shape
),
"shapes mismatch"
assert
len
(
axes
)
==
len
(
self
.
shape
),
"axes mismatch"
# compute the hermitian part
# compute the hermitian part
flipped_x
=
self
.
_hermitianize_inverter
(
x
,
axes
=
axes
)
flipped_x
=
self
.
_hermitianize_inverter
(
x
,
axes
=
axes
)
flipped_x
=
flipped_x
.
conjugate
()
flipped_x
=
flipped_x
.
conjugate
()
...
@@ -112,68 +118,46 @@ class RGSpace(Space):
...
@@ -112,68 +118,46 @@ class RGSpace(Space):
# use subtraction since it is faster than flipping another time
# use subtraction since it is faster than flipping another time
anti_hermitian_part
=
(
x
-
hermitian_part
)
anti_hermitian_part
=
(
x
-
hermitian_part
)
if
preserve_gaussian_variance
:
hermitian_part
,
anti_hermitian_part
=
\
self
.
_hermitianize_correct_variance
(
hermitian_part
,
anti_hermitian_part
,
axes
=
axes
)
return
(
hermitian_part
,
anti_hermitian_part
)
return
(
hermitian_part
,
anti_hermitian_part
)
def
_hermitianize_correct_variance
(
self
,
hermitian_part
,
def
hermitian_fixed_points
(
self
):
anti_hermitian_part
,
axes
):
dimensions
=
len
(
self
.
shape
)
# Correct the variance by multiplying sqrt(2)
mid_index
=
np
.
array
(
self
.
shape
)
//
2
hermitian_part
=
hermitian_part
*
np
.
sqrt
(
2
)
ndlist
=
[
1
]
*
dimensions
anti_hermitian_part
=
anti_hermitian_part
*
np
.
sqrt
(
2
)
for
k
in
range
(
dimensions
):
if
self
.
shape
[
k
]
%
2
==
0
:
# If the dtype of the input is complex, the fixed points lose the power
ndlist
[
k
]
=
2
# of their imaginary-part (or real-part, respectively). Therefore
# the factor of sqrt(2) also applies there
if
not
issubclass
(
hermitian_part
.
dtype
.
type
,
np
.
complexfloating
):
# The fixed points of the point inversion must not be averaged.
# Hence one must divide out the sqrt(2) again
# -> Get the middle index of the array
mid_index
=
np
.
array
(
hermitian_part
.
shape
,
dtype
=
np
.
int
)
//
2
dimensions
=
mid_index
.
size
# Use ndindex to iterate over all combinations of zeros and the
# mid_index in order to correct all fixed points.
if
axes
is
None
:
axes
=
xrange
(
dimensions
)
ndlist
=
[
2
if
i
in
axes
else
1
for
i
in
xrange
(
dimensions
)]
ndlist
=
tuple
(
ndlist
)
ndlist
=
tuple
(
ndlist
)
for
i
in
np
.
ndindex
(
ndlist
):
fixed_points
=
[]
temp_index
=
tuple
(
i
*
mid_index
)
for
index
in
np
.
ndindex
(
ndlist
):
hermitian_part
[
temp_index
]
/=
np
.
sqrt
(
2
)
for
k
in
range
(
dimensions
):
anti_hermitian_part
[
temp_index
]
/=
np
.
sqrt
(
2
)
if
self
.
shape
[
k
]
%
2
!=
0
and
self
.
zerocenter
[
k
]:
return
hermitian_part
,
anti_hermitian_part
index
=
list
(
index
)
index
[
k
]
=
1
index
=
tuple
(
index
)
fixed_points
+=
[
tuple
(
index
*
mid_index
)]
return
fixed_points
def
_hermitianize_inverter
(
self
,
x
,
axes
):
def
_hermitianize_inverter
(
self
,
x
,
axes
):
shape
=
x
.
shape
# calculate the number of dimensions the input array has
# calculate the number of dimensions the input array has
dimensions
=
len
(
shape
)
dimensions
=
len
(
x
.
shape
)
# prepare the slicing object which will be used for mirroring
# prepare the slicing object which will be used for mirroring
slice_primitive
=
[
slice
(
None
),
]
*
dimensions
slice_primitive
=
[
slice
(
None
),
]
*
dimensions
# copy the input data
# copy the input data
y
=
x
.
copy
()
y
=
x
.
copy
()
if
axes
is
None
:
axes
=
xrange
(
dimensions
)
# flip in the desired directions
# flip in the desired directions
for
i
in
axes
:
for
k
in
range
(
len
(
axes
)):
i
=
axes
[
k
]
slice_picker
=
slice_primitive
[:]
slice_picker
=
slice_primitive
[:]
if
shape
[
i
]
%
2
==
0
:
slice_picker
[
i
]
=
slice
(
1
,
None
,
None
)
else
:
slice_picker
[
i
]
=
slice
(
None
)
slice_picker
=
tuple
(
slice_picker
)
slice_inverter
=
slice_primitive
[:]
slice_inverter
=
slice_primitive
[:]
if
shape
[
i
]
%
2
==
0
:
if
(
not
self
.
zerocenter
[
k
])
or
self
.
shape
[
k
]
%
2
==
0
:
slice_picker
[
i
]
=
slice
(
1
,
None
,
None
)
slice_inverter
[
i
]
=
slice
(
None
,
0
,
-
1
)
slice_inverter
[
i
]
=
slice
(
None
,
0
,
-
1
)
else
:
else
:
slice_picker
[
i
]
=
slice
(
None
)
slice_inverter
[
i
]
=
slice
(
None
,
None
,
-
1
)
slice_inverter
[
i
]
=
slice
(
None
,
None
,
-
1
)
slice_picker
=
tuple
(
slice_picker
)
slice_inverter
=
tuple
(
slice_inverter
)
slice_inverter
=
tuple
(
slice_inverter
)
try
:
try
:
...
...
nifty/spaces/space/space.py
View file @
6705b416
...
@@ -167,7 +167,7 @@ class Space(DomainObject):
...
@@ -167,7 +167,7 @@ class Space(DomainObject):
If the hermitian decomposition is done via computing the half
If the hermitian decomposition is done via computing the half
sums and differences of `x` and mirrored `x`, all points except the
sums and differences of `x` and mirrored `x`, all points except the
fixed points lose half of their variance. If `x` is complex also
fixed points lose half of their variance. If `x` is complex also
the lose half of their variance since the real(/imaginary) part
the
y
lose half of their variance since the real(/imaginary) part
gets lost.
gets lost.
Returns
Returns
...
...
test/test_field.py
View file @
6705b416
...
@@ -20,20 +20,20 @@ import unittest
...
@@ -20,20 +20,20 @@ import unittest
import
numpy
as
np
import
numpy
as
np
from
numpy.testing
import
assert_
,
\
from
numpy.testing
import
assert_
,
\
assert_equal
assert_almost_equal
,
\
assert_allclose
from
itertools
import
product
from
itertools
import
product
from
nifty
import
Field
,
\
from
nifty
import
Field
,
\
RGSpace
,
\
RGSpace
,
\
FieldArray
LMSpace
,
\
PowerSpace
from
d2o
import
distributed_data_object
,
\
from
d2o
import
distributed_data_object
STRATEGIES
from
test.common
import
expand
from
test.common
import
expand
np
.
random
.
seed
(
123
)
SPACES
=
[
RGSpace
((
4
,)),
RGSpace
((
5
))]
SPACES
=
[
RGSpace
((
4
,)),
RGSpace
((
5
))]
SPACE_COMBINATIONS
=
[(),
SPACES
[
0
],
SPACES
[
1
],
SPACES
]
SPACE_COMBINATIONS
=
[(),
SPACES
[
0
],
SPACES
[
1
],
SPACES
]
...
@@ -55,10 +55,67 @@ class Test_Interface(unittest.TestCase):
...
@@ -55,10 +55,67 @@ class Test_Interface(unittest.TestCase):
f
=
Field
(
domain
=
domain
)
f
=
Field
(
domain
=
domain
)
assert_
(
isinstance
(
getattr
(
f
,
attribute
),
desired_type
))
assert_
(
isinstance
(
getattr
(
f
,
attribute
),
desired_type
))
#class Test_Initialization(unittest.TestCase):
#
class
Test_Functionality
(
unittest
.
TestCase
):
# @parameterized.expand(
@
expand
(
product
([
True
,
False
],
[
True
,
False
],
# itertools.product(SPACE_COMBINATIONS,
[
True
,
False
],
[
True
,
False
],
# []
[(
1
,),
(
4
,),
(
5
,)],
[(
1
,),
(
6
,),
(
7
,)]))
# )
def
test_hermitian_decomposition
(
self
,
z1
,
z2
,
preserve
,
complexdata
,
# def test_
s1
,
s2
):
np
.
random
.
seed
(
123
)
r1
=
RGSpace
(
s1
,
harmonic
=
True
,
zerocenter
=
(
z1
,))
r2
=
RGSpace
(
s2
,
harmonic
=
True
,
zerocenter
=
(
z2
,))
ra
=
RGSpace
(
s1
+
s2
,
harmonic
=
True
,
zerocenter
=
(
z1
,
z2
))
v
=
np
.
random
.
random
(
s1
+
s2
)
if
complexdata
:
v
=
v
+
1j
*
np
.
random
.
random
(
s1
+
s2
)
f1
=
Field
(
ra
,
val
=
v
,
copy
=
True
)
f2
=
Field
((
r1
,
r2
),
val
=
v
,
copy
=
True
)
h1
,
a1
=
Field
.
_hermitian_decomposition
((
ra
,),
f1
.
val
,
(
0
,),
((
0
,
1
,),),
preserve
)
h2
,
a2
=
Field
.
_hermitian_decomposition
((
r1
,
r2
),
f2
.
val
,
(
0
,
1
),
((
0
,),
(
1
,)),
preserve
)
h3
,
a3
=
Field
.
_hermitian_decomposition
((
r1
,
r2
),
f2
.
val
,
(
1
,
0
),
((
0
,),
(
1
,)),
preserve
)
assert_almost_equal
(
h1
.
get_full_data
(),
h2
.
get_full_data
())
assert_almost_equal
(
a1
.
get_full_data
(),
a2
.
get_full_data
())
assert_almost_equal
(
h1
.
get_full_data
(),
h3
.
get_full_data
())
assert_almost_equal
(
a1
.
get_full_data
(),
a3
.
get_full_data
())
@
expand
(
product
([
RGSpace
((
8
,),
harmonic
=
True
,
zerocenter
=
False
),
RGSpace
((
8
,
8
),
harmonic
=
True
,
distances
=
0.123
,
zerocenter
=
True
)],
[
RGSpace
((
8
,),
harmonic
=
True
,
zerocenter
=
False
),
LMSpace
(
12
)]))
def
test_power_synthesize_analyze
(
self
,
space1
,
space2
):
p1
=
PowerSpace
(
space1
)
spec1
=
lambda
k
:
42
/
(
1
+
k
)
**
2
fp1
=
Field
(
p1
,
val
=
spec1
)
p2
=
PowerSpace
(
space2
)
spec2
=
lambda
k
:
42
/
(
1
+
k
)
**
3
fp2
=
Field
(
p2
,
val
=
spec2
)
outer
=
np
.
outer
(
fp1
.
val
.
get_full_data
(),
fp2
.
val
.
get_full_data
())
fp
=
Field
((
p1
,
p2
),
val
=
outer
)
samples
=
1000
ps1
=
0.
ps2
=
0.
for
ii
in
xrange
(
samples
):
sk
=
fp
.
power_synthesize
(
spaces
=
(
0
,
1
),
real_signal
=
True
)
sp
=
sk
.
power_analyze
(
spaces
=
(
0
,
1
),
keep_phase_information
=
False
)