Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
93c14275
Commit
93c14275
authored
Jul 24, 2018
by
Martin Reinecke
Browse files
begin redesign
parent
b8dbbbfa
Changes
5
Hide whitespace changes
Inline
Side-by-side
nifty5/__init__.py
View file @
93c14275
...
...
@@ -102,5 +102,7 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator
from
.energies.kl
import
SampledKullbachLeiblerDivergence
from
.energies.hamiltonian
import
Hamiltonian
from
.
operator
import
Linearization
,
Operator
# We deliberately don't set __all__ here, because we don't want people to do a
# "from nifty5 import *"; that would swamp the global namespace.
nifty5/multi/multi_field.py
View file @
93c14275
...
...
@@ -32,7 +32,7 @@ class MultiField(object):
Parameters
----------
domain: MultiDomain
val: tuple containing Field
or None
entries
val: tuple containing Field entries
"""
if
not
isinstance
(
domain
,
MultiDomain
):
raise
TypeError
(
"domain must be of type MultiDomain"
)
...
...
@@ -44,8 +44,8 @@ class MultiField(object):
if
isinstance
(
v
,
Field
):
if
v
.
_domain
is
not
d
:
raise
ValueError
(
"domain mismatch"
)
el
if
v
is
not
Non
e
:
raise
TypeError
(
"bad entry in val (must be Field
or None
)"
)
el
s
e
:
raise
TypeError
(
"bad entry in val (must be Field)"
)
self
.
_domain
=
domain
self
.
_val
=
val
...
...
@@ -54,8 +54,7 @@ class MultiField(object):
if
domain
is
None
:
domain
=
MultiDomain
.
make
({
key
:
v
.
_domain
for
key
,
v
in
dict
.
items
()})
return
MultiField
(
domain
,
tuple
(
dict
[
key
]
if
key
in
dict
else
None
for
key
in
domain
.
keys
()))
return
MultiField
(
domain
,
tuple
(
dict
[
key
]
for
key
in
domain
.
keys
()))
def
to_dict
(
self
):
return
{
key
:
val
for
key
,
val
in
zip
(
self
.
_domain
.
keys
(),
self
.
_val
)}
...
...
@@ -81,9 +80,7 @@ class MultiField(object):
# return {key: val.dtype for key, val in self._val.items()}
def
_transform
(
self
,
op
):
return
MultiField
(
self
.
_domain
,
tuple
(
op
(
v
)
if
v
is
not
None
else
None
for
v
in
self
.
_val
))
return
MultiField
(
self
.
_domain
,
tuple
(
op
(
v
)
for
v
in
self
.
_val
))
@
property
def
real
(
self
):
...
...
@@ -111,8 +108,7 @@ class MultiField(object):
result
=
0.
self
.
_check_domain
(
x
)
for
v1
,
v2
in
zip
(
self
.
_val
,
x
.
_val
):
if
v1
is
not
None
and
v2
is
not
None
:
result
+=
v1
.
vdot
(
v2
)
result
+=
v1
.
vdot
(
v2
)
return
result
# @staticmethod
...
...
@@ -190,13 +186,13 @@ class MultiField(object):
def
all
(
self
):
for
v
in
self
.
_val
:
if
v
is
None
or
not
v
.
all
():
if
not
v
.
all
():
return
False
return
True
def
any
(
self
):
for
v
in
self
.
_val
:
if
v
is
not
None
and
v
.
any
():
if
v
.
any
():
return
True
return
False
...
...
@@ -215,44 +211,9 @@ class MultiField(object):
return
True
for
op
in
[
"__add__"
,
"__radd__"
]:
def
func
(
op
):
def
func2
(
self
,
other
):
if
isinstance
(
other
,
MultiField
):
if
self
.
_domain
is
not
other
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
val
=
[]
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
):
if
v1
is
not
None
:
val
.
append
(
v1
if
v2
is
None
else
(
v1
+
v2
))
else
:
val
.
append
(
None
if
v2
is
None
else
v2
)
val
=
tuple
(
val
)
else
:
val
=
tuple
(
other
if
v1
is
None
else
(
v1
+
other
)
for
v1
in
self
.
_val
)
return
MultiField
(
self
.
_domain
,
val
)
return
func2
setattr
(
MultiField
,
op
,
func
(
op
))
for
op
in
[
"__mul__"
,
"__rmul__"
]:
def
func
(
op
):
def
func2
(
self
,
other
):
if
isinstance
(
other
,
MultiField
):
if
self
.
_domain
is
not
other
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
val
=
tuple
(
None
if
v1
is
None
or
v2
is
None
else
v1
*
v2
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
))
else
:
val
=
tuple
(
None
if
v1
is
None
else
(
v1
*
other
)
for
v1
in
self
.
_val
)
return
MultiField
(
self
.
_domain
,
val
)
return
func2
setattr
(
MultiField
,
op
,
func
(
op
))
for
op
in
[
"__sub__"
,
"__rsub__"
,
for
op
in
[
"__add__"
,
"__radd__"
,
"__sub__"
,
"__rsub__"
,
"__mul__"
,
"__rmul__"
,
"__div__"
,
"__rdiv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
...
...
nifty5/operator.py
0 → 100644
View file @
93c14275
from
__future__
import
absolute_import
,
division
,
print_function
import
abc
import
numpy
as
np
from
.compat
import
*
from
.utilities
import
NiftyMetaBase
#from ..domain_tuple import DomainTuple
#from ..multi.multi_domain import MultiDomain
from
.field
import
Field
from
.multi.multi_field
import
MultiField
from
.operators.scaling_operator
import
ScalingOperator
from
.operators.diagonal_operator
import
DiagonalOperator
class
Linearization
(
object
):
def
__init__
(
self
,
val
,
jac
):
self
.
_val
=
val
self
.
_jac
=
jac
@
property
def
domain
(
self
):
return
self
.
_jac
.
domain
@
property
def
target
(
self
):
return
self
.
_jac
.
target
@
property
def
val
(
self
):
return
self
.
_val
@
property
def
jac
(
self
):
return
self
.
_jac
def
__neg__
(
self
):
return
Linearization
(
-
self
.
_val
,
self
.
_jac
*
(
-
1
))
def
__add__
(
self
,
other
):
if
isinstance
(
other
,
Linearization
):
return
Linearization
(
self
.
_val
+
other
.
_val
,
self
.
_jac
+
other
.
_jac
)
if
isinstance
(
other
,
(
int
,
float
,
complex
,
Field
,
MultiField
)):
return
Linearization
(
self
.
_val
+
other
,
self
.
_jac
)
def
__radd__
(
self
,
other
):
return
self
.
__add__
(
other
)
def
__sub__
(
self
,
other
):
return
self
.
__add__
(
-
other
)
def
__rsub__
(
self
,
other
):
return
(
-
self
).
__add__
(
other
)
def
__mul__
(
self
,
other
):
if
isinstance
(
other
,
Linearization
):
d1
=
DiagonalOperator
(
self
.
_val
)
d2
=
DiagonalOperator
(
other
.
_val
)
return
Linearization
(
self
.
_val
*
other
.
_val
,
self
.
_jac
*
d2
+
d1
*
other
.
_jac
)
if
isinstance
(
other
,
(
int
,
float
,
complex
)):
#if other == 0:
# return ...
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
*
other
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
d2
=
DiagonalOperator
(
other
)
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
*
d2
)
raise
TypeError
def
__rmul__
(
self
,
other
):
if
isinstance
(
other
,
(
int
,
float
,
complex
)):
return
Linearization
(
self
.
_val
*
other
,
self
.
_jac
*
other
)
if
isinstance
(
other
,
(
Field
,
MultiField
)):
d1
=
DiagonalOperator
(
other
)
return
Linearization
(
self
.
_val
*
other
,
d1
*
self
.
_jac
)
@
staticmethod
def
make_var
(
field
):
return
Linearization
(
field
,
ScalingOperator
(
1.
,
field
.
domain
))
@
staticmethod
def
make_const
(
field
):
return
Linearization
(
field
,
ScalingOperator
(
0.
,
{}))
class
Operator
(
NiftyMetaBase
()):
"""Transforms values living on one domain into values living on another
domain, and can also provide the Jacobian.
"""
def
__call__
(
self
,
x
):
"""Returns transformed x
Parameters
----------
x : Linearization
input
Returns
-------
Linearization
output
"""
raise
NotImplementedError
nifty5/operators/central_zero_padder.py
View file @
93c14275
from
__future__
import
absolute_import
,
division
,
print_function
import
numpy
as
np
import
itertools
from
..compat
import
*
from
..
import
utilities
from
.linear_operator
import
LinearOperator
from
..domain_tuple
import
DomainTuple
...
...
test/test_operators/test_adjoint.py
View file @
93c14275
...
...
@@ -62,13 +62,13 @@ class Consistency_Tests(unittest.TestCase):
op
=
ift
.
SlopeOperator
(
dom
,
tgt
,
sig
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
@
expand
(
product
(
_h_spaces
+
_p_spaces
+
_pow_spaces
,
_h_spaces
+
_p_spaces
+
_pow_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testSelectionOperator
(
self
,
sp1
,
sp2
,
dtype
):
mdom
=
ift
.
MultiDomain
.
make
({
'a'
:
sp1
,
'b'
:
sp2
})
op
=
ift
.
SelectionOperator
(
mdom
,
'a'
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
#
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
#
_h_spaces + _p_spaces + _pow_spaces,
#
[np.float64, np.complex128]))
#
def testSelectionOperator(self, sp1, sp2, dtype):
#
mdom = ift.MultiDomain.make({'a': sp1, 'b': sp2})
#
op = ift.SelectionOperator(mdom, 'a')
#
ift.extra.consistency_check(op, dtype, dtype)
@
expand
(
product
(
_h_spaces
+
_p_spaces
+
_pow_spaces
,
[
np
.
float64
,
np
.
complex128
]))
...
...
@@ -80,20 +80,20 @@ class Consistency_Tests(unittest.TestCase):
ift
.
extra
.
consistency_check
(
op
.
inverse
.
adjoint
,
dtype
,
dtype
)
ift
.
extra
.
consistency_check
(
op
.
adjoint
.
inverse
,
dtype
,
dtype
)
@
expand
(
product
(
_h_spaces
+
_p_spaces
+
_pow_spaces
,
_h_spaces
+
_p_spaces
+
_pow_spaces
,
[
np
.
float64
,
np
.
complex128
]))
def
testNullOperator
(
self
,
sp1
,
sp2
,
dtype
):
op
=
ift
.
NullOperator
(
sp1
,
sp2
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
mdom1
=
ift
.
MultiDomain
.
make
({
'a'
:
sp1
})
mdom2
=
ift
.
MultiDomain
.
make
({
'b'
:
sp2
})
op
=
ift
.
NullOperator
(
mdom1
,
mdom2
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
op
=
ift
.
NullOperator
(
sp1
,
mdom2
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
op
=
ift
.
NullOperator
(
mdom1
,
sp2
)
ift
.
extra
.
consistency_check
(
op
,
dtype
,
dtype
)
#
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
#
_h_spaces + _p_spaces + _pow_spaces,
#
[np.float64, np.complex128]))
#
def testNullOperator(self, sp1, sp2, dtype):
#
op = ift.NullOperator(sp1, sp2)
#
ift.extra.consistency_check(op, dtype, dtype)
#
mdom1 = ift.MultiDomain.make({'a': sp1})
#
mdom2 = ift.MultiDomain.make({'b': sp2})
#
op = ift.NullOperator(mdom1, mdom2)
#
ift.extra.consistency_check(op, dtype, dtype)
#
op = ift.NullOperator(sp1, mdom2)
#
ift.extra.consistency_check(op, dtype, dtype)
#
op = ift.NullOperator(mdom1, sp2)
#
ift.extra.consistency_check(op, dtype, dtype)
@
expand
(
product
(
_h_spaces
+
_p_spaces
+
_pow_spaces
,
[
np
.
float64
,
np
.
complex128
]))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment