Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
133bf484
Commit
133bf484
authored
Jan 17, 2019
by
Martin Reinecke
Browse files
fix interface of BlockDiagonalOperator
parent
3996cb92
Changes
4
Hide whitespace changes
Inline
Side-by-side
nifty5/linearization.py
View file @
133bf484
...
...
@@ -461,7 +461,7 @@ class Linearization(object):
if
len
(
constants
)
==
0
:
return
Linearization
.
make_var
(
field
,
want_metric
)
else
:
ops
=
[
ScalingOperator
(
0.
if
key
in
constants
else
1.
,
dom
)
for
key
,
dom
in
field
.
domain
.
items
()
]
bdop
=
BlockDiagonalOperator
(
field
.
domain
,
tuple
(
ops
)
)
ops
=
{
key
:
ScalingOperator
(
0.
if
key
in
constants
else
1.
,
dom
)
for
key
,
dom
in
field
.
domain
.
items
()
}
bdop
=
BlockDiagonalOperator
(
field
.
domain
,
ops
)
return
Linearization
(
field
,
bdop
,
want_metric
=
want_metric
)
nifty5/operators/block_diagonal_operator.py
View file @
133bf484
...
...
@@ -24,17 +24,16 @@ class BlockDiagonalOperator(EndomorphicOperator):
"""
Parameters
----------
domain : MultiDomain
Domain and target of the operator.
operators : dict
Dictionary with operators domain names as keys and LinearOperators as
items.
Dictionary with subdomain names as keys and LinearOperators as items.
"""
def
__init__
(
self
,
domain
,
operators
):
if
not
isinstance
(
domain
,
MultiDomain
):
raise
TypeError
(
"MultiDomain expected"
)
if
not
isinstance
(
operators
,
tuple
):
raise
TypeError
(
"tuple expected"
)
self
.
_domain
=
domain
self
.
_ops
=
operators
self
.
_ops
=
tuple
(
operators
[
key
]
for
key
in
domain
.
keys
())
self
.
_capability
=
self
.
_all_ops
for
op
in
self
.
_ops
:
if
op
is
not
None
:
...
...
@@ -55,13 +54,14 @@ class BlockDiagonalOperator(EndomorphicOperator):
def
_combine_chain
(
self
,
op
):
if
self
.
_domain
!=
op
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
res
=
tuple
(
v1
(
v2
)
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
))
res
=
{
key
:
v1
(
v2
)
for
key
,
v1
,
v2
in
zip
(
self
.
_domain
.
keys
(),
self
.
_ops
,
op
.
_ops
)}
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
from
..operators.sum_operator
import
SumOperator
if
self
.
_domain
!=
op
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
res
=
tuple
(
SumOperator
.
make
([
v1
,
v2
],
[
selfneg
,
opneg
])
for
v1
,
v2
in
zip
(
self
.
_ops
,
op
.
_ops
)
)
res
=
{
key
:
SumOperator
.
make
([
v1
,
v2
],
[
selfneg
,
opneg
])
for
key
,
v1
,
v2
in
zip
(
self
.
_domain
.
keys
(),
self
.
_ops
,
op
.
_ops
)
}
return
BlockDiagonalOperator
(
self
.
_domain
,
res
)
nifty5/sugar.py
View file @
133bf484
...
...
@@ -363,7 +363,7 @@ def makeOp(input):
return
DiagonalOperator
(
input
)
if
isinstance
(
input
,
MultiField
):
return
BlockDiagonalOperator
(
input
.
domain
,
tuple
(
makeOp
(
val
)
for
val
in
input
.
values
())
)
input
.
domain
,
{
key
:
makeOp
(
val
)
for
key
,
val
in
enumerate
(
input
)}
)
raise
NotImplementedError
...
...
test/test_multi_field.py
View file @
133bf484
...
...
@@ -43,7 +43,8 @@ def test_dataconv():
def
test_blockdiagonal
():
op
=
ift
.
BlockDiagonalOperator
(
dom
,
(
ift
.
ScalingOperator
(
20.
,
dom
[
"d1"
]),))
op
=
ift
.
BlockDiagonalOperator
(
dom
,
{
"d1"
:
ift
.
ScalingOperator
(
20.
,
dom
[
"d1"
])})
op2
=
op
(
op
)
ift
.
extra
.
consistency_check
(
op2
)
assert_equal
(
type
(
op2
),
ift
.
BlockDiagonalOperator
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment