Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
a40cbebf
Commit
a40cbebf
authored
Sep 02, 2017
by
Martin Reinecke
Browse files
cleanup
parent
6386588d
Changes
3
Hide whitespace changes
Inline
Side-by-side
nifty/basic_arithmetics.py
View file @
a40cbebf
...
...
@@ -23,18 +23,14 @@ from .field import Field
__all__
=
[
'cos'
,
'sin'
,
'cosh'
,
'sinh'
,
'tan'
,
'tanh'
,
'arccos'
,
'arcsin'
,
'arccosh'
,
'arcsinh'
,
'arctan'
,
'arctanh'
,
'sqrt'
,
'exp'
,
'log'
,
'conjugate'
,
'clipped_exp'
,
'limited_exp'
,
'limited_exp_deriv'
]
'conjugate'
]
def
_math_helper
(
x
,
function
):
if
isinstance
(
x
,
Field
):
result_val
=
function
(
x
.
val
)
result
=
x
.
copy_empty
(
dtype
=
result_val
.
dtype
)
result
.
val
=
result_val
return
Field
(
val
=
function
(
x
.
val
))
else
:
result
=
function
(
np
.
asarray
(
x
))
return
result
return
function
(
np
.
asarray
(
x
))
def
cos
(
x
):
...
...
@@ -93,36 +89,6 @@ def exp(x):
return
_math_helper
(
x
,
np
.
exp
)
def
clipped_exp
(
x
):
return
_math_helper
(
x
,
lambda
z
:
np
.
exp
(
np
.
minimum
(
200
,
z
)))
def
limited_exp
(
x
):
return
_math_helper
(
x
,
_limited_exp_helper
)
def
_limited_exp_helper
(
x
):
thr
=
200.
mask
=
x
>
thr
if
np
.
count_nonzero
(
mask
)
==
0
:
return
np
.
exp
(
x
)
result
=
((
1.
-
thr
)
+
x
)
*
np
.
exp
(
thr
)
result
[
~
mask
]
=
np
.
exp
(
x
[
~
mask
])
return
result
def
limited_exp_deriv
(
x
):
return
_math_helper
(
x
,
_limited_exp_deriv_helper
)
def
_limited_exp_deriv_helper
(
x
):
thr
=
200.
mask
=
x
>
thr
if
np
.
count_nonzero
(
mask
)
==
0
:
return
np
.
exp
(
x
)
result
=
np
.
empty_like
(
x
)
result
[
mask
]
=
np
.
exp
(
thr
)
result
[
~
mask
]
=
np
.
exp
(
x
[
~
mask
])
return
result
def
log
(
x
,
base
=
None
):
result
=
_math_helper
(
x
,
np
.
log
)
if
base
is
not
None
:
...
...
nifty/field.py
View file @
a40cbebf
...
...
@@ -89,8 +89,21 @@ class Field(object):
else
:
global_shape
=
reduce
(
lambda
x
,
y
:
x
+
y
,
shape_tuple
)
dtype
=
self
.
_infer_dtype
(
dtype
=
dtype
,
val
=
val
)
self
.
_val
=
np
.
empty
(
global_shape
,
dtype
=
dtype
)
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
if
isinstance
(
val
,
Field
):
if
self
.
domain
!=
val
.
domain
:
raise
ValueError
(
"Domain mismatch"
)
self
.
_val
=
np
.
array
(
val
.
val
,
dtype
=
dtype
,
copy
=
copy
)
elif
(
np
.
isscalar
(
val
)):
self
.
_val
=
np
.
full
(
global_shape
,
dtype
=
dtype
,
fill_value
=
val
)
elif
isinstance
(
val
,
np
.
ndarray
):
if
global_shape
==
val
.
shape
:
self
.
_val
=
np
.
array
(
val
,
dtype
=
dtype
,
copy
=
copy
)
else
:
raise
ValueError
(
"Shape mismatch"
)
elif
val
is
None
:
self
.
_val
=
np
.
empty
(
global_shape
,
dtype
=
dtype
)
else
:
raise
TypeError
(
"unknown source type"
)
def
_parse_domain
(
self
,
domain
,
val
=
None
):
if
domain
is
None
:
...
...
@@ -412,7 +425,6 @@ class Field(object):
# apply the rescaler to the random fields
result_list
[
0
].
val
*=
spec
.
real
if
not
real_power
:
result_list
[
1
].
val
*=
spec
.
imag
...
...
@@ -481,7 +493,7 @@ class Field(object):
if
copy
:
self
.
_val
[()]
=
new_val
.
val
else
:
self
.
_val
=
new_val
.
val
self
.
_val
=
np
.
array
(
new_val
.
val
,
dtype
=
self
.
dtype
,
copy
=
False
)
elif
(
np
.
isscalar
(
new_val
)):
self
.
_val
[()]
=
new_val
elif
isinstance
(
new_val
,
np
.
ndarray
):
...
...
@@ -490,7 +502,7 @@ class Field(object):
else
:
if
self
.
shape
!=
new_val
.
shape
:
raise
ValueError
(
"Shape mismatch"
)
self
.
_val
=
n
ew_val
self
.
_val
=
n
p
.
array
(
new_val
,
dtype
=
self
.
dtype
,
copy
=
False
)
else
:
raise
TypeError
(
"unknown source type"
)
return
self
...
...
@@ -573,12 +585,7 @@ class Field(object):
shape
"""
dim_tuple
=
tuple
(
sp
.
dim
for
sp
in
self
.
domain
)
try
:
return
int
(
reduce
(
lambda
x
,
y
:
x
*
y
,
dim_tuple
))
except
TypeError
:
return
0
return
self
.
_val
.
size
@
property
def
dof
(
self
):
...
...
nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py
View file @
a40cbebf
from
...operators
import
EndomorphicOperator
,
\
InvertibleOperatorMixin
from
...energies.memoization
import
memo
from
...basic_arithmetics
import
clipped_
exp
from
...basic_arithmetics
import
exp
from
...sugar
import
create_composed_fft_operator
...
...
@@ -71,7 +71,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
@
property
@
memo
def
_expp_sspace
(
self
):
return
clipped_
exp
(
self
.
_fft
(
self
.
position
))
return
exp
(
self
.
_fft
(
self
.
position
))
@
property
@
memo
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment