Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
153d1ce1
Commit
153d1ce1
authored
Nov 08, 2017
by
Martin Reinecke
Browse files
very early (broken) stage of distributed_do
parent
67bed0b7
Pipeline
#21209
passed with stage
in 4 minutes and 11 seconds
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
nifty/data_objects/distributed_do.py
0 → 100644
View file @
153d1ce1
import
numpy
as
np
from
.random
import
Random
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
ntask
=
comm
.
Get_size
()
rank
=
comm
.
Get_rank
()
def
shareSize
(
nwork
,
nshares
,
myshare
):
nbase
=
nwork
//
nshares
return
nbase
if
myshare
>=
nwork
%
nshares
else
nbase
+
1
def
get_locshape
(
shape
,
distaxis
):
if
distaxis
==-
1
:
return
shape
shape2
=
list
(
shape
)
shape2
[
distaxis
]
=
shareSize
(
shape
[
distaxis
],
ntask
,
rank
)
return
tuple
(
shape2
)
class
data_object
(
object
):
def
__init__
(
self
,
shape
,
data
,
distaxis
):
"""Must not be called directly by users"""
self
.
_shape
=
shape
self
.
_distaxis
=
distaxis
lshape
=
get_locshape
(
self
.
_shape
,
self
.
_distaxis
)
self
.
_data
=
data
def
sanity_checks
(
self
):
# check whether the distaxis is consistent
if
self
.
_distaxis
<-
1
or
self
.
_distaxis
>=
len
(
self
.
_shape
):
raise
ValueError
itmp
=
np
.
array
(
self
.
_distaxis
)
otmp
=
np
.
empty
(
ntask
,
dtype
=
np
.
int
)
comm
.
Allgather
(
itmp
,
otmp
)
if
np
.
any
(
otmp
!=
self
.
_distaxis
):
raise
ValueError
# check whether the global shape is consistent
itmp
=
np
.
array
(
self
.
_shape
)
otmp
=
np
.
empty
((
len
(
self
.
_shape
),
ntask
),
dtype
=
np
.
int
)
comm
.
Allgather
(
itmp
,
otmp
)
for
i
in
range
(
ntask
):
if
(
otmp
[
i
,:]
!=
self
.
_shape
).
any
():
raise
ValueError
# check shape of local data
if
self
.
_distaxis
<
0
:
if
self
.
_data
.
shape
!=
self
.
_shape
:
raise
ValueError
else
:
itmp
=
np
.
array
(
self
.
_shape
)
itmp
[
self
.
_distaxis
]
=
get_local_length
(
self
.
_shape
[
self
.
_distaxis
],
ntask
,
rank
)
if
self
.
_data
.
shape
!=
itmp
:
raise
ValueError
@
property
def
dtype
(
self
):
return
self
.
_data
.
dtype
@
property
def
shape
(
self
):
return
self
.
_shape
@
property
def
size
(
self
):
return
np
.
prod
(
self
.
_shape
)
@
property
def
real
(
self
):
return
data_object
(
self
.
_shape
,
self
.
_data
.
real
,
self
.
_dist_axis
)
@
property
def
imag
(
self
):
return
data_object
(
self
.
_shape
,
self
.
_data
.
imag
,
self
.
_dist_axis
)
def
_contraction_helper
(
self
,
op
,
axis
):
if
axis
is
not
None
:
if
len
(
axis
)
==
len
(
self
.
_data
.
shape
):
axis
=
None
if
axis
is
None
:
res
=
getattr
(
self
.
_data
,
op
)()
MPI
.
COMM_WORLD
.
Allreduce
(
res
,
res2
,
mpiop
)
if
self
.
_distaxis
in
axis
:
pass
# reduce globally, redistribute the result along axis 0(?)
else
:
pass
# reduce locally
# perform the contraction on the data
data
=
getattr
(
self
.
_data
,
op
)(
axis
=
axis
)
# check if the result is scalar or if a result_field must be constr.
if
np
.
isscalar
(
data
):
return
data
else
:
return
data_object
(
data
)
def
sum
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"sum"
,
MPI
.
SUM
,
axis
)
def
_binary_helper
(
self
,
other
,
op
):
a
=
self
.
_data
if
isinstance
(
other
,
data_object
):
b
=
other
.
_data
if
a
.
_shape
!=
b
.
_shape
:
raise
ValueError
(
"shapes are incompatible."
)
if
a
.
_distaxis
!=
b
.
_distaxis
:
raise
ValueError
(
"distributions are incompatible."
)
else
:
b
=
other
tval
=
getattr
(
a
,
op
)(
b
)
return
self
if
tval
is
a
else
data_object
(
self
.
_shape
,
tval
,
self
.
_distaxis
)
def
__add__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__add__'
)
def
__radd__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__radd__'
)
def
__iadd__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__iadd__'
)
def
__sub__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__sub__'
)
def
__rsub__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__rsub__'
)
def
__isub__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__isub__'
)
def
__mul__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__mul__'
)
def
__rmul__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__rmul__'
)
def
__imul__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__imul__'
)
def
__div__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__div__'
)
def
__rdiv__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__rdiv__'
)
def
__truediv__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__truediv__'
)
def
__rtruediv__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__rtruediv__'
)
def
__pow__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__pow__'
)
def
__rpow__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__rpow__'
)
def
__ipow__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__ipow__'
)
def
__eq__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__eq__'
)
def
__ne__
(
self
,
other
):
return
self
.
_binary_helper
(
other
,
op
=
'__ne__'
)
def
__neg__
(
self
):
return
data_object
(
-
self
.
_data
)
def
__abs__
(
self
):
return
data_object
(
np
.
abs
(
self
.
_data
))
def
ravel
(
self
):
return
data_object
(
self
.
_data
.
ravel
())
def
reshape
(
self
,
shape
):
return
data_object
(
self
.
_data
.
reshape
(
shape
))
def
all
(
self
):
return
self
.
_data
.
all
()
def
any
(
self
):
return
self
.
_data
.
any
()
def
full
(
shape
,
fill_value
,
dtype
=
None
,
dist_axis
=
0
):
return
data_object
(
shape
,
np
.
full
(
shape
,
local_shape
(
shape
,
dist_axis
),
fill_value
,
dtype
))
def
empty
(
shape
,
dtype
=
np
.
float
):
return
data_object
(
np
.
empty
(
shape
,
dtype
))
def
zeros
(
shape
,
dtype
=
np
.
float
):
return
data_object
(
np
.
zeros
(
shape
,
dtype
))
def
ones
(
shape
,
dtype
=
np
.
float
):
return
data_object
(
np
.
ones
(
shape
,
dtype
))
def
empty_like
(
a
,
dtype
=
None
):
return
data_object
(
np
.
empty_like
(
a
.
_data
,
dtype
))
def
vdot
(
a
,
b
):
return
np
.
vdot
(
a
.
_data
,
b
.
_data
)
def
_math_helper
(
x
,
function
,
out
):
if
out
is
not
None
:
function
(
x
.
_data
,
out
=
out
.
_data
)
return
out
else
:
return
data_object
(
function
(
x
.
_data
))
def
abs
(
a
,
out
=
None
):
return
_math_helper
(
a
,
np
.
abs
,
out
)
def
exp
(
a
,
out
=
None
):
return
_math_helper
(
a
,
np
.
exp
,
out
)
def
log
(
a
,
out
=
None
):
return
_math_helper
(
a
,
np
.
log
,
out
)
def
sqrt
(
a
,
out
=
None
):
return
_math_helper
(
a
,
np
.
sqrt
,
out
)
def
bincount
(
x
,
weights
=
None
,
minlength
=
None
):
if
weights
is
not
None
:
weights
=
weights
.
_data
res
=
np
.
bincount
(
x
.
_data
,
weights
,
minlength
)
return
data_object
(
res
)
def
from_object
(
object
,
dtype
=
None
,
copy
=
True
):
return
data_object
(
np
.
array
(
object
.
_data
,
dtype
=
dtype
,
copy
=
copy
))
def
from_random
(
random_type
,
shape
,
dtype
=
np
.
float64
,
**
kwargs
):
generator_function
=
getattr
(
Random
,
random_type
)
return
data_object
(
generator_function
(
dtype
=
dtype
,
shape
=
shape
,
**
kwargs
))
def
to_ndarray
(
arr
):
return
arr
.
_data
def
from_ndarray
(
arr
):
return
data_object
(
arr
.
shape
,
arr
,
-
1
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment