Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
55564314
Commit
55564314
authored
Aug 03, 2018
by
Martin Reinecke
Browse files
experimental performance tweaks
parent
79df4e2b
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty5/field.py
View file @
55564314
...
...
@@ -631,10 +631,98 @@ class Field(object):
return
0.5
*
(
1.
+
self
.
tanh
())
return
Field
(
self
.
_domain
,
0.5
*
(
1.
+
np
.
tanh
(
self
.
_uni
)))
def
__add__
(
self
,
other
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
if
self
.
_uni
is
None
:
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
+
other
.
_val
)
if
other
.
_uni
==
0
:
return
self
return
Field
(
self
.
_domain
,
self
.
_val
+
other
.
_uni
)
else
:
if
self
.
_uni
==
0
:
return
other
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
other
.
_val
+
self
.
_uni
)
return
Field
(
self
.
_domain
,
self
.
_uni
+
other
.
_uni
)
if
np
.
isscalar
(
other
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
+
other
)
return
Field
(
self
.
_domain
,
self
.
_uni
+
other
)
if
isinstance
(
other
,
(
dobj
.
data_object
,
np
.
ndarray
)):
return
Field
(
self
.
_domain
,
self
.
_val
+
other
)
return
NotImplemented
def
__radd__
(
self
,
other
):
return
self
.
__add__
(
other
)
def
__sub__
(
self
,
other
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
if
self
.
_uni
is
None
:
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
-
other
.
_val
)
if
other
.
_uni
==
0
:
return
self
return
Field
(
self
.
_domain
,
self
.
_val
-
other
.
_uni
)
else
:
if
self
.
_uni
==
0
:
return
-
other
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_uni
-
other
.
_val
)
return
Field
(
self
.
_domain
,
self
.
_uni
-
other
.
_uni
)
for
op
in
[
"__add__"
,
"__radd__"
,
"__sub__"
,
"__rsub__"
,
"__mul__"
,
"__rmul__"
,
if
np
.
isscalar
(
other
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
-
other
)
return
Field
(
self
.
_domain
,
self
.
_uni
-
other
)
if
isinstance
(
other
,
(
dobj
.
data_object
,
np
.
ndarray
)):
return
Field
(
self
.
_domain
,
self
.
_val
-
other
)
return
NotImplemented
def
__mul__
(
self
,
other
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
if
self
.
_uni
is
None
:
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
*
other
.
_val
)
if
other
.
_uni
==
1
:
return
self
if
other
.
_uni
==
0
:
return
other
return
Field
(
self
.
_domain
,
self
.
_val
*
other
.
_uni
)
else
:
if
self
.
_uni
==
1
:
return
other
if
self
.
_uni
==
0
:
return
self
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
other
.
_val
*
self
.
_uni
)
return
Field
(
self
.
_domain
,
self
.
_uni
*
other
.
_uni
)
if
np
.
isscalar
(
other
):
if
self
.
_uni
is
None
:
if
other
==
1
:
return
self
if
other
==
0
:
return
Field
(
self
.
_domain
,
other
)
return
Field
(
self
.
_domain
,
self
.
_val
*
other
)
return
Field
(
self
.
_domain
,
self
.
_uni
*
other
)
if
isinstance
(
other
,
(
dobj
.
data_object
,
np
.
ndarray
)):
return
Field
(
self
.
_domain
,
self
.
_val
*
other
)
return
NotImplemented
for
op
in
[
"__rsub__"
,
"__rmul__"
,
"__div__"
,
"__rdiv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
...
...
nifty5/operators/diagonal_operator.py
View file @
55564314
...
...
@@ -94,11 +94,12 @@ class DiagonalOperator(EndomorphicOperator):
self
.
_ldiag
=
self
.
_ldiag
.
reshape
(
self
.
_reshaper
)
else
:
self
.
_ldiag
=
diagonal
.
local_data
self
.
_
update_diagmin
()
self
.
_
fill_rest
()
def
_
update_diagmin
(
self
):
def
_
fill_rest
(
self
):
self
.
_ldiag
.
flags
.
writeable
=
False
if
not
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
complexfloating
):
self
.
_complex
=
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
complexfloating
)
if
not
self
.
_complex
:
lmin
=
self
.
_ldiag
.
min
()
if
self
.
_ldiag
.
size
>
0
else
1.
self
.
_diagmin
=
dobj
.
np_allreduce_min
(
np
.
array
(
lmin
))[()]
...
...
@@ -110,7 +111,7 @@ class DiagonalOperator(EndomorphicOperator):
else
:
res
.
_spaces
=
tuple
(
set
(
self
.
_spaces
)
|
set
(
spc
))
res
.
_ldiag
=
ldiag
res
.
_
update_diagmin
()
res
.
_
fill_rest
()
return
res
def
_scale
(
self
,
fct
):
...
...
@@ -137,21 +138,17 @@ class DiagonalOperator(EndomorphicOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
# shortcut for most common cases
if
mode
==
1
or
(
not
self
.
_complex
and
mode
==
2
)
:
return
Field
(
x
.
domain
,
val
=
x
.
val
*
self
.
_ldiag
)
elif
mode
==
self
.
ADJOINT_TIMES
:
if
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
floating
):
return
Field
(
x
.
domain
,
val
=
x
.
val
*
self
.
_ldiag
)
else
:
return
Field
(
x
.
domain
,
val
=
x
.
val
*
self
.
_ldiag
.
conj
())
elif
mode
==
self
.
INVERSE_TIMES
:
return
Field
(
x
.
domain
,
val
=
x
.
val
/
self
.
_ldiag
)
else
:
if
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
floating
):
return
Field
(
x
.
domain
,
val
=
x
.
val
/
self
.
_ldiag
)
else
:
return
Field
(
x
.
domain
,
val
=
x
.
val
/
self
.
_ldiag
.
conj
())
xdiag
=
self
.
_ldiag
if
self
.
_complex
and
(
mode
&
10
):
# adjoint or inverse adjoint
xdiag
=
xdiag
.
conj
()
if
mode
&
3
:
return
Field
(
x
.
domain
,
val
=
x
.
val
*
xdiag
)
return
Field
(
x
.
domain
,
val
=
x
.
val
/
xdiag
)
@
property
def
domain
(
self
):
...
...
@@ -162,23 +159,15 @@ class DiagonalOperator(EndomorphicOperator):
return
self
.
_all_ops
def
_flip_modes
(
self
,
trafo
):
ADJ
=
self
.
ADJOINT_BIT
INV
=
self
.
INVERSE_BIT
if
trafo
==
0
:
return
self
if
trafo
==
ADJ
and
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
floating
):
return
self
if
trafo
==
ADJ
:
return
self
.
_from_ldiag
((),
self
.
_ldiag
.
conjugate
())
elif
trafo
==
INV
:
return
self
.
_from_ldiag
((),
1.
/
self
.
_ldiag
)
elif
trafo
==
ADJ
|
INV
:
return
self
.
_from_ldiag
((),
1.
/
self
.
_ldiag
.
conjugate
())
raise
ValueError
(
"invalid operator transformation"
)
xdiag
=
self
.
_ldiag
if
self
.
_complex
and
(
trafo
&
self
.
ADJOINT_BIT
):
xdiag
=
xdiag
.
conj
()
if
trafo
&
self
.
INVERSE_BIT
:
xdiag
=
1.
/
xdiag
return
self
.
_from_ldiag
((),
xdiag
)
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
if
np
.
issubdtype
(
self
.
_ldiag
.
dtype
,
np
.
complexfloating
)
:
if
self
.
_complex
:
raise
ValueError
(
"operator not positive definite"
)
if
self
.
_diagmin
<
0.
:
raise
ValueError
(
"operator not positive definite"
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment