Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
2e434668
Commit
2e434668
authored
Aug 08, 2018
by
Martin Reinecke
Browse files
simplification and cosmetics
parent
5d2241a3
Changes
4
Show whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
2e434668
...
...
@@ -78,8 +78,6 @@ from .library.amplitude_model import AmplitudeModel
from
.library.inverse_gamma_model
import
InverseGammaModel
from
.library.los_response
import
LOSResponse
#from .library.inverse_gamma_model import InverseGammaModel
from
.library.wiener_filter_curvature
import
WienerFilterCurvature
from
.library.correlated_fields
import
CorrelatedField
# make_mf_correlated_field)
...
...
nifty5/field.py
View file @
2e434668
...
...
@@ -47,13 +47,11 @@ class Field(object):
"""
def
__init__
(
self
,
domain
,
val
):
self
.
_uni
=
None
if
not
isinstance
(
domain
,
DomainTuple
):
raise
TypeError
(
"domain must be of type DomainTuple"
)
if
not
isinstance
(
val
,
dobj
.
data_object
)
:
if
type
(
val
)
is
not
dobj
.
data_object
:
if
np
.
isscalar
(
val
):
self
.
_uni
=
val
val
=
dobj
.
uniform_full
(
domain
.
shape
,
val
)
val
=
dobj
.
full
(
domain
.
shape
,
val
)
else
:
raise
TypeError
(
"val must be of type dobj.data_object"
)
if
domain
.
shape
!=
val
.
shape
:
...
...
@@ -394,14 +392,10 @@ class Field(object):
return
self
def
__neg__
(
self
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
-
self
.
_val
)
return
Field
(
self
.
_domain
,
-
self
.
_uni
)
def
__abs__
(
self
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
abs
(
self
.
_val
))
return
Field
(
self
.
_domain
,
abs
(
self
.
_uni
))
def
_contraction_helper
(
self
,
op
,
spaces
):
if
spaces
is
None
:
...
...
@@ -617,96 +611,12 @@ class Field(object):
return
self
+
other
def
positive_tanh
(
self
):
if
self
.
_uni
is
None
:
return
0.5
*
(
1.
+
self
.
tanh
())
return
Field
(
self
.
_domain
,
0.5
*
(
1.
+
np
.
tanh
(
self
.
_uni
)))
def
__add__
(
self
,
other
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
if
self
.
_uni
is
None
:
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
+
other
.
_val
)
if
other
.
_uni
==
0
:
return
self
return
Field
(
self
.
_domain
,
self
.
_val
+
other
.
_uni
)
else
:
if
self
.
_uni
==
0
:
return
other
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
other
.
_val
+
self
.
_uni
)
return
Field
(
self
.
_domain
,
self
.
_uni
+
other
.
_uni
)
if
np
.
isscalar
(
other
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
+
other
)
return
Field
(
self
.
_domain
,
self
.
_uni
+
other
)
return
NotImplemented
def
__radd__
(
self
,
other
):
return
self
.
__add__
(
other
)
def
__sub__
(
self
,
other
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
if
self
.
_uni
is
None
:
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
-
other
.
_val
)
if
other
.
_uni
==
0
:
return
self
return
Field
(
self
.
_domain
,
self
.
_val
-
other
.
_uni
)
else
:
if
self
.
_uni
==
0
:
return
-
other
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_uni
-
other
.
_val
)
return
Field
(
self
.
_domain
,
self
.
_uni
-
other
.
_uni
)
if
np
.
isscalar
(
other
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
-
other
)
return
Field
(
self
.
_domain
,
self
.
_uni
-
other
)
return
NotImplemented
def
__mul__
(
self
,
other
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
if
self
.
_uni
is
None
:
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
*
other
.
_val
)
if
other
.
_uni
==
1
:
return
self
if
other
.
_uni
==
0
:
return
other
return
Field
(
self
.
_domain
,
self
.
_val
*
other
.
_uni
)
else
:
if
self
.
_uni
==
1
:
return
other
if
self
.
_uni
==
0
:
return
self
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
other
.
_val
*
self
.
_uni
)
return
Field
(
self
.
_domain
,
self
.
_uni
*
other
.
_uni
)
if
np
.
isscalar
(
other
):
if
self
.
_uni
is
None
:
if
other
==
1
:
return
self
if
other
==
0
:
return
Field
(
self
.
_domain
,
other
)
return
Field
(
self
.
_domain
,
self
.
_val
*
other
)
return
Field
(
self
.
_domain
,
self
.
_uni
*
other
)
return
NotImplemented
for
op
in
[
"__rsub__"
,
"__rmul__"
,
for
op
in
[
"__add__"
,
"__radd__"
,
"__sub__"
,
"__rsub__"
,
"__mul__"
,
"__rmul__"
,
"__div__"
,
"__rdiv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
...
...
@@ -739,11 +649,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
]:
def
func
(
f
):
def
func2
(
self
):
if
self
.
_uni
is
None
:
fu
=
getattr
(
dobj
,
f
)
return
Field
(
domain
=
self
.
_domain
,
val
=
fu
(
self
.
val
))
else
:
fu
=
getattr
(
np
,
f
)
return
Field
(
domain
=
self
.
_domain
,
val
=
fu
(
self
.
_uni
))
return
func2
setattr
(
Field
,
f
,
func
(
f
))
nifty5/linearization.py
View file @
2e434668
...
...
@@ -102,10 +102,10 @@ class Linearization(object):
from
.operators.simple_linear_operators
import
VdotOperator
if
isinstance
(
other
,
(
Field
,
MultiField
)):
return
Linearization
(
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
)),
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
)),
VdotOperator
(
other
)(
self
.
_jac
))
return
Linearization
(
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
.
_val
)),
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
.
_val
)),
VdotOperator
(
self
.
_val
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_val
)(
self
.
_jac
))
...
...
nifty5/minimization/scipy_minimizer.py
View file @
2e434668
...
...
@@ -26,12 +26,12 @@ from .iteration_controller import IterationController
from
.minimizer
import
Minimizer
def
_to
Nda
rray
(
fld
):
def
_to
A
rray
(
fld
):
return
fld
.
to_global_data
().
reshape
(
-
1
)
def
_to
FlatNda
rray
(
fld
):
return
fld
.
val
.
flatten
(
)
def
_to
A
rray
_rw
(
fld
):
return
fld
.
to_global_data_rw
().
reshape
(
-
1
)
def
_toField
(
arr
,
dom
):
...
...
@@ -54,12 +54,12 @@ class _MinHelper(object):
def
jac
(
self
,
x
):
self
.
_update
(
x
)
return
_to
FlatNda
rray
(
self
.
_energy
.
gradient
)
return
_to
A
rray
_rw
(
self
.
_energy
.
gradient
)
def
hessp
(
self
,
x
,
p
):
self
.
_update
(
x
)
res
=
self
.
_energy
.
metric
(
_toField
(
p
,
self
.
_domain
))
return
_to
FlatNda
rray
(
res
)
return
_to
A
rray
_rw
(
res
)
class
ScipyMinimizer
(
Minimizer
):
...
...
@@ -95,7 +95,7 @@ class ScipyMinimizer(Minimizer):
else
:
raise
ValueError
(
"unrecognized bounds"
)
x
=
hlp
.
_energy
.
position
.
val
.
flatten
(
)
x
=
_toArray_rw
(
hlp
.
_energy
.
position
)
hessp
=
hlp
.
hessp
if
self
.
_need_hessp
else
None
r
=
opt
.
minimize
(
hlp
.
fun
,
x
,
method
=
self
.
_method
,
jac
=
hlp
.
jac
,
hessp
=
hessp
,
options
=
self
.
_options
,
bounds
=
bounds
)
...
...
@@ -147,11 +147,11 @@ class ScipyCG(Minimizer):
self
.
_op
=
op
def
__call__
(
self
,
inp
):
return
_to
Nda
rray
(
self
.
_op
(
_toField
(
inp
,
self
.
_op
.
domain
)))
return
_to
A
rray
(
self
.
_op
(
_toField
(
inp
,
self
.
_op
.
domain
)))
op
=
energy
.
_A
b
=
_to
Nda
rray
(
energy
.
_b
)
sx
=
_to
Nda
rray
(
energy
.
position
)
b
=
_to
A
rray
(
energy
.
_b
)
sx
=
_to
A
rray
(
energy
.
position
)
sci_op
=
scipy_linop
(
shape
=
(
op
.
domain
.
size
,
op
.
target
.
size
),
matvec
=
mymatvec
(
op
))
prec_op
=
None
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment