Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
17013e3d
Commit
17013e3d
authored
Jan 09, 2018
by
Martin Reinecke
Browse files
more SumOperator optimizations; new tests
parent
1509dfd8
Pipeline
#23528
passed with stage
in 4 minutes and 36 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty/operators/__init__.py
View file @
17013e3d
...
...
@@ -13,5 +13,6 @@ from .power_projection_operator import PowerProjectionOperator
from
.dof_projection_operator
import
DOFProjectionOperator
from
.chain_operator
import
ChainOperator
from
.sum_operator
import
SumOperator
from
.scaling_operator
import
ScalingOperator
from
.inverse_operator
import
InverseOperator
from
.adjoint_operator
import
AdjointOperator
nifty/operators/sum_operator.py
View file @
17013e3d
...
...
@@ -41,7 +41,7 @@ class SumOperator(LinearOperator):
# Step 2: unpack SumOperators
opsnew
=
[]
negnew
=
[]
for
op
,
ng
in
zip
(
ops
,
neg
):
for
op
,
ng
in
zip
(
ops
,
neg
):
if
isinstance
(
op
,
SumOperator
):
opsnew
+=
op
.
_ops
if
ng
:
...
...
@@ -81,14 +81,33 @@ class SumOperator(LinearOperator):
ops
=
opsnew
neg
=
negnew
# Step 4: combine DiagonalOperators where possible
# (TBD)
processed
=
[
False
]
*
len
(
ops
)
opsnew
=
[]
negnew
=
[]
for
i
in
range
(
len
(
ops
)):
if
not
processed
[
i
]:
if
isinstance
(
ops
[
i
],
DiagonalOperator
):
diag
=
ops
[
i
].
diagonal
()
*
(
-
1
if
neg
[
i
]
else
1
)
for
j
in
range
(
i
+
1
,
len
(
ops
)):
if
(
isinstance
(
ops
[
j
],
DiagonalOperator
)
and
ops
[
i
].
_spaces
==
ops
[
j
].
_spaces
):
diag
+=
ops
[
j
].
diagonal
()
*
(
-
1
if
neg
[
j
]
else
1
)
processed
[
j
]
=
True
opsnew
.
append
(
DiagonalOperator
(
diag
,
ops
[
i
].
domain
,
ops
[
i
].
_spaces
))
negnew
.
append
(
False
)
else
:
opsnew
.
append
(
ops
[
i
])
negnew
.
append
(
neg
[
i
])
ops
=
opsnew
neg
=
negnew
return
ops
,
neg
@
staticmethod
def
make
(
ops
,
neg
):
ops
=
tuple
(
ops
)
neg
=
tuple
(
neg
)
if
len
(
ops
)
!=
len
(
neg
):
if
len
(
ops
)
!=
len
(
neg
):
raise
ValueError
(
"length mismatch between ops and neg"
)
ops
,
neg
=
SumOperator
.
simplify
(
ops
,
neg
)
if
len
(
ops
)
==
1
and
not
neg
[
0
]:
...
...
test/test_operators/test_composed_operator.py
View file @
17013e3d
import
unittest
from
numpy.testing
import
assert_allclose
from
numpy.testing
import
assert_allclose
,
assert_equal
import
nifty2go
as
ift
from
test.common
import
generate_spaces
from
itertools
import
product
...
...
@@ -41,3 +41,23 @@ class ComposedOperator_Tests(unittest.TestCase):
assert_allclose
(
ift
.
dobj
.
to_global_data
(
tt1
.
val
),
ift
.
dobj
.
to_global_data
(
rand1
.
val
))
@
expand
(
product
(
spaces
))
def
test_sum
(
self
,
space
):
op1
=
ift
.
DiagonalOperator
(
ift
.
Field
(
space
,
2.
))
op2
=
ift
.
ScalingOperator
(
3.
,
space
)
full_op
=
op1
+
op2
-
(
op2
-
op1
)
+
op1
+
op1
+
op2
x
=
ift
.
Field
(
space
,
1.
)
res
=
full_op
(
x
)
assert_equal
(
isinstance
(
full_op
,
ift
.
DiagonalOperator
),
True
)
assert_allclose
(
ift
.
dobj
.
to_global_data
(
res
.
val
),
11.
)
@
expand
(
product
(
spaces
))
def
test_chain
(
self
,
space
):
op1
=
ift
.
DiagonalOperator
(
ift
.
Field
(
space
,
2.
))
op2
=
ift
.
ScalingOperator
(
3.
,
space
)
full_op
=
op1
*
op2
*
(
op2
*
op1
)
*
op1
*
op1
*
op2
x
=
ift
.
Field
(
space
,
1.
)
res
=
full_op
(
x
)
assert_equal
(
isinstance
(
full_op
,
ift
.
DiagonalOperator
),
True
)
assert_allclose
(
ift
.
dobj
.
to_global_data
(
res
.
val
),
432.
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment