Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
073ce5ce
Commit
073ce5ce
authored
Nov 11, 2017
by
Martin Reinecke
Browse files
fixes
parent
6108ab56
Changes
2
Show whitespace changes
Inline
Side-by-side
nifty/data_objects/distributed_do.py
View file @
073ce5ce
...
...
@@ -91,10 +91,10 @@ class data_object(object):
if
axis
is
None
:
res
=
np
.
array
(
getattr
(
self
.
_data
,
op
)())
if
(
self
.
_distaxis
==-
1
):
return
res
[
0
]
res2
=
np
.
empty
(
1
,
dtype
=
res
.
dtype
)
return
res
[
()
]
res2
=
np
.
empty
(
()
,
dtype
=
res
.
dtype
)
_comm
.
Allreduce
(
res
,
res2
,
mpiop
)
return
res2
[
0
]
return
res2
[
()
]
if
self
.
_distaxis
in
axis
:
res
=
getattr
(
self
.
_data
,
op
)(
axis
=
axis
)
...
...
@@ -122,6 +122,10 @@ class data_object(object):
def
sum
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"sum"
,
MPI
.
SUM
,
axis
)
def
min
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"min"
,
MPI
.
MIN
,
axis
)
def
max
(
self
,
axis
=
None
):
return
self
.
_contraction_helper
(
"max"
,
MPI
.
MAX
,
axis
)
# FIXME: to be improved!
def
mean
(
self
):
...
...
@@ -239,10 +243,10 @@ def empty_like(a, dtype=None):
def
vdot
(
a
,
b
):
tmp
=
np
.
vdot
(
a
.
_data
,
b
.
_data
)
res
=
np
.
empty
(
1
,
dtype
=
t
ype
(
tmp
)
)
tmp
=
np
.
array
(
np
.
vdot
(
a
.
_data
,
b
.
_data
)
)
res
=
np
.
empty
(
()
,
dtype
=
t
mp
.
dtype
)
_comm
.
Allreduce
(
tmp
,
res
,
MPI
.
SUM
)
return
res
[
0
]
return
res
[
()
]
def
_math_helper
(
x
,
function
,
out
):
...
...
@@ -364,7 +368,7 @@ def redistribute (arr, dist=None, nodist=None):
out
=
np
.
moveaxis
(
out
,
0
,
arr
.
_distaxis
)
return
from_global_data
(
out
,
distaxis
=-
1
)
# real redistribution via Alltoallv
# temporary slow, but simple solution
# temporary slow, but simple solution
for comparison purposes:
#return redistribute(redistribute(arr,dist=-1),dist=dist)
tmp
=
np
.
moveaxis
(
arr
.
_data
,
(
dist
,
arr
.
_distaxis
),
(
0
,
1
))
...
...
nifty/low_level_library.py
View file @
073ce5ce
...
...
@@ -48,7 +48,7 @@ if not special_hartley:
def
_fill_array
(
tmp
,
res
,
axes
):
if
axes
is
None
:
axes
=
range
(
a
.
ndim
)
axes
=
range
(
tmp
.
ndim
)
lastaxis
=
axes
[
-
1
]
ntmplast
=
tmp
.
shape
[
lastaxis
]
slice1
=
[
slice
(
None
)]
*
lastaxis
+
[
slice
(
0
,
ntmplast
)]
...
...
@@ -60,7 +60,7 @@ def hartley(a, axes=None):
# Check if the axes provided are valid given the shape
if
axes
is
not
None
and
\
not
all
(
axis
<
len
(
a
.
shape
)
for
axis
in
axes
):
raise
ValueError
(
"Provided axes do
es
not match array shape"
)
raise
ValueError
(
"Provided axes do not match array shape"
)
if
issubclass
(
a
.
dtype
.
type
,
np
.
complexfloating
):
raise
TypeError
(
"Hartley tansform requires real-valued arrays."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment