Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
67bed0b7
Commit
67bed0b7
authored
Nov 08, 2017
by
Martin Reinecke
Browse files
tweaks
parent
93d2e2d6
Pipeline
#21196
passed with stage
in 4 minutes and 57 seconds
Changes
7
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/data_objects/my_own_do.py
View file @
67bed0b7
...
@@ -6,6 +6,7 @@ class data_object(object):
...
@@ -6,6 +6,7 @@ class data_object(object):
def
__init__
(
self
,
npdata
):
def
__init__
(
self
,
npdata
):
self
.
_data
=
np
.
asarray
(
npdata
)
self
.
_data
=
np
.
asarray
(
npdata
)
# FIXME: subscripting support will most likely go away
def
__getitem__
(
self
,
key
):
def
__getitem__
(
self
,
key
):
res
=
self
.
_data
[
key
]
res
=
self
.
_data
[
key
]
return
res
if
np
.
isscalar
(
res
)
else
data_object
(
res
)
return
res
if
np
.
isscalar
(
res
)
else
data_object
(
res
)
...
@@ -37,6 +38,9 @@ class data_object(object):
...
@@ -37,6 +38,9 @@ class data_object(object):
return
data_object
(
self
.
_data
.
imag
)
return
data_object
(
self
.
_data
.
imag
)
def
_contraction_helper
(
self
,
op
,
axis
):
def
_contraction_helper
(
self
,
op
,
axis
):
if
axis
is
not
None
:
if
len
(
axis
)
==
len
(
self
.
_data
.
shape
):
axis
=
None
if
axis
is
None
:
if
axis
is
None
:
return
getattr
(
self
.
_data
,
op
)()
return
getattr
(
self
.
_data
,
op
)()
...
@@ -164,32 +168,28 @@ def vdot(a, b):
...
@@ -164,32 +168,28 @@ def vdot(a, b):
return
np
.
vdot
(
a
.
_data
,
b
.
_data
)
return
np
.
vdot
(
a
.
_data
,
b
.
_data
)
def
_math_helper
(
x
,
function
,
out
):
if
out
is
not
None
:
function
(
x
.
_data
,
out
=
out
.
_data
)
return
out
else
:
return
data_object
(
function
(
x
.
_data
))
def
abs
(
a
,
out
=
None
):
def
abs
(
a
,
out
=
None
):
if
out
is
None
:
return
_math_helper
(
a
,
np
.
abs
,
out
)
out
=
empty_like
(
a
)
np
.
abs
(
a
.
_data
,
out
=
out
.
_data
)
return
out
def
exp
(
a
,
out
=
None
):
def
exp
(
a
,
out
=
None
):
if
out
is
None
:
return
_math_helper
(
a
,
np
.
exp
,
out
)
out
=
empty_like
(
a
)
np
.
exp
(
a
.
_data
,
out
=
out
.
_data
)
return
out
def
log
(
a
,
out
=
None
):
def
log
(
a
,
out
=
None
):
if
out
is
None
:
return
_math_helper
(
a
,
np
.
log
,
out
)
out
=
empty_like
(
a
)
np
.
log
(
a
.
_data
,
out
=
out
.
_data
)
return
out
def
sqrt
(
a
,
out
=
None
):
def
sqrt
(
a
,
out
=
None
):
if
out
is
None
:
return
_math_helper
(
a
,
np
.
sqrt
,
out
)
out
=
empty_like
(
a
)
np
.
sqrt
(
a
.
_data
,
out
=
out
.
_data
)
return
out
def
bincount
(
x
,
weights
=
None
,
minlength
=
None
):
def
bincount
(
x
,
weights
=
None
,
minlength
=
None
):
...
@@ -224,12 +224,6 @@ def ibegin(arr):
...
@@ -224,12 +224,6 @@ def ibegin(arr):
return
(
0
,)
*
arr
.
_data
.
ndim
return
(
0
,)
*
arr
.
_data
.
ndim
def
create_from_template
(
tmpl
,
local_data
,
dtype
):
res
=
np
.
ndarray
(
tmpl
.
shape
,
dtype
=
dtype
)
res
[()]
=
local_data
return
data_object
(
res
)
def
np_allreduce_sum
(
arr
):
def
np_allreduce_sum
(
arr
):
return
arr
return
arr
...
@@ -249,8 +243,6 @@ def from_local_data (shape, arr, dist_axis):
...
@@ -249,8 +243,6 @@ def from_local_data (shape, arr, dist_axis):
def
from_global_data
(
arr
,
dist_axis
):
def
from_global_data
(
arr
,
dist_axis
):
if
dist_axis
!=-
1
:
if
dist_axis
!=-
1
:
raise
NotImplementedError
raise
NotImplementedError
if
shape
!=
arr
.
shape
:
raise
ValueError
return
data_object
(
arr
)
return
data_object
(
arr
)
...
...
nifty/data_objects/numpy_do.py
View file @
67bed0b7
...
@@ -31,12 +31,6 @@ def ibegin(arr):
...
@@ -31,12 +31,6 @@ def ibegin(arr):
return
(
0
,)
*
arr
.
ndim
return
(
0
,)
*
arr
.
ndim
def
create_from_template
(
tmpl
,
local_data
,
dtype
):
res
=
np
.
ndarray
(
tmpl
.
shape
,
dtype
=
dtype
)
res
[()]
=
local_data
return
res
def
np_allreduce_sum
(
arr
):
def
np_allreduce_sum
(
arr
):
return
arr
return
arr
...
@@ -56,8 +50,6 @@ def from_local_data (shape, arr, dist_axis):
...
@@ -56,8 +50,6 @@ def from_local_data (shape, arr, dist_axis):
def
from_global_data
(
arr
,
dist_axis
):
def
from_global_data
(
arr
,
dist_axis
):
if
dist_axis
!=-
1
:
if
dist_axis
!=-
1
:
raise
NotImplementedError
raise
NotImplementedError
if
shape
!=
arr
.
shape
:
raise
ValueError
return
arr
return
arr
...
...
nifty/field.py
View file @
67bed0b7
...
@@ -455,7 +455,7 @@ class Field(object):
...
@@ -455,7 +455,7 @@ class Field(object):
raise
TypeError
(
"argument must be a Field"
)
raise
TypeError
(
"argument must be a Field"
)
if
other
.
domain
!=
self
.
domain
:
if
other
.
domain
!=
self
.
domain
:
raise
ValueError
(
"domains are incompatible."
)
raise
ValueError
(
"domains are incompatible."
)
self
.
val
[()]
=
other
.
val
[()]
dobj
.
local_data
(
self
.
val
)
[()]
=
dobj
.
local_data
(
other
.
val
)
[()]
# ---General binary methods---
# ---General binary methods---
...
...
nifty/operators/power_projection_operator.py
View file @
67bed0b7
...
@@ -55,34 +55,36 @@ class PowerProjectionOperator(LinearOperator):
...
@@ -55,34 +55,36 @@ class PowerProjectionOperator(LinearOperator):
res
=
Field
.
zeros
(
self
.
_target
,
dtype
=
x
.
dtype
)
res
=
Field
.
zeros
(
self
.
_target
,
dtype
=
x
.
dtype
)
if
dobj
.
dist_axis
(
x
.
val
)
in
x
.
domain
.
axes
[
self
.
_space
]:
# the distributed axis is part of the projected space
if
dobj
.
dist_axis
(
x
.
val
)
in
x
.
domain
.
axes
[
self
.
_space
]:
# the distributed axis is part of the projected space
pindex
=
dobj
.
local_data
(
pindex
)
pindex
=
dobj
.
local_data
(
pindex
)
pindex
.
reshape
((
1
,
pindex
.
size
,
1
))
arr
=
dobj
.
local_data
(
x
.
weight
(
1
).
val
)
firstaxis
=
x
.
domain
.
axes
[
self
.
_space
][
0
]
lastaxis
=
x
.
domain
.
axes
[
self
.
_space
][
-
1
]
presize
=
np
.
prod
(
arr
.
shape
[
0
:
firstaxis
],
dtype
=
np
.
int
)
postsize
=
np
.
prod
(
arr
.
shape
[
lastaxis
+
1
:],
dtype
=
np
.
int
)
arr
=
arr
.
reshape
((
presize
,
pindex
.
size
,
postsize
))
oarr
=
dobj
.
local_data
(
res
.
val
).
reshape
((
presize
,
-
1
,
postsize
))
np
.
add
.
at
(
oarr
,
(
slice
(
None
),
pindex
.
ravel
(),
slice
(
None
)),
arr
)
else
:
else
:
pindex
=
dobj
.
to_ndarray
(
pindex
)
pindex
=
dobj
.
to_ndarray
(
pindex
)
pindex
.
reshape
((
1
,
pindex
.
size
,
1
))
pindex
.
reshape
((
1
,
pindex
.
size
,
1
))
arr
=
dobj
.
local_data
(
x
.
weight
(
1
).
val
)
arr
=
dobj
.
local_data
(
x
.
weight
(
1
).
val
)
firstaxis
=
x
.
domain
.
axes
[
self
.
_space
][
0
]
firstaxis
=
x
.
domain
.
axes
[
self
.
_space
][
0
]
lastaxis
=
x
.
domain
.
axes
[
self
.
_space
][
-
1
]
lastaxis
=
x
.
domain
.
axes
[
self
.
_space
][
-
1
]
presize
=
np
.
prod
(
arr
.
shape
[
0
:
firstaxis
],
dtype
=
np
.
int
)
presize
=
np
.
prod
(
arr
.
shape
[
0
:
firstaxis
],
dtype
=
np
.
int
)
postsize
=
np
.
prod
(
arr
.
shape
[
lastaxis
+
1
:],
dtype
=
np
.
int
)
postsize
=
np
.
prod
(
arr
.
shape
[
lastaxis
+
1
:],
dtype
=
np
.
int
)
arr
=
arr
.
reshape
((
presize
,
pindex
.
size
,
postsize
))
arr
=
arr
.
reshape
((
presize
,
pindex
.
size
,
postsize
))
oarr
=
dobj
.
local_data
(
res
.
val
).
reshape
((
presize
,
-
1
,
postsize
))
oarr
=
dobj
.
local_data
(
res
.
val
).
reshape
((
presize
,
-
1
,
postsize
))
np
.
add
.
at
(
oarr
,
(
slice
(
None
),
pindex
.
ravel
(),
slice
(
None
)),
arr
)
np
.
add
.
at
(
oarr
,
(
slice
(
None
),
pindex
.
ravel
(),
slice
(
None
)),
arr
)
return
res
.
weight
(
-
1
,
spaces
=
self
.
_space
)
return
res
.
weight
(
-
1
,
spaces
=
self
.
_space
)
def
_adjoint_times
(
self
,
x
):
def
_adjoint_times
(
self
,
x
):
pindex
=
self
.
_target
[
self
.
_space
].
pindex
pindex
=
self
.
_target
[
self
.
_space
].
pindex
res
=
Field
.
empty
(
self
.
_domain
,
dtype
=
x
.
dtype
)
if
dobj
.
dist_axis
(
x
.
val
)
in
x
.
domain
.
axes
[
self
.
_space
]:
# the distributed axis is part of the projected space
pindex
=
dobj
.
local_data
(
pindex
)
else
:
pindex
=
dobj
.
to_ndarray
(
pindex
)
pindex
=
pindex
.
reshape
((
1
,
pindex
.
size
,
1
))
pindex
=
pindex
.
reshape
((
1
,
pindex
.
size
,
1
))
arr
=
x
.
val
.
reshape
(
x
.
domain
.
collapsed_shape_for_domain
(
self
.
_space
))
arr
=
dobj
.
local_data
(
x
.
val
)
out
=
arr
[(
slice
(
None
),
dobj
.
to_ndarray
(
pindex
.
ravel
()),
slice
(
None
))]
firstaxis
=
x
.
domain
.
axes
[
self
.
_space
][
0
]
return
Field
(
self
.
_domain
,
out
.
reshape
(
self
.
_domain
.
shape
))
lastaxis
=
x
.
domain
.
axes
[
self
.
_space
][
-
1
]
presize
=
np
.
prod
(
arr
.
shape
[
0
:
firstaxis
],
dtype
=
np
.
int
)
postsize
=
np
.
prod
(
arr
.
shape
[
lastaxis
+
1
:],
dtype
=
np
.
int
)
arr
=
arr
.
reshape
((
presize
,
-
1
,
postsize
))
oarr
=
dobj
.
local_data
(
res
.
val
).
reshape
((
presize
,
-
1
,
postsize
))
oarr
[()]
=
arr
[(
slice
(
None
),
pindex
.
ravel
(),
slice
(
None
))]
return
res
@
property
@
property
def
domain
(
self
):
def
domain
(
self
):
...
...
nifty/spaces/power_space.py
View file @
67bed0b7
...
@@ -143,8 +143,8 @@ class PowerSpace(Space):
...
@@ -143,8 +143,8 @@ class PowerSpace(Space):
else
:
else
:
tbb
=
binbounds
tbb
=
binbounds
locdat
=
np
.
searchsorted
(
tbb
,
dobj
.
local_data
(
k_length_array
.
val
))
locdat
=
np
.
searchsorted
(
tbb
,
dobj
.
local_data
(
k_length_array
.
val
))
temp_pindex
=
dobj
.
create_from_templ
at
e
(
temp_pindex
=
dobj
.
from_local_d
at
a
(
k_length_array
.
val
,
local_data
=
locdat
,
dtype
=
locdat
.
dtype
)
k_length_array
.
val
.
shape
,
locdat
,
dobj
.
dist_axis
(
k_length_array
.
val
)
)
nbin
=
len
(
tbb
)
nbin
=
len
(
tbb
)
temp_rho
=
np
.
bincount
(
dobj
.
local_data
(
temp_pindex
).
ravel
(),
temp_rho
=
np
.
bincount
(
dobj
.
local_data
(
temp_pindex
).
ravel
(),
minlength
=
nbin
)
minlength
=
nbin
)
...
...
nifty/spaces/rg_space.py
View file @
67bed0b7
...
@@ -115,6 +115,7 @@ class RGSpace(Space):
...
@@ -115,6 +115,7 @@ class RGSpace(Space):
tmp
[
t2
]
=
True
tmp
[
t2
]
=
True
return
np
.
sqrt
(
np
.
nonzero
(
tmp
)[
0
])
*
self
.
distances
[
0
]
return
np
.
sqrt
(
np
.
nonzero
(
tmp
)[
0
])
*
self
.
distances
[
0
]
else
:
# do it the hard way
else
:
# do it the hard way
# FIXME: this needs to improve for MPI. Maybe unique()/gather()?
tmp
=
np
.
unique
(
dobj
.
to_ndarray
(
self
.
get_k_length_array
().
val
))
# expensive!
tmp
=
np
.
unique
(
dobj
.
to_ndarray
(
self
.
get_k_length_array
().
val
))
# expensive!
tol
=
1e-12
*
tmp
[
-
1
]
tol
=
1e-12
*
tmp
[
-
1
]
# remove all points that are closer than tol to their right
# remove all points that are closer than tol to their right
...
...
test/test_operators/test_smoothing_operator.py
View file @
67bed0b7
...
@@ -56,8 +56,9 @@ class SmoothingOperator_Tests(unittest.TestCase):
...
@@ -56,8 +56,9 @@ class SmoothingOperator_Tests(unittest.TestCase):
@
expand
(
product
(
spaces
,
[
0.
,
.
5
,
5.
]))
@
expand
(
product
(
spaces
,
[
0.
,
.
5
,
5.
]))
def
test_times
(
self
,
space
,
sigma
):
def
test_times
(
self
,
space
,
sigma
):
op
=
ift
.
FFTSmoothingOperator
(
space
,
sigma
=
sigma
)
op
=
ift
.
FFTSmoothingOperator
(
space
,
sigma
=
sigma
)
rand1
=
ift
.
Field
.
zeros
(
space
)
fld
=
np
.
zeros
(
space
.
shape
,
dtype
=
np
.
float64
)
rand1
.
val
[
0
]
=
1.
fld
[
0
]
=
1.
rand1
=
ift
.
Field
(
space
,
ift
.
dobj
.
from_global_data
(
fld
,
dist_axis
=-
1
))
tt1
=
op
.
times
(
rand1
)
tt1
=
op
.
times
(
rand1
)
assert_allclose
(
1
,
tt1
.
sum
())
assert_allclose
(
1
,
tt1
.
sum
())
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment