Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Neel Shah
NIFTy
Commits
59bf1166
Commit
59bf1166
authored
May 21, 2018
by
Martin Reinecke
Browse files
fixes
parent
570f8d8a
Changes
3
Hide whitespace changes
Inline
Side-by-side
nifty4/multi/block_diagonal_operator.py
View file @
59bf1166
...
...
@@ -15,7 +15,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
"""
super
(
BlockDiagonalOperator
,
self
).
__init__
()
self
.
_operators
=
operators
self
.
_domain
=
MultiDomain
(
self
.
_domain
=
MultiDomain
.
make
(
{
key
:
op
.
domain
for
key
,
op
in
self
.
_operators
.
items
()})
self
.
_cap
=
self
.
_all_ops
for
op
in
self
.
_operators
.
values
():
...
...
@@ -43,12 +43,13 @@ class BlockDiagonalOperator(EndomorphicOperator):
res
=
{}
for
key
in
self
.
_operators
.
keys
():
res
[
key
]
=
self
.
_operators
[
key
]
*
op
.
_operators
[
key
]
return
res
return
BlockDiagonalOperator
(
res
)
def
_combine_sum
(
self
,
op
,
selfneg
,
opneg
):
from
..operators.sum_operator
import
SumOperator
res
=
{}
for
key
in
self
.
_operators
.
keys
():
res
[
key
]
=
SumOperator
.
make
([
self
.
_operators
[
key
],
op
.
_operators
[
key
]],
[
selfneg
,
opneg
])
return
res
return
BlockDiagonalOperator
(
res
)
nifty4/multi/multi_field.py
View file @
59bf1166
...
...
@@ -121,6 +121,15 @@ class MultiField(object):
return
MultiField
({
key
:
Field
.
full
(
dom
,
val
)
for
key
,
dom
in
domain
.
items
()})
def
to_global_data
(
self
):
return
{
key
:
val
.
to_global_data
()
for
key
,
val
in
self
.
_val
.
items
()}
@
staticmethod
def
from_global_data
(
domain
,
arr
,
sum_up
=
False
):
return
MultiField
({
key
:
Field
.
from_global_data
(
domain
[
key
],
val
,
sum_up
)
for
key
,
val
in
arr
.
items
()})
def
norm
(
self
):
""" Computes the L2-norm of the field values.
...
...
test/test_multi_field.py
View file @
59bf1166
...
...
@@ -32,7 +32,6 @@ class Test_Functionality(unittest.TestCase):
assert_allclose
(
f1
.
vdot
(
f2
),
np
.
conj
(
f2
.
vdot
(
f1
)))
def
test_lock
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
f1
=
ift
.
full
(
dom
,
27
)
assert_equal
(
f1
.
locked
,
False
)
f1
.
lock
()
...
...
@@ -42,13 +41,26 @@ class Test_Functionality(unittest.TestCase):
assert_equal
(
f1
.
locked_copy
()
is
f1
,
True
)
def
test_fill
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
f1
=
ift
.
full
(
s1
,
27
)
assert_equal
((
f1
.
fill
(
10
)
==
10
).
all
(),
True
)
f1
=
ift
.
full
(
dom
,
27
)
f1
.
fill
(
10
)
for
val
in
f1
.
values
():
assert_equal
((
val
==
10
).
all
(),
True
)
def
test_dataconv
(
self
):
s1
=
ift
.
RGSpace
((
10
,))
ld
=
np
.
arange
(
ift
.
dobj
.
local_shape
(
s1
.
shape
)[
0
])
gd
=
np
.
arange
(
s1
.
shape
[
0
])
assert_equal
(
ld
,
ift
.
from_local_data
(
s1
,
ld
).
local_data
)
assert_equal
(
gd
,
ift
.
from_global_data
(
s1
,
gd
).
to_global_data
())
f1
=
ift
.
full
(
dom
,
27
)
f2
=
ift
.
from_global_data
(
dom
,
f1
.
to_global_data
())
for
key
,
val
in
f1
.
items
():
assert_equal
(
val
.
local_data
,
f2
[
key
].
local_data
)
def
test_blockdiagonal
(
self
):
op
=
ift
.
BlockDiagonalOperator
({
"d1"
:
ift
.
ScalingOperator
(
20.
,
dom
[
"d1"
])})
op2
=
op
*
op
assert_equal
(
type
(
op2
),
ift
.
BlockDiagonalOperator
)
f1
=
op2
(
ift
.
full
(
dom
,
1
))
for
val
in
f1
.
values
():
assert_equal
((
val
==
400
).
all
(),
True
)
op2
=
op
+
op
assert_equal
(
type
(
op2
),
ift
.
BlockDiagonalOperator
)
f1
=
op2
(
ift
.
full
(
dom
,
1
))
for
val
in
f1
.
values
():
assert_equal
((
val
==
40
).
all
(),
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment