Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
4234baff
Commit
4234baff
authored
Apr 30, 2018
by
Martin Reinecke
Browse files
add chain and sum operators
parent
a0e5a346
Pipeline
#28378
failed with stages
in 2 minutes and 25 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty4/multi/__init__.py
View file @
4234baff
from
.multi_domain
import
MultiDomain
from
.multi_field
import
MultiField
from
.multi_linear_operator
import
MultiLinearOperator
from
.multi_chain_operator
import
MultiChainOperator
from
.multi_sum_operator
import
MultiSumOperator
__all__
=
[
"MultiDomain"
,
"MultiField"
,
"MultiLinearOperator"
]
__all__
=
[
"MultiDomain"
,
"MultiField"
,
"MultiLinearOperator"
,
"MultiChainOperator"
,
"MultiSumOperator"
]
nifty4/multi/multi_chain_operator.py
0 → 100644
View file @
4234baff
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
.multi_linear_operator
import
MultiLinearOperator
class
MultiChainOperator
(
MultiLinearOperator
):
"""Class representing chains of multi-operators."""
def
__init__
(
self
,
ops
,
_callingfrommake
=
False
):
print
"CHAINOP"
if
not
_callingfrommake
:
raise
NotImplementedError
super
(
MultiChainOperator
,
self
).
__init__
()
self
.
_ops
=
ops
self
.
_capability
=
self
.
_all_ops
for
op
in
ops
:
self
.
_capability
&=
op
.
capability
@
staticmethod
def
make
(
ops
):
ops
=
tuple
(
ops
)
if
len
(
ops
)
==
1
:
return
ops
[
0
]
return
MultiChainOperator
(
ops
,
_callingfrommake
=
True
)
@
property
def
domain
(
self
):
return
self
.
_ops
[
-
1
].
domain
@
property
def
target
(
self
):
return
self
.
_ops
[
0
].
target
def
_flip_modes
(
self
,
trafo
):
ADJ
=
self
.
ADJOINT_BIT
INV
=
self
.
INVERSE_BIT
if
trafo
==
0
:
return
self
if
trafo
==
ADJ
or
trafo
==
INV
:
return
self
.
make
([
op
.
_flip_modes
(
trafo
)
for
op
in
reversed
(
self
.
_ops
)])
if
trafo
==
ADJ
|
INV
:
return
self
.
make
([
op
.
_flip_modes
(
trafo
)
for
op
in
self
.
_ops
])
raise
ValueError
(
"invalid operator transformation"
)
@
property
def
capability
(
self
):
return
self
.
_capability
def
apply
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
t_ops
=
self
.
_ops
if
mode
&
self
.
_backwards
else
reversed
(
self
.
_ops
)
for
op
in
t_ops
:
x
=
op
.
apply
(
x
,
mode
)
return
x
nifty4/multi/multi_linear_operator.py
View file @
4234baff
...
...
@@ -2,4 +2,39 @@ from ..operators.linear_operator import LinearOperator
class
MultiLinearOperator
(
LinearOperator
):
pass
@
staticmethod
def
_toOperator
(
thing
,
dom
):
#from .multi_scaling_operator import ScalingOperator
if
isinstance
(
thing
,
MultiLinearOperator
):
return
thing
#if np.isscalar(thing):
# return MultiScalingOperator(thing, dom)
return
NotImplemented
def
__mul__
(
self
,
other
):
from
.multi_chain_operator
import
MultiChainOperator
other
=
self
.
_toOperator
(
other
,
self
.
domain
)
return
MultiChainOperator
.
make
([
self
,
other
])
def
__rmul__
(
self
,
other
):
from
.multi_chain_operator
import
MultiChainOperator
other
=
self
.
_toOperator
(
other
,
self
.
target
)
return
MultiChainOperator
.
make
([
other
,
self
])
def
__add__
(
self
,
other
):
from
.multi_sum_operator
import
MultiSumOperator
other
=
self
.
_toOperator
(
other
,
self
.
domain
)
return
MultiSumOperator
.
make
([
self
,
other
],
[
False
,
False
])
def
__radd__
(
self
,
other
):
return
self
.
__add__
(
other
)
def
__sub__
(
self
,
other
):
from
.multi_sum_operator
import
MultiSumOperator
other
=
self
.
_toOperator
(
other
,
self
.
domain
)
return
MultiSumOperator
.
make
([
self
,
other
],
[
False
,
True
])
def
__rsub__
(
self
,
other
):
from
.multi_sum_operator
import
MultiSumOperator
other
=
self
.
_toOperator
(
other
,
self
.
domain
)
return
MultiSumOperator
.
make
([
other
,
self
],
[
False
,
True
])
nifty4/multi/multi_sum_operator.py
0 → 100644
View file @
4234baff
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from
.multi_linear_operator
import
MultiLinearOperator
import
numpy
as
np
class
MultiSumOperator
(
MultiLinearOperator
):
"""Class representing sums of multi-operators."""
def
__init__
(
self
,
ops
,
neg
,
_callingfrommake
=
False
):
if
not
_callingfrommake
:
raise
NotImplementedError
super
(
MultiSumOperator
,
self
).
__init__
()
self
.
_ops
=
ops
self
.
_neg
=
neg
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
for
op
in
ops
:
self
.
_capability
&=
op
.
capability
@
staticmethod
def
make
(
ops
,
neg
):
ops
=
tuple
(
ops
)
neg
=
tuple
(
neg
)
if
len
(
ops
)
!=
len
(
neg
):
raise
ValueError
(
"length mismatch between ops and neg"
)
#ops, neg = MultiSumOperator.simplify(ops, neg)
if
len
(
ops
)
==
1
and
not
neg
[
0
]:
return
ops
[
0
]
return
MultiSumOperator
(
ops
,
neg
,
_callingfrommake
=
True
)
@
property
def
domain
(
self
):
return
self
.
_ops
[
0
].
domain
@
property
def
target
(
self
):
return
self
.
_ops
[
0
].
target
@
property
def
adjoint
(
self
):
return
self
.
make
([
op
.
adjoint
for
op
in
self
.
_ops
],
self
.
_neg
)
@
property
def
capability
(
self
):
return
self
.
_capability
def
apply
(
self
,
x
,
mode
):
self
.
_check_mode
(
mode
)
for
i
,
op
in
enumerate
(
self
.
_ops
):
if
i
==
0
:
res
=
-
op
.
apply
(
x
,
mode
)
if
self
.
_neg
[
i
]
else
op
.
apply
(
x
,
mode
)
else
:
if
self
.
_neg
[
i
]:
res
-=
op
.
apply
(
x
,
mode
)
else
:
res
+=
op
.
apply
(
x
,
mode
)
return
res
def
draw_sample
(
self
,
from_inverse
=
False
,
dtype
=
np
.
float64
):
if
from_inverse
:
raise
ValueError
(
"cannot draw from inverse of this operator"
)
res
=
self
.
_ops
[
0
].
draw_sample
(
from_inverse
,
dtype
)
for
op
in
self
.
_ops
[
1
:]:
res
+=
op
.
draw_sample
(
from_inverse
,
dtype
)
return
res
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment