Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
e10b54aa
Commit
e10b54aa
authored
Sep 23, 2015
by
Ultima
Browse files
Cleaned up the field class.
Fixed a bug in nifty_mpi_data.
parent
41dd2ed7
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
nifty_core.py
View file @
e10b54aa
This diff is collapsed.
Click to expand it.
nifty_mpi_data.py
View file @
e10b54aa
...
...
@@ -1294,6 +1294,7 @@ class distributor(object):
distribution_strategy
=
'freeform'
,
comm
=
self
.
comm
)
# disperse the data one after another
print
(
'i'
,
i
,
temp_data_update
)
self
.
_disperse_data_primitive
(
data
=
data
,
to_key
=
to_key_list
[
i
],
...
...
@@ -2241,40 +2242,50 @@ class _slicing_distributor(distributor):
# in case of leading scalars, indenify the node with data
# and broadcast the shape to the others
if
sliceified
[
0
]:
local_has_data
=
(
np
.
prod
(
np
.
shape
(
in_data
.
get_local_data
())
)
!=
0
)
local_has_data_list
=
np
.
array
(
self
.
comm
.
allgather
(
local_has_data
))
nodes_with_data
=
np
.
where
(
local_has_data_list
==
True
)[
0
]
if
np
.
shape
(
nodes_with_data
)[
0
]
>
1
:
raise
ValueError
(
"ERROR: scalar index on first dimension, but more "
+
"than one node has data!"
)
elif
np
.
shape
(
nodes_with_data
)[
0
]
==
1
:
node_with_data
=
nodes_with_data
[
0
]
else
:
node_with_data
=
-
1
# Case 1: The in_data d2o has more than one dimension
if
len
(
in_data
.
shape
)
>
1
:
local_has_data
=
(
np
.
prod
(
np
.
shape
(
in_data
.
get_local_data
()))
!=
0
)
local_has_data_list
=
np
.
array
(
self
.
comm
.
allgather
(
local_has_data
))
nodes_with_data
=
np
.
where
(
local_has_data_list
)[
0
]
if
np
.
shape
(
nodes_with_data
)[
0
]
>
1
:
raise
ValueError
(
"ERROR: scalar index on first dimension, but "
+
" more than one node has data!"
)
elif
np
.
shape
(
nodes_with_data
)[
0
]
==
1
:
node_with_data
=
nodes_with_data
[
0
]
else
:
node_with_data
=
-
1
if
node_with_data
==
-
1
:
broadcasted_shape
=
(
0
,)
*
len
(
temp_local_shape
)
else
:
broadcasted_shape
=
self
.
comm
.
bcast
(
temp_local_shape
,
if
node_with_data
==
-
1
:
broadcasted_shape
=
(
0
,)
*
len
(
temp_local_shape
)
else
:
broadcasted_shape
=
self
.
comm
.
bcast
(
temp_local_shape
,
root
=
node_with_data
)
if
self
.
comm
.
rank
!=
node_with_data
:
temp_local_shape
=
np
.
array
(
broadcasted_shape
)
temp_local_shape
[
0
]
=
0
temp_local_shape
=
tuple
(
temp_local_shape
)
if
self
.
comm
.
rank
!=
node_with_data
:
temp_local_shape
=
np
.
array
(
broadcasted_shape
)
temp_local_shape
[
0
]
=
0
temp_local_shape
=
tuple
(
temp_local_shape
)
# Case 2: The in_data d2o is only onedimensional
else
:
# The data contained in the d2o must be stored on one
# single node at the end. Hence it is ok to consolidate
# the data and to make a recursive call.
temp_data
=
in_data
.
get_full_data
()
return
self
.
_enfold
(
temp_data
,
sliceified
)
if
in_data
.
distribution_strategy
!=
'freeform'
:
if
in_data
.
distribution_strategy
in
STRATEGIES
[
'global'
]
:
new_data
=
in_data
.
copy_empty
(
global_shape
=
temp_global_shape
)
new_data
.
set_local_data
(
local_data
,
copy
=
False
)
el
se
:
el
if
in_data
.
distribution_strategy
in
STRATEGIES
[
'local'
]
:
reshaped_data
=
local_data
.
reshape
(
temp_local_shape
)
new_data
=
distributed_data_object
(
local_data
=
reshaped_data
,
distribution_strategy
=
'freeform'
,
comm
=
self
.
comm
)
local_data
=
reshaped_data
,
distribution_strategy
=
in_data
.
distribution_strategy
,
comm
=
self
.
comm
)
return
new_data
else
:
return
local_data
.
reshape
(
temp_local_shape
)
...
...
test/test_nifty_mpi_data.py
View file @
e10b54aa
...
...
@@ -743,15 +743,33 @@ class Test_slicing_get_set_data(unittest.TestCase):
###############################################################################
@
parameterized
.
expand
(
all_distribution_strategies
)
@
parameterized
.
expand
(
all_distribution_strategies
,
testcase_func_name
=
custom_name_func
)
def
test_get_single_value_from_d2o
(
self
,
distribution_strategy
):
(
a
,
obj
)
=
generate_data
((
4
,),
np
.
dtype
(
'float'
),
distribution_strategy
)
assert_equal
(
obj
[
0
],
a
[
0
])
###############################################################################
###############################################################################
@
parameterized
.
expand
(
itertools
.
product
(
all_distribution_strategies
,
all_distribution_strategies
),
testcase_func_name
=
custom_name_func
)
def
test_single_row_from_d2o
(
self
,
distribution_strategy1
,
distribution_strategy2
):
(
a
,
obj
)
=
generate_data
((
8
,
8
),
np
.
dtype
(
'float'
),
distribution_strategy1
)
(
b
,
p
)
=
generate_data
((
8
,),
np
.
dtype
(
'float'
),
distribution_strategy2
)
a
[
4
]
=
b
obj
[
4
]
=
p
assert_equal
(
obj
.
get_full_data
(),
a
)
###############################################################################
###############################################################################
class
Test_boolean_get_set_data
(
unittest
.
TestCase
):
...
...
test/test_nifty_spaces.py
View file @
e10b54aa
...
...
@@ -22,16 +22,16 @@ from nifty.nifty_paradict import space_paradict
from
nifty.nifty_core
import
POINT_DISTRIBUTION_STRATEGIES
from
nifty.rg.nifty_rg
import
RG_DISTRIBUTION_STRATEGIES
,
\
gc
as
RG_GC
gc
as
RG_GC
from
nifty.lm.nifty_lm
import
LM_DISTRIBUTION_STRATEGIES
,
\
GL_DISTRIBUTION_STRATEGIES
,
\
HP_DISTRIBUTION_STRATEGIES
GL_DISTRIBUTION_STRATEGIES
,
\
HP_DISTRIBUTION_STRATEGIES
from
nifty.nifty_power_indices
import
power_indices
from
nifty.nifty_utilities
import
_hermitianize_inverter
as
\
hermitianize_inverter
hermitianize_inverter
###############################################################################
###############################################################################
def
custom_name_func
(
testcase_func
,
param_num
,
param
):
return
"%s_%s"
%
(
...
...
@@ -169,10 +169,10 @@ def check_almost_equality(space, data1, data2, integers=7):
def
flip
(
space
,
data
):
return
space
.
unary_operation
(
hermitianize_inverter
(
data
),
'conjugate'
)
###############################################################################
###############################################################################
class
Test_Common_Space_Features
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
all_spaces
,
...
...
@@ -195,7 +195,6 @@ class Test_Common_Space_Features(unittest.TestCase):
assert
(
callable
(
s
.
apply_scalar_function
))
assert
(
callable
(
s
.
unary_operation
))
assert
(
callable
(
s
.
binary_operation
))
assert
(
callable
(
s
.
get_norm
))
assert
(
callable
(
s
.
get_shape
))
assert
(
callable
(
s
.
get_dim
))
assert
(
callable
(
s
.
get_dof
))
...
...
@@ -207,6 +206,7 @@ class Test_Common_Space_Features(unittest.TestCase):
assert
(
callable
(
s
.
get_random_values
))
assert
(
callable
(
s
.
calc_weight
))
assert
(
callable
(
s
.
get_weight
))
assert
(
callable
(
s
.
calc_norm
))
assert
(
callable
(
s
.
calc_dot
))
assert
(
callable
(
s
.
calc_transform
))
assert
(
callable
(
s
.
calc_smooth
))
...
...
@@ -346,18 +346,6 @@ class Test_Point_Space(unittest.TestCase):
s
.
binary_operation
(
d
,
d2
,
op
)
# TODO: Implement value verification
###############################################################################
@
parameterized
.
expand
(
itertools
.
product
(
DATAMODELS
[
'point_space'
]),
testcase_func_name
=
custom_name_func
)
def
test_get_norm
(
self
,
datamodel
):
num
=
10
s
=
point_space
(
num
,
datamodel
=
datamodel
)
d
=
s
.
cast
(
np
.
arange
(
num
))
assert_almost_equal
(
s
.
get_norm
(
d
),
16.881943016134134
)
assert_almost_equal
(
s
.
get_norm
(
d
,
q
=
3
),
12.651489979526238
)
###############################################################################
@
parameterized
.
expand
(
...
...
@@ -599,6 +587,18 @@ class Test_Point_Space(unittest.TestCase):
assert_equal
(
s
.
calc_dot
(
1
,
1
),
num
)
assert_equal
(
s
.
calc_dot
(
np
.
arange
(
num
),
1
),
num
*
(
num
-
1.
)
/
2.
)
###############################################################################
@
parameterized
.
expand
(
itertools
.
product
(
DATAMODELS
[
'point_space'
]),
testcase_func_name
=
custom_name_func
)
def
test_calc_norm
(
self
,
datamodel
):
num
=
10
s
=
point_space
(
num
,
datamodel
=
datamodel
)
d
=
s
.
cast
(
np
.
arange
(
num
))
assert_almost_equal
(
s
.
calc_norm
(
d
),
16.881943016134134
)
assert_almost_equal
(
s
.
calc_norm
(
d
,
q
=
3
),
12.651489979526238
)
###############################################################################
@
parameterized
.
expand
(
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment