Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
2e434668
Commit
2e434668
authored
Aug 08, 2018
by
Martin Reinecke
Browse files
simplification and cosmetics
parent
5d2241a3
Changes
4
Hide whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
2e434668
...
@@ -78,8 +78,6 @@ from .library.amplitude_model import AmplitudeModel
...
@@ -78,8 +78,6 @@ from .library.amplitude_model import AmplitudeModel
from
.library.inverse_gamma_model
import
InverseGammaModel
from
.library.inverse_gamma_model
import
InverseGammaModel
from
.library.los_response
import
LOSResponse
from
.library.los_response
import
LOSResponse
#from .library.inverse_gamma_model import InverseGammaModel
from
.library.wiener_filter_curvature
import
WienerFilterCurvature
from
.library.wiener_filter_curvature
import
WienerFilterCurvature
from
.library.correlated_fields
import
CorrelatedField
from
.library.correlated_fields
import
CorrelatedField
# make_mf_correlated_field)
# make_mf_correlated_field)
...
...
nifty5/field.py
View file @
2e434668
...
@@ -47,13 +47,11 @@ class Field(object):
...
@@ -47,13 +47,11 @@ class Field(object):
"""
"""
def
__init__
(
self
,
domain
,
val
):
def
__init__
(
self
,
domain
,
val
):
self
.
_uni
=
None
if
not
isinstance
(
domain
,
DomainTuple
):
if
not
isinstance
(
domain
,
DomainTuple
):
raise
TypeError
(
"domain must be of type DomainTuple"
)
raise
TypeError
(
"domain must be of type DomainTuple"
)
if
not
isinstance
(
val
,
dobj
.
data_object
)
:
if
type
(
val
)
is
not
dobj
.
data_object
:
if
np
.
isscalar
(
val
):
if
np
.
isscalar
(
val
):
self
.
_uni
=
val
val
=
dobj
.
full
(
domain
.
shape
,
val
)
val
=
dobj
.
uniform_full
(
domain
.
shape
,
val
)
else
:
else
:
raise
TypeError
(
"val must be of type dobj.data_object"
)
raise
TypeError
(
"val must be of type dobj.data_object"
)
if
domain
.
shape
!=
val
.
shape
:
if
domain
.
shape
!=
val
.
shape
:
...
@@ -394,14 +392,10 @@ class Field(object):
...
@@ -394,14 +392,10 @@ class Field(object):
return
self
return
self
def
__neg__
(
self
):
def
__neg__
(
self
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
-
self
.
_val
)
return
Field
(
self
.
_domain
,
-
self
.
_val
)
return
Field
(
self
.
_domain
,
-
self
.
_uni
)
def
__abs__
(
self
):
def
__abs__
(
self
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
abs
(
self
.
_val
))
return
Field
(
self
.
_domain
,
abs
(
self
.
_val
))
return
Field
(
self
.
_domain
,
abs
(
self
.
_uni
))
def
_contraction_helper
(
self
,
op
,
spaces
):
def
_contraction_helper
(
self
,
op
,
spaces
):
if
spaces
is
None
:
if
spaces
is
None
:
...
@@ -617,96 +611,12 @@ class Field(object):
...
@@ -617,96 +611,12 @@ class Field(object):
return
self
+
other
return
self
+
other
def
positive_tanh
(
self
):
def
positive_tanh
(
self
):
if
self
.
_uni
is
None
:
return
0.5
*
(
1.
+
self
.
tanh
())
return
0.5
*
(
1.
+
self
.
tanh
())
return
Field
(
self
.
_domain
,
0.5
*
(
1.
+
np
.
tanh
(
self
.
_uni
)))
for
op
in
[
"__add__"
,
"__radd__"
,
def
__add__
(
self
,
other
):
"__sub__"
,
"__rsub__"
,
# if other is a field, make sure that the domains match
"__mul__"
,
"__rmul__"
,
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
if
self
.
_uni
is
None
:
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
+
other
.
_val
)
if
other
.
_uni
==
0
:
return
self
return
Field
(
self
.
_domain
,
self
.
_val
+
other
.
_uni
)
else
:
if
self
.
_uni
==
0
:
return
other
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
other
.
_val
+
self
.
_uni
)
return
Field
(
self
.
_domain
,
self
.
_uni
+
other
.
_uni
)
if
np
.
isscalar
(
other
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
+
other
)
return
Field
(
self
.
_domain
,
self
.
_uni
+
other
)
return
NotImplemented
def
__radd__
(
self
,
other
):
return
self
.
__add__
(
other
)
def
__sub__
(
self
,
other
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
if
self
.
_uni
is
None
:
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
-
other
.
_val
)
if
other
.
_uni
==
0
:
return
self
return
Field
(
self
.
_domain
,
self
.
_val
-
other
.
_uni
)
else
:
if
self
.
_uni
==
0
:
return
-
other
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_uni
-
other
.
_val
)
return
Field
(
self
.
_domain
,
self
.
_uni
-
other
.
_uni
)
if
np
.
isscalar
(
other
):
if
self
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
-
other
)
return
Field
(
self
.
_domain
,
self
.
_uni
-
other
)
return
NotImplemented
def
__mul__
(
self
,
other
):
# if other is a field, make sure that the domains match
if
isinstance
(
other
,
Field
):
if
other
.
_domain
is
not
self
.
_domain
:
raise
ValueError
(
"domains are incompatible."
)
if
self
.
_uni
is
None
:
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
self
.
_val
*
other
.
_val
)
if
other
.
_uni
==
1
:
return
self
if
other
.
_uni
==
0
:
return
other
return
Field
(
self
.
_domain
,
self
.
_val
*
other
.
_uni
)
else
:
if
self
.
_uni
==
1
:
return
other
if
self
.
_uni
==
0
:
return
self
if
other
.
_uni
is
None
:
return
Field
(
self
.
_domain
,
other
.
_val
*
self
.
_uni
)
return
Field
(
self
.
_domain
,
self
.
_uni
*
other
.
_uni
)
if
np
.
isscalar
(
other
):
if
self
.
_uni
is
None
:
if
other
==
1
:
return
self
if
other
==
0
:
return
Field
(
self
.
_domain
,
other
)
return
Field
(
self
.
_domain
,
self
.
_val
*
other
)
return
Field
(
self
.
_domain
,
self
.
_uni
*
other
)
return
NotImplemented
for
op
in
[
"__rsub__"
,
"__rmul__"
,
"__div__"
,
"__rdiv__"
,
"__div__"
,
"__rdiv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
...
@@ -739,11 +649,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
...
@@ -739,11 +649,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
]:
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
]:
def
func
(
f
):
def
func
(
f
):
def
func2
(
self
):
def
func2
(
self
):
if
self
.
_uni
is
None
:
fu
=
getattr
(
dobj
,
f
)
fu
=
getattr
(
dobj
,
f
)
return
Field
(
domain
=
self
.
_domain
,
val
=
fu
(
self
.
val
))
return
Field
(
domain
=
self
.
_domain
,
val
=
fu
(
self
.
val
))
else
:
fu
=
getattr
(
np
,
f
)
return
Field
(
domain
=
self
.
_domain
,
val
=
fu
(
self
.
_uni
))
return
func2
return
func2
setattr
(
Field
,
f
,
func
(
f
))
setattr
(
Field
,
f
,
func
(
f
))
nifty5/linearization.py
View file @
2e434668
...
@@ -102,10 +102,10 @@ class Linearization(object):
...
@@ -102,10 +102,10 @@ class Linearization(object):
from
.operators.simple_linear_operators
import
VdotOperator
from
.operators.simple_linear_operators
import
VdotOperator
if
isinstance
(
other
,
(
Field
,
MultiField
)):
if
isinstance
(
other
,
(
Field
,
MultiField
)):
return
Linearization
(
return
Linearization
(
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
)),
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
)),
VdotOperator
(
other
)(
self
.
_jac
))
VdotOperator
(
other
)(
self
.
_jac
))
return
Linearization
(
return
Linearization
(
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
.
_val
)),
Field
(
DomainTuple
.
scalar_domain
(),
self
.
_val
.
vdot
(
other
.
_val
)),
VdotOperator
(
self
.
_val
)(
other
.
_jac
)
+
VdotOperator
(
self
.
_val
)(
other
.
_jac
)
+
VdotOperator
(
other
.
_val
)(
self
.
_jac
))
VdotOperator
(
other
.
_val
)(
self
.
_jac
))
...
...
nifty5/minimization/scipy_minimizer.py
View file @
2e434668
...
@@ -26,12 +26,12 @@ from .iteration_controller import IterationController
...
@@ -26,12 +26,12 @@ from .iteration_controller import IterationController
from
.minimizer
import
Minimizer
from
.minimizer
import
Minimizer
def
_to
Nda
rray
(
fld
):
def
_to
A
rray
(
fld
):
return
fld
.
to_global_data
().
reshape
(
-
1
)
return
fld
.
to_global_data
().
reshape
(
-
1
)
def
_to
FlatNda
rray
(
fld
):
def
_to
A
rray
_rw
(
fld
):
return
fld
.
val
.
flatten
(
)
return
fld
.
to_global_data_rw
().
reshape
(
-
1
)
def
_toField
(
arr
,
dom
):
def
_toField
(
arr
,
dom
):
...
@@ -54,12 +54,12 @@ class _MinHelper(object):
...
@@ -54,12 +54,12 @@ class _MinHelper(object):
def
jac
(
self
,
x
):
def
jac
(
self
,
x
):
self
.
_update
(
x
)
self
.
_update
(
x
)
return
_to
FlatNda
rray
(
self
.
_energy
.
gradient
)
return
_to
A
rray
_rw
(
self
.
_energy
.
gradient
)
def
hessp
(
self
,
x
,
p
):
def
hessp
(
self
,
x
,
p
):
self
.
_update
(
x
)
self
.
_update
(
x
)
res
=
self
.
_energy
.
metric
(
_toField
(
p
,
self
.
_domain
))
res
=
self
.
_energy
.
metric
(
_toField
(
p
,
self
.
_domain
))
return
_to
FlatNda
rray
(
res
)
return
_to
A
rray
_rw
(
res
)
class
ScipyMinimizer
(
Minimizer
):
class
ScipyMinimizer
(
Minimizer
):
...
@@ -95,7 +95,7 @@ class ScipyMinimizer(Minimizer):
...
@@ -95,7 +95,7 @@ class ScipyMinimizer(Minimizer):
else
:
else
:
raise
ValueError
(
"unrecognized bounds"
)
raise
ValueError
(
"unrecognized bounds"
)
x
=
hlp
.
_energy
.
position
.
val
.
flatten
(
)
x
=
_toArray_rw
(
hlp
.
_energy
.
position
)
hessp
=
hlp
.
hessp
if
self
.
_need_hessp
else
None
hessp
=
hlp
.
hessp
if
self
.
_need_hessp
else
None
r
=
opt
.
minimize
(
hlp
.
fun
,
x
,
method
=
self
.
_method
,
jac
=
hlp
.
jac
,
r
=
opt
.
minimize
(
hlp
.
fun
,
x
,
method
=
self
.
_method
,
jac
=
hlp
.
jac
,
hessp
=
hessp
,
options
=
self
.
_options
,
bounds
=
bounds
)
hessp
=
hessp
,
options
=
self
.
_options
,
bounds
=
bounds
)
...
@@ -147,11 +147,11 @@ class ScipyCG(Minimizer):
...
@@ -147,11 +147,11 @@ class ScipyCG(Minimizer):
self
.
_op
=
op
self
.
_op
=
op
def
__call__
(
self
,
inp
):
def
__call__
(
self
,
inp
):
return
_to
Nda
rray
(
self
.
_op
(
_toField
(
inp
,
self
.
_op
.
domain
)))
return
_to
A
rray
(
self
.
_op
(
_toField
(
inp
,
self
.
_op
.
domain
)))
op
=
energy
.
_A
op
=
energy
.
_A
b
=
_to
Nda
rray
(
energy
.
_b
)
b
=
_to
A
rray
(
energy
.
_b
)
sx
=
_to
Nda
rray
(
energy
.
position
)
sx
=
_to
A
rray
(
energy
.
position
)
sci_op
=
scipy_linop
(
shape
=
(
op
.
domain
.
size
,
op
.
target
.
size
),
sci_op
=
scipy_linop
(
shape
=
(
op
.
domain
.
size
,
op
.
target
.
size
),
matvec
=
mymatvec
(
op
))
matvec
=
mymatvec
(
op
))
prec_op
=
None
prec_op
=
None
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment