Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
f0a24f5e
Commit
f0a24f5e
authored
Oct 22, 2018
by
Martin Reinecke
Browse files
tweaks
parent
0e63c553
Changes
3
Hide whitespace changes
Inline
Side-by-side
nifty5/data_objects/distributed_do.py
View file @
f0a24f5e
...
...
@@ -34,7 +34,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"distaxis"
,
"from_local_data"
,
"from_global_data"
,
"to_global_data"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"absmax"
,
"norm"
,
"lock"
,
"locked"
,
"uniform_full"
,
"transpose"
,
"to_global_data_rw"
,
"ensure_not_distributed"
,
"ensure_default_distributed"
]
"ensure_not_distributed"
,
"ensure_default_distributed"
,
"clipped_exp"
]
_comm
=
MPI
.
COMM_WORLD
ntask
=
_comm
.
Get_size
()
...
...
@@ -303,6 +304,10 @@ for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
setattr
(
_current_module
,
f
,
func
(
f
))
def
clipped_exp
(
a
):
return
data_object
(
x
.
shape
,
np
.
exp
(
np
.
clip
(
x
.
data
,
-
300
,
300
),
x
.
distaxis
)
def
from_object
(
object
,
dtype
,
copy
,
set_locked
):
if
dtype
is
None
:
dtype
=
object
.
dtype
...
...
nifty5/data_objects/numpy_do.py
View file @
f0a24f5e
...
...
@@ -33,7 +33,8 @@ __all__ = ["ntask", "rank", "master", "local_shape", "data_object", "full",
"distaxis"
,
"from_local_data"
,
"from_global_data"
,
"to_global_data"
,
"redistribute"
,
"default_distaxis"
,
"is_numpy"
,
"absmax"
,
"norm"
,
"lock"
,
"locked"
,
"uniform_full"
,
"to_global_data_rw"
,
"ensure_not_distributed"
,
"ensure_default_distributed"
]
"ensure_not_distributed"
,
"ensure_default_distributed"
,
"clipped_exp"
]
ntask
=
1
rank
=
0
...
...
@@ -149,3 +150,7 @@ def absmax(arr):
def
norm
(
arr
,
ord
=
2
):
return
np
.
linalg
.
norm
(
arr
.
reshape
(
-
1
),
ord
=
ord
)
def
clipped_exp
(
arr
):
return
np
.
exp
(
np
.
clip
(
arr
,
-
300
,
300
))
nifty5/field.py
View file @
f0a24f5e
...
...
@@ -634,6 +634,9 @@ class Field(object):
def
positive_tanh
(
self
):
return
0.5
*
(
1.
+
self
.
tanh
())
def
clipped_exp
(
self
):
return
Field
(
self
.
_domain
,
dobj
.
clipped_exp
(
self
.
_val
))
def
_binary_op
(
self
,
other
,
op
):
# if other is a field, make sure that the domains match
f
=
getattr
(
self
.
_val
,
op
)
...
...
@@ -675,9 +678,3 @@ for f in ["sqrt", "exp", "log", "tanh"]:
return
Field
(
self
.
_domain
,
getattr
(
dobj
,
f
)(
self
.
val
))
return
func2
setattr
(
Field
,
f
,
func
(
f
))
def
func2
(
self
):
np
.
clip
(
self
.
val
,
-
300
,
300
,
out
=
self
.
val
)
return
Field
(
self
.
_domain
,
getattr
(
dobj
,
'exp'
)(
self
.
val
))
setattr
(
Field
,
'clipped_exp'
,
func2
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment