Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
f399b0da
Commit
f399b0da
authored
Apr 19, 2016
by
csongor
Browse files
Add axis keyword to point_space and test it
parent
51428d74
Pipeline
#1794
skipped
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty_core.py
View file @
f399b0da
...
...
@@ -892,7 +892,7 @@ class point_space(space):
def
apply_scalar_function
(
self
,
x
,
function
,
inplace
=
False
):
return
x
.
apply_scalar_function
(
function
,
inplace
=
inplace
)
def
unary_operation
(
self
,
x
,
op
=
'None'
,
**
kwargs
):
def
unary_operation
(
self
,
x
,
op
=
'None'
,
axis
=
None
,
**
kwargs
):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
...
...
@@ -903,21 +903,21 @@ class point_space(space):
'abs'
:
lambda
y
:
getattr
(
y
,
'__abs__'
)(),
'real'
:
lambda
y
:
getattr
(
y
,
'real'
),
'imag'
:
lambda
y
:
getattr
(
y
,
'imag'
),
'nanmin'
:
lambda
y
:
getattr
(
y
,
'nanmin'
)(),
'amin'
:
lambda
y
:
getattr
(
y
,
'amin'
)(),
'nanmax'
:
lambda
y
:
getattr
(
y
,
'nanmax'
)(),
'amax'
:
lambda
y
:
getattr
(
y
,
'amax'
)(),
'median'
:
lambda
y
:
getattr
(
y
,
'median'
)(),
'mean'
:
lambda
y
:
getattr
(
y
,
'mean'
)(),
'std'
:
lambda
y
:
getattr
(
y
,
'std'
)(),
'var'
:
lambda
y
:
getattr
(
y
,
'var'
)(),
'nanmin'
:
lambda
y
:
getattr
(
y
,
'nanmin'
)(
axis
=
axis
),
'amin'
:
lambda
y
:
getattr
(
y
,
'amin'
)(
axis
=
axis
),
'nanmax'
:
lambda
y
:
getattr
(
y
,
'nanmax'
)(
axis
=
axis
),
'amax'
:
lambda
y
:
getattr
(
y
,
'amax'
)(
axis
=
axis
),
'median'
:
lambda
y
:
getattr
(
y
,
'median'
)(
axis
=
axis
),
'mean'
:
lambda
y
:
getattr
(
y
,
'mean'
)(
axis
=
axis
),
'std'
:
lambda
y
:
getattr
(
y
,
'std'
)(
axis
=
axis
),
'var'
:
lambda
y
:
getattr
(
y
,
'var'
)(
axis
=
axis
),
'argmin'
:
lambda
y
:
getattr
(
y
,
'argmin_nonflat'
)(),
'argmin_flat'
:
lambda
y
:
getattr
(
y
,
'argmin'
)(),
'argmax'
:
lambda
y
:
getattr
(
y
,
'argmax_nonflat'
)(),
'argmax_flat'
:
lambda
y
:
getattr
(
y
,
'argmax'
)(),
'conjugate'
:
lambda
y
:
getattr
(
y
,
'conjugate'
)(),
'sum'
:
lambda
y
:
getattr
(
y
,
'sum'
)(),
'prod'
:
lambda
y
:
getattr
(
y
,
'prod'
)(),
'sum'
:
lambda
y
:
getattr
(
y
,
'sum'
)(
axis
=
axis
),
'prod'
:
lambda
y
:
getattr
(
y
,
'prod'
)(
axis
=
axis
),
'unique'
:
lambda
y
:
getattr
(
y
,
'unique'
)(),
'copy'
:
lambda
y
:
getattr
(
y
,
'copy'
)(),
'copy_empty'
:
lambda
y
:
getattr
(
y
,
'copy_empty'
)(),
...
...
@@ -925,8 +925,8 @@ class point_space(space):
'isinf'
:
lambda
y
:
getattr
(
y
,
'isinf'
)(),
'isfinite'
:
lambda
y
:
getattr
(
y
,
'isfinite'
)(),
'nan_to_num'
:
lambda
y
:
getattr
(
y
,
'nan_to_num'
)(),
'all'
:
lambda
y
:
getattr
(
y
,
'all'
)(),
'any'
:
lambda
y
:
getattr
(
y
,
'any'
)(),
'all'
:
lambda
y
:
getattr
(
y
,
'all'
)(
axis
=
axis
),
'any'
:
lambda
y
:
getattr
(
y
,
'any'
)(
axis
=
axis
),
'None'
:
lambda
y
:
y
}
return
translation
[
op
](
x
,
**
kwargs
)
...
...
test/test_nifty_spaces.py
View file @
f399b0da
...
...
@@ -27,6 +27,8 @@ from nifty.nifty_power_indices import power_indices
from
nifty.nifty_utilities
import
_hermitianize_inverter
as
\
hermitianize_inverter
from
nifty.operators.nifty_operators
import
power_operator
available
=
[]
try
:
from
nifty
import
lm_space
...
...
@@ -178,6 +180,22 @@ def generate_space(name):
return
space_dict
[
name
]
def
generate_space_with_size
(
name
,
num
,
datamodel
=
'fftw'
):
space_dict
=
{
'space'
:
space
(),
'point_space'
:
point_space
(
num
,
datamodel
=
datamodel
),
'rg_space'
:
rg_space
((
num
,
num
),
datamodel
=
datamodel
),
}
if
'lm_space'
in
available
:
space_dict
[
'lm_space'
]
=
lm_space
(
mmax
=
num
,
lmax
=
num
,
datamodel
=
datamodel
)
if
'hp_space'
in
available
:
space_dict
[
'hp_space'
]
=
hp_space
(
num
,
datamodel
=
datamodel
)
if
'gl_space'
in
available
:
space_dict
[
'gl_space'
]
=
gl_space
(
nlat
=
num
,
nlon
=
num
,
datamodel
=
datamodel
)
return
space_dict
[
name
]
def
generate_data
(
space
):
a
=
np
.
arange
(
space
.
get_dim
()).
reshape
(
space
.
get_shape
())
data
=
space
.
cast
(
a
)
...
...
@@ -1334,4 +1352,21 @@ class Test_Lm_Space(unittest.TestCase):
print
all_spaces
print
generate_space
(
'rg_space'
)
\ No newline at end of file
print
generate_space
(
'rg_space'
)
class
Test_axis
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
itertools
.
product
(
point_like_spaces
,
[
8
,
16
],
[
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'median'
,
'all'
,
'any'
,
'nanmin'
,
'nanmax'
],
[
None
,
(
0
,)],
DATAMODELS
[
'point_space'
]),
testcase_func_name
=
custom_name_func
)
def
test_binary_operations
(
self
,
name
,
num
,
op
,
axis
,
datamodel
):
s
=
generate_space_with_size
(
name
,
np
.
prod
(
num
),
datamodel
=
datamodel
)
d
=
generate_data
(
s
)
a
=
d
.
get_full_data
()
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
axis
),
getattr
(
np
,
op
)(
a
,
axis
=
axis
),
decimal
=
4
)
if
name
in
[
'rg_space'
]:
assert_almost_equal
(
s
.
unary_operation
(
d
,
op
,
axis
=
(
0
,
1
)),
getattr
(
np
,
op
)(
a
,
axis
=
(
0
,
1
)),
decimal
=
4
)
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment