Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
ift
NIFTy
Commits
71fb7400
Commit
71fb7400
authored
Jun 06, 2016
by
csongor
Browse files
WIP: unary operations fixes
parent
f78cf8d0
Pipeline
#4574
skipped
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty_field.py
View file @
71fb7400
...
...
@@ -359,7 +359,7 @@ class field(object):
new_val
=
map
(
lambda
z
:
self
.
unary_operation
(
z
,
'copy'
),
new_val
)
self
.
val
=
map
(
lambda
z
:
self
.
cast
(
z
),
new_val
)
self
.
val
=
self
.
cast
(
new_val
)
return
self
.
val
def
get_val
(
self
):
...
...
@@ -996,11 +996,8 @@ class field(object):
"
\n
- codomain = "
+
repr
(
self
.
codomain
)
+
\
"
\n
- ishape = "
+
str
(
self
.
ishape
)
def
_unary_helper
(
self
,
x
,
op
,
**
kwargs
):
result
=
map
(
lambda
z
:
self
.
domain
.
unary_operation
(
z
,
op
=
op
,
**
kwargs
),
self
.
get_val
())
return
result
def
all
(
self
,
**
kwargs
):
return
self
.
_unary_operation
(
self
.
get_val
(),
op
=
'all'
,
**
kwargs
)
def
min
(
self
,
ignore
=
False
,
**
kwargs
):
"""
...
...
@@ -1021,10 +1018,10 @@ class field(object):
np.amin, np.nanmin
"""
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'amin'
,
**
kwargs
)
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'amin'
,
**
kwargs
)
def
nanmin
(
self
,
**
kwargs
):
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'nanmin'
,
**
kwargs
)
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'nanmin'
,
**
kwargs
)
def
max
(
self
,
**
kwargs
):
"""
...
...
@@ -1045,10 +1042,10 @@ class field(object):
np.amax, np.nanmax
"""
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'amax'
,
**
kwargs
)
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'amax'
,
**
kwargs
)
def
nanmax
(
self
,
**
kwargs
):
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'nanmax'
,
**
kwargs
)
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'nanmax'
,
**
kwargs
)
def
median
(
self
,
**
kwargs
):
"""
...
...
@@ -1064,7 +1061,7 @@ class field(object):
np.median
"""
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'median'
,
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'median'
,
**
kwargs
)
def
mean
(
self
,
**
kwargs
):
...
...
@@ -1081,7 +1078,7 @@ class field(object):
np.mean
"""
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'mean'
,
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'mean'
,
**
kwargs
)
def
std
(
self
,
**
kwargs
):
...
...
@@ -1098,7 +1095,7 @@ class field(object):
np.std
"""
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'std'
,
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'std'
,
**
kwargs
)
def
var
(
self
,
**
kwargs
):
...
...
@@ -1115,7 +1112,7 @@ class field(object):
np.var
"""
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'var'
,
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'var'
,
**
kwargs
)
def
argmin
(
self
,
split
=
True
,
**
kwargs
):
...
...
@@ -1141,10 +1138,10 @@ class field(object):
"""
if
split
:
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'argmin_nonflat'
,
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'argmin_nonflat'
,
**
kwargs
)
else
:
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'argmin'
,
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'argmin'
,
**
kwargs
)
def
argmax
(
self
,
split
=
True
,
**
kwargs
):
...
...
@@ -1170,29 +1167,29 @@ class field(object):
"""
if
split
:
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'argmax_nonflat'
,
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'argmax_nonflat'
,
**
kwargs
)
else
:
return
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'argmax'
,
return
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'argmax'
,
**
kwargs
)
# TODO: Implement the full range of unary and binary operotions
def
__pos__
(
self
):
new_field
=
self
.
copy_empty
()
new_val
=
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'pos'
)
new_val
=
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'pos'
)
new_field
.
set_val
(
new_val
=
new_val
)
return
new_field
def
__neg__
(
self
):
new_field
=
self
.
copy_empty
()
new_val
=
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'neg'
)
new_val
=
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'neg'
)
new_field
.
set_val
(
new_val
=
new_val
)
return
new_field
def
__abs__
(
self
):
new_field
=
self
.
copy_empty
()
new_val
=
self
.
_unary_
helper
(
self
.
get_val
(),
op
=
'abs'
)
new_val
=
self
.
_unary_
operation
(
self
.
get_val
(),
op
=
'abs'
)
new_field
.
set_val
(
new_val
=
new_val
)
return
new_field
...
...
@@ -1224,7 +1221,7 @@ class field(object):
working_field
.
set_val
(
new_val
=
new_val
)
return
working_field
def
unary_operation
(
self
,
x
,
op
=
'None'
,
axis
=
None
,
**
kwargs
):
def
_
unary_operation
(
self
,
x
,
op
=
'None'
,
axis
=
None
,
**
kwargs
):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
...
...
test/test_nifty_field.py
View file @
71fb7400
...
...
@@ -9,6 +9,8 @@ import unittest
import
itertools
import
numpy
as
np
from
d2o
import
distributed_data_object
from
nifty
import
space
,
\
point_space
,
\
rg_space
,
\
...
...
@@ -101,6 +103,21 @@ for param in itertools.product([(1,), (4, 6), (5, 8)],
fft_module
=
param
[
6
]),
param
[
5
]]]
def
generate_space_with_size
(
name
,
num
):
space_dict
=
{
'space'
:
space
(),
'point_space'
:
point_space
(
num
),
'rg_space'
:
rg_space
((
num
,
num
)),
'lm_space'
:
lm_space
(
mmax
=
num
,
lmax
=
num
),
'hp_space'
:
hp_space
(
num
),
'gl_space'
:
gl_space
(
nlat
=
num
,
nlon
=
num
),
}
return
space_dict
[
name
]
def
generate_data
(
space
):
a
=
np
.
arange
(
space
.
get_dim
()).
reshape
(
space
.
get_shape
())
return
distributed_data_object
(
a
)
###############################################################################
###############################################################################
...
...
@@ -151,3 +168,21 @@ class Test_field_multiple_init(unittest.TestCase):
assert
(
s1
.
check_codomain
(
f
.
codomain
[
0
]))
assert
(
s2
.
check_codomain
(
f
.
codomain
[
1
]))
assert
(
s1
.
get_shape
()
+
s2
.
get_shape
()
==
f
.
get_shape
())
class
Test_axis
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
itertools
.
product
(
point_like_spaces
,
[
4
],
[
'sum'
,
'prod'
,
'mean'
,
'var'
,
'std'
,
'median'
,
'all'
,
'any'
,
'amin'
,
'nanmin'
,
'argmin'
,
'amax'
,
'nanmax'
,
'argmax'
],
[
None
,
(
0
,)],
DATAMODELS
[
'rg_space'
]),
testcase_func_name
=
custom_name_func
)
def
test_unary_operations
(
self
,
name
,
num
,
op
,
axis
,
datamodel
):
s
=
generate_space_with_size
(
name
,
num
)
d
=
generate_data
(
s
)
a
=
d
.
get_full_data
()
f
=
field
(
val
=
d
,
domain
=
(
s
,),
dtype
=
s
.
dtype
,
datamodel
=
datamodel
)
assert_almost_equal
(
getattr
(
f
,
op
)(
axis
=
axis
),
getattr
(
np
,
op
)(
a
,
axis
=
axis
),
decimal
=
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment