Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
59d7def9
Commit
59d7def9
authored
Jul 10, 2018
by
Martin Reinecke
Browse files
intermediate state
parent
8bdf38d0
Changes
6
Show whitespace changes
Inline
Side-by-side
nifty5/field.py
View file @
59d7def9
...
@@ -606,11 +606,6 @@ class Field(object):
...
@@ -606,11 +606,6 @@ class Field(object):
return
False
return
False
return
(
self
.
_val
==
other
.
_val
).
all
()
return
(
self
.
_val
==
other
.
_val
).
all
()
def
isSubsetOf
(
self
,
other
):
"""Identical to `Field.isEquivalentTo()`. This method is provided for
easier interoperability with `MultiField`."""
return
self
.
isEquivalentTo
(
other
)
for
op
in
[
"__add__"
,
"__radd__"
,
for
op
in
[
"__add__"
,
"__radd__"
,
"__sub__"
,
"__rsub__"
,
"__sub__"
,
"__rsub__"
,
...
...
nifty5/minimization/energy.py
View file @
59d7def9
...
@@ -170,8 +170,6 @@ class MetricInversionEnabler(Energy):
...
@@ -170,8 +170,6 @@ class MetricInversionEnabler(Energy):
self
.
_preconditioner
=
preconditioner
self
.
_preconditioner
=
preconditioner
def
at
(
self
,
position
):
def
at
(
self
,
position
):
if
self
.
_position
.
isSubsetOf
(
position
):
return
self
return
MetricInversionEnabler
(
return
MetricInversionEnabler
(
self
.
_energy
.
at
(
position
),
self
.
_controller
,
self
.
_preconditioner
)
self
.
_energy
.
at
(
position
),
self
.
_controller
,
self
.
_preconditioner
)
...
...
nifty5/multi/block_diagonal_operator.py
View file @
59d7def9
...
@@ -34,8 +34,12 @@ class BlockDiagonalOperator(EndomorphicOperator):
...
@@ -34,8 +34,12 @@ class BlockDiagonalOperator(EndomorphicOperator):
LinearOperators as items
LinearOperators as items
"""
"""
super
(
BlockDiagonalOperator
,
self
).
__init__
()
super
(
BlockDiagonalOperator
,
self
).
__init__
()
if
not
isinstance
(
domain
,
MultiDomain
):
raise
TypeError
(
"MultiDomain expected"
)
if
not
isinstance
(
operators
,
tuple
):
raise
TypeError
(
"tuple expected"
)
self
.
_domain
=
domain
self
.
_domain
=
domain
self
.
_ops
=
tuple
(
operators
[
key
]
for
key
in
self
.
domain
.
keys
())
self
.
_ops
=
operators
self
.
_cap
=
self
.
_all_ops
self
.
_cap
=
self
.
_all_ops
for
op
in
self
.
_ops
:
for
op
in
self
.
_ops
:
if
op
is
not
None
:
if
op
is
not
None
:
...
@@ -64,15 +68,13 @@ class BlockDiagonalOperator(EndomorphicOperator):
...
@@ -64,15 +68,13 @@ class BlockDiagonalOperator(EndomorphicOperator):
def
_combine_chain
(
self
,
op
):
def
_combine_chain
(
self
,
op
):
if
self
.
_domain
is
not
op
.
_domain
:
if
self
.
_domain
is
not
op
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
raise
ValueError
(
"domain mismatch"
)
res
=
{
key
:
v1
*
v2
res
=
tuple
(
v1
*
v2
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
))
for
key
,
v1
,
v2
in
zip
(
self
.
_domain
.
keys
(),
self
.
_ops
,
op
.
_ops
)}
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
from
..operators.sum_operator
import
SumOperator
from
..operators.sum_operator
import
SumOperator
if
self
.
_domain
is
not
op
.
_domain
:
if
self
.
_domain
is
not
op
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
raise
ValueError
(
"domain mismatch"
)
res
=
{}
res
=
tuple
(
SumOperator
.
make
([
v1
,
v2
],
[
selfneg
,
opneg
])
for
key
,
v1
,
v2
in
zip
(
self
.
_domain
.
keys
(),
self
.
_ops
,
op
.
_ops
):
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
))
res
[
key
]
=
SumOperator
.
make
([
v1
,
v2
],
[
selfneg
,
opneg
])
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
nifty5/multi/multi_field.py
View file @
59d7def9
...
@@ -191,26 +191,45 @@ class MultiField(object):
...
@@ -191,26 +191,45 @@ class MultiField(object):
return
False
return
False
return
True
return
True
def
isSubsetOf
(
self
,
other
):
"""Determines (as quickly as possible) whether `self`'s content is
for
op
in
[
"__add__"
,
"__radd__"
]:
a subset of `other`'s content."""
def
func
(
op
):
if
self
is
other
:
def
func2
(
self
,
other
):
return
True
if
isinstance
(
other
,
MultiField
):
if
not
isinstance
(
other
,
MultiField
):
if
self
.
_domain
is
not
other
.
_domain
:
return
False
raise
ValueError
(
"domain mismatch"
)
if
len
(
set
(
self
.
_domain
.
keys
())
-
set
(
other
.
_domain
.
keys
()))
>
0
:
val
=
[]
return
False
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
):
for
key
in
self
.
_domain
.
keys
():
if
v1
is
not
None
:
if
other
.
_domain
[
key
]
is
not
self
.
_domain
[
key
]:
val
.
append
(
v1
if
v2
is
None
else
(
v1
+
v2
))
return
False
else
:
if
not
other
[
key
].
isSubsetOf
(
self
[
key
]):
val
.
append
(
None
if
v2
is
None
else
v2
)
return
False
val
=
tuple
(
val
)
return
True
else
:
val
=
tuple
(
other
if
v1
is
None
else
(
v1
+
other
)
for
v1
in
self
.
_val
)
return
MultiField
(
self
.
_domain
,
val
)
return
func2
setattr
(
MultiField
,
op
,
func
(
op
))
for
op
in
[
"__mul__"
,
"__rmul__"
]:
def
func
(
op
):
def
func2
(
self
,
other
):
if
isinstance
(
other
,
MultiField
):
if
self
.
_domain
is
not
other
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
val
=
tuple
(
None
if
v1
is
None
or
v2
is
None
else
v1
*
v2
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
))
else
:
val
=
tuple
(
None
if
v1
is
None
else
(
v1
*
other
)
for
v1
in
self
.
_val
)
return
MultiField
(
self
.
_domain
,
val
)
return
func2
setattr
(
MultiField
,
op
,
func
(
op
))
for
op
in
[
"__add__"
,
"__radd__"
,
for
op
in
[
"__sub__"
,
"__rsub__"
,
"__sub__"
,
"__rsub__"
,
"__mul__"
,
"__rmul__"
,
"__div__"
,
"__rdiv__"
,
"__div__"
,
"__rdiv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
...
@@ -218,27 +237,18 @@ for op in ["__add__", "__radd__",
...
@@ -218,27 +237,18 @@ for op in ["__add__", "__radd__",
"__lt__"
,
"__le__"
,
"__gt__"
,
"__ge__"
,
"__eq__"
,
"__ne__"
]:
"__lt__"
,
"__le__"
,
"__gt__"
,
"__ge__"
,
"__eq__"
,
"__ne__"
]:
def
func
(
op
):
def
func
(
op
):
def
func2
(
self
,
other
):
def
func2
(
self
,
other
):
res
=
[]
if
isinstance
(
other
,
MultiField
):
if
isinstance
(
other
,
MultiField
):
if
self
.
_domain
is
not
other
.
_domain
:
if
self
.
_domain
is
not
other
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
raise
ValueError
(
"domain mismatch"
)
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
):
val
=
tuple
(
getattr
(
v1
,
op
)(
v2
)
if
v1
is
not
None
:
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
))
if
v2
is
None
:
res
.
append
(
getattr
(
v1
,
op
)(
v1
*
0
))
else
:
else
:
res
.
append
(
getattr
(
v1
,
op
)(
v2
))
val
=
tuple
(
getattr
(
v1
,
op
)(
other
)
for
v1
in
self
.
_val
)
else
:
return
MultiField
(
self
.
_domain
,
val
)
if
v2
is
None
:
res
.
append
(
None
)
else
:
res
.
append
(
getattr
(
v2
*
0
,
op
)(
v2
))
return
MultiField
(
self
.
_domain
,
tuple
(
res
))
else
:
return
self
.
_transform
(
lambda
x
:
getattr
(
x
,
op
)(
other
))
return
func2
return
func2
setattr
(
MultiField
,
op
,
func
(
op
))
setattr
(
MultiField
,
op
,
func
(
op
))
for
op
in
[
"__iadd__"
,
"__isub__"
,
"__imul__"
,
"__idiv__"
,
for
op
in
[
"__iadd__"
,
"__isub__"
,
"__imul__"
,
"__idiv__"
,
"__itruediv__"
,
"__ifloordiv__"
,
"__ipow__"
]:
"__itruediv__"
,
"__ifloordiv__"
,
"__ipow__"
]:
def
func
(
op
):
def
func
(
op
):
...
...
nifty5/sugar.py
View file @
59d7def9
...
@@ -236,7 +236,7 @@ def makeOp(input):
...
@@ -236,7 +236,7 @@ def makeOp(input):
return
DiagonalOperator
(
input
)
return
DiagonalOperator
(
input
)
if
isinstance
(
input
,
MultiField
):
if
isinstance
(
input
,
MultiField
):
return
BlockDiagonalOperator
(
return
BlockDiagonalOperator
(
input
.
domain
,
{
key
:
makeOp
(
val
)
for
key
,
val
in
input
.
item
s
()
}
)
input
.
domain
,
tuple
(
makeOp
(
val
)
for
val
in
input
.
value
s
()
)
)
raise
NotImplementedError
raise
NotImplementedError
# Arithmetic functions working on Fields
# Arithmetic functions working on Fields
...
...
test/test_multi_field.py
View file @
59d7def9
...
@@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase):
...
@@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase):
def
test_blockdiagonal
(
self
):
def
test_blockdiagonal
(
self
):
op
=
ift
.
BlockDiagonalOperator
(
op
=
ift
.
BlockDiagonalOperator
(
dom
,
{
"d1"
:
ift
.
ScalingOperator
(
20.
,
dom
[
"d1"
])
}
)
dom
,
(
ift
.
ScalingOperator
(
20.
,
dom
[
"d1"
])
,)
)
op2
=
op
*
op
op2
=
op
*
op
ift
.
extra
.
consistency_check
(
op2
)
ift
.
extra
.
consistency_check
(
op2
)
assert_equal
(
type
(
op2
),
ift
.
BlockDiagonalOperator
)
assert_equal
(
type
(
op2
),
ift
.
BlockDiagonalOperator
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment