Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
59d7def9
Commit
59d7def9
authored
Jul 10, 2018
by
Martin Reinecke
Browse files
intermediate state
parent
8bdf38d0
Changes
6
Show whitespace changes
Inline
Side-by-side
nifty5/field.py
View file @
59d7def9
...
...
@@ -606,11 +606,6 @@ class Field(object):
return
False
return
(
self
.
_val
==
other
.
_val
).
all
()
def
isSubsetOf
(
self
,
other
):
"""Identical to `Field.isEquivalentTo()`. This method is provided for
easier interoperability with `MultiField`."""
return
self
.
isEquivalentTo
(
other
)
for
op
in
[
"__add__"
,
"__radd__"
,
"__sub__"
,
"__rsub__"
,
...
...
nifty5/minimization/energy.py
View file @
59d7def9
...
...
@@ -170,8 +170,6 @@ class MetricInversionEnabler(Energy):
self
.
_preconditioner
=
preconditioner
def
at
(
self
,
position
):
if
self
.
_position
.
isSubsetOf
(
position
):
return
self
return
MetricInversionEnabler
(
self
.
_energy
.
at
(
position
),
self
.
_controller
,
self
.
_preconditioner
)
...
...
nifty5/multi/block_diagonal_operator.py
View file @
59d7def9
...
...
@@ -34,8 +34,12 @@ class BlockDiagonalOperator(EndomorphicOperator):
LinearOperators as items
"""
super
(
BlockDiagonalOperator
,
self
).
__init__
()
if
not
isinstance
(
domain
,
MultiDomain
):
raise
TypeError
(
"MultiDomain expected"
)
if
not
isinstance
(
operators
,
tuple
):
raise
TypeError
(
"tuple expected"
)
self
.
_domain
=
domain
self
.
_ops
=
tuple
(
operators
[
key
]
for
key
in
self
.
domain
.
keys
())
self
.
_ops
=
operators
self
.
_cap
=
self
.
_all_ops
for
op
in
self
.
_ops
:
if
op
is
not
None
:
...
...
@@ -64,15 +68,13 @@ class BlockDiagonalOperator(EndomorphicOperator):
def
_combine_chain
(
self
,
op
):
if
self
.
_domain
is
not
op
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
res
=
{
key
:
v1
*
v2
for
key
,
v1
,
v2
in
zip
(
self
.
_domain
.
keys
(),
self
.
_ops
,
op
.
_ops
)}
res
=
tuple
(
v1
*
v2
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
))
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
from
..operators.sum_operator
import
SumOperator
if
self
.
_domain
is
not
op
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
res
=
{}
for
key
,
v1
,
v2
in
zip
(
self
.
_domain
.
keys
(),
self
.
_ops
,
op
.
_ops
):
res
[
key
]
=
SumOperator
.
make
([
v1
,
v2
],
[
selfneg
,
opneg
])
res
=
tuple
(
SumOperator
.
make
([
v1
,
v2
],
[
selfneg
,
opneg
])
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
))
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
nifty5/multi/multi_field.py
View file @
59d7def9
...
...
@@ -191,26 +191,45 @@ class MultiField(object):
return
False
return
True
def
isSubsetOf
(
self
,
other
):
"""Determines (as quickly as possible) whether `self`'s content is
a subset of `other`'s content."""
if
self
is
other
:
return
True
if
not
isinstance
(
other
,
MultiField
):
return
False
if
len
(
set
(
self
.
_domain
.
keys
())
-
set
(
other
.
_domain
.
keys
()))
>
0
:
return
False
for
key
in
self
.
_domain
.
keys
():
if
other
.
_domain
[
key
]
is
not
self
.
_domain
[
key
]:
return
False
if
not
other
[
key
].
isSubsetOf
(
self
[
key
]):
return
False
return
True
for
op
in
[
"__add__"
,
"__radd__"
]:
def
func
(
op
):
def
func2
(
self
,
other
):
if
isinstance
(
other
,
MultiField
):
if
self
.
_domain
is
not
other
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
val
=
[]
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
):
if
v1
is
not
None
:
val
.
append
(
v1
if
v2
is
None
else
(
v1
+
v2
))
else
:
val
.
append
(
None
if
v2
is
None
else
v2
)
val
=
tuple
(
val
)
else
:
val
=
tuple
(
other
if
v1
is
None
else
(
v1
+
other
)
for
v1
in
self
.
_val
)
return
MultiField
(
self
.
_domain
,
val
)
return
func2
setattr
(
MultiField
,
op
,
func
(
op
))
for
op
in
[
"__mul__"
,
"__rmul__"
]:
def
func
(
op
):
def
func2
(
self
,
other
):
if
isinstance
(
other
,
MultiField
):
if
self
.
_domain
is
not
other
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
val
=
tuple
(
None
if
v1
is
None
or
v2
is
None
else
v1
*
v2
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
))
else
:
val
=
tuple
(
None
if
v1
is
None
else
(
v1
*
other
)
for
v1
in
self
.
_val
)
return
MultiField
(
self
.
_domain
,
val
)
return
func2
setattr
(
MultiField
,
op
,
func
(
op
))
for
op
in
[
"__add__"
,
"__radd__"
,
"__sub__"
,
"__rsub__"
,
"__mul__"
,
"__rmul__"
,
for
op
in
[
"__sub__"
,
"__rsub__"
,
"__div__"
,
"__rdiv__"
,
"__truediv__"
,
"__rtruediv__"
,
"__floordiv__"
,
"__rfloordiv__"
,
...
...
@@ -218,27 +237,18 @@ for op in ["__add__", "__radd__",
"__lt__"
,
"__le__"
,
"__gt__"
,
"__ge__"
,
"__eq__"
,
"__ne__"
]:
def
func
(
op
):
def
func2
(
self
,
other
):
res
=
[]
if
isinstance
(
other
,
MultiField
):
if
self
.
_domain
is
not
other
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
):
if
v1
is
not
None
:
if
v2
is
None
:
res
.
append
(
getattr
(
v1
,
op
)(
v1
*
0
))
val
=
tuple
(
getattr
(
v1
,
op
)(
v2
)
for
v1
,
v2
in
zip
(
self
.
_val
,
other
.
_val
))
else
:
res
.
append
(
getattr
(
v1
,
op
)(
v2
))
else
:
if
v2
is
None
:
res
.
append
(
None
)
else
:
res
.
append
(
getattr
(
v2
*
0
,
op
)(
v2
))
return
MultiField
(
self
.
_domain
,
tuple
(
res
))
else
:
return
self
.
_transform
(
lambda
x
:
getattr
(
x
,
op
)(
other
))
val
=
tuple
(
getattr
(
v1
,
op
)(
other
)
for
v1
in
self
.
_val
)
return
MultiField
(
self
.
_domain
,
val
)
return
func2
setattr
(
MultiField
,
op
,
func
(
op
))
for
op
in
[
"__iadd__"
,
"__isub__"
,
"__imul__"
,
"__idiv__"
,
"__itruediv__"
,
"__ifloordiv__"
,
"__ipow__"
]:
def
func
(
op
):
...
...
nifty5/sugar.py
View file @
59d7def9
...
...
@@ -236,7 +236,7 @@ def makeOp(input):
return
DiagonalOperator
(
input
)
if
isinstance
(
input
,
MultiField
):
return
BlockDiagonalOperator
(
input
.
domain
,
{
key
:
makeOp
(
val
)
for
key
,
val
in
input
.
item
s
()
}
)
input
.
domain
,
tuple
(
makeOp
(
val
)
for
val
in
input
.
value
s
()
)
)
raise
NotImplementedError
# Arithmetic functions working on Fields
...
...
test/test_multi_field.py
View file @
59d7def9
...
...
@@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase):
def
test_blockdiagonal
(
self
):
op
=
ift
.
BlockDiagonalOperator
(
dom
,
{
"d1"
:
ift
.
ScalingOperator
(
20.
,
dom
[
"d1"
])
}
)
dom
,
(
ift
.
ScalingOperator
(
20.
,
dom
[
"d1"
])
,)
)
op2
=
op
*
op
ift
.
extra
.
consistency_check
(
op2
)
assert_equal
(
type
(
op2
),
ift
.
BlockDiagonalOperator
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment