Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
abc660d4
Commit
abc660d4
authored
Mar 04, 2019
by
Martin Reinecke
Browse files
Merge branch 'simplify_for_const' into 'NIFTy_5'
Simplify for const See merge request
!295
parents
11f686dd
61c290f5
Pipeline
#44819
passed with stages
in 19 minutes and 22 seconds
Changes
8
Pipelines
4
Hide whitespace changes
Inline
Side-by-side
nifty5/field.py
View file @
abc660d4
...
...
@@ -626,6 +626,11 @@ class Field(object):
raise
ValueError
(
"domain mismatch"
)
return
self
def
extract_part
(
self
,
dom
):
if
dom
!=
self
.
_domain
:
raise
ValueError
(
"domain mismatch"
)
return
self
def
unite
(
self
,
other
):
return
self
+
other
...
...
nifty5/multi_field.py
View file @
abc660d4
...
...
@@ -217,6 +217,12 @@ class MultiField(object):
return
MultiField
(
subset
,
tuple
(
self
[
key
]
for
key
in
subset
.
keys
()))
def
extract_part
(
self
,
subset
):
if
subset
is
self
.
_domain
:
return
self
return
MultiField
.
from_dict
({
key
:
self
[
key
]
for
key
in
subset
.
keys
()
if
key
in
self
})
def
unite
(
self
,
other
):
"""Merges two MultiFields on potentially different MultiDomains.
...
...
nifty5/operators/chain_operator.py
View file @
abc660d4
...
...
@@ -138,6 +138,17 @@ class ChainOperator(LinearOperator):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
return
"ChainOperator:
\n
"
+
utilities
.
indent
(
subs
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
return
None
,
self
newop
=
None
for
op
in
reversed
(
self
.
_ops
):
c_inp
,
t_op
=
op
.
simplify_for_constant_input
(
c_inp
)
newop
=
t_op
if
newop
is
None
else
op
(
newop
)
return
c_inp
,
newop
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# from ..sugar import from_random
# if len(self._ops) == 1:
...
...
nifty5/operators/operator.py
View file @
abc660d4
...
...
@@ -146,6 +146,17 @@ class Operator(metaclass=NiftyMeta):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
def
simplify_for_constant_input
(
self
,
c_inp
):
if
c_inp
is
None
:
return
None
,
self
if
c_inp
.
domain
==
self
.
domain
:
op
=
_ConstantOperator
(
self
.
domain
,
self
(
c_inp
))
return
op
(
c_inp
),
op
return
self
.
_simplify_for_constant_input_nontrivial
(
c_inp
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
return
None
,
self
for
f
in
[
"sqrt"
,
"exp"
,
"log"
,
"tanh"
,
"sigmoid"
,
'sin'
,
'cos'
,
'tan'
,
'sinh'
,
'cosh'
,
'absolute'
,
'sinc'
,
'one_over'
]:
...
...
@@ -157,6 +168,72 @@ for f in ["sqrt", "exp", "log", "tanh", "sigmoid", 'sin', 'cos', 'tan',
setattr
(
Operator
,
f
,
func
(
f
))
class
_ConstCollector
(
object
):
def
__init__
(
self
):
self
.
_const
=
None
self
.
_nc
=
set
()
def
mult
(
self
,
const
,
fulldom
):
if
const
is
None
:
self
.
_nc
|=
set
(
fulldom
)
else
:
self
.
_nc
|=
set
(
fulldom
)
-
set
(
const
)
if
self
.
_const
is
None
:
from
..multi_field
import
MultiField
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
const
[
key
]
for
key
in
const
if
key
not
in
self
.
_nc
})
else
:
from
..multi_field
import
MultiField
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
self
.
_const
[
key
]
*
const
[
key
]
for
key
in
const
if
key
not
in
self
.
_nc
})
def
add
(
self
,
const
,
fulldom
):
if
const
is
None
:
self
.
_nc
|=
set
(
fulldom
.
keys
())
else
:
from
..multi_field
import
MultiField
self
.
_nc
|=
set
(
fulldom
.
keys
())
-
set
(
const
.
keys
())
if
self
.
_const
is
None
:
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
const
[
key
]
for
key
in
const
.
keys
()
if
key
not
in
self
.
_nc
})
else
:
self
.
_const
=
self
.
_const
.
unite
(
const
)
self
.
_const
=
MultiField
.
from_dict
(
{
key
:
self
.
_const
[
key
]
for
key
in
self
.
_const
if
key
not
in
self
.
_nc
})
@
property
def
constfield
(
self
):
return
self
.
_const
class
_ConstantOperator
(
Operator
):
def
__init__
(
self
,
dom
,
output
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
dom
)
self
.
_target
=
output
.
domain
self
.
_output
=
output
def
apply
(
self
,
x
):
from
..linearization
import
Linearization
from
.simple_linear_operators
import
NullOperator
from
..domain_tuple
import
DomainTuple
self
.
_check_input
(
x
)
if
not
isinstance
(
x
,
Linearization
):
return
self
.
_output
if
x
.
want_metric
and
self
.
_target
is
DomainTuple
.
scalar_domain
():
met
=
NullOperator
(
self
.
_domain
,
self
.
_domain
)
else
:
met
=
None
return
x
.
new
(
self
.
_output
,
NullOperator
(
self
.
_domain
,
self
.
_target
),
met
)
def
__repr__
(
self
):
return
'ConstantOperator <- {}'
.
format
(
self
.
domain
.
keys
())
class
_FunctionApplier
(
Operator
):
def
__init__
(
self
,
domain
,
funcname
):
from
..sugar
import
makeDomain
...
...
@@ -229,6 +306,17 @@ class _OpChain(_CombinedOperator):
x
=
op
(
x
)
return
x
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_domain
,
MultiDomain
):
return
None
,
self
newop
=
None
for
op
in
reversed
(
self
.
_ops
):
c_inp
,
t_op
=
op
.
simplify_for_constant_input
(
c_inp
)
newop
=
t_op
if
newop
is
None
else
op
(
newop
)
return
c_inp
,
newop
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
return
"_OpChain:
\n
"
+
indent
(
subs
)
...
...
@@ -261,6 +349,21 @@ class _OpProd(Operator):
makeOp
(
lin2
.
_val
)(
lin1
.
_jac
),
False
)
return
lin1
.
new
(
lin1
.
_val
*
lin2
.
_val
,
op
(
x
.
jac
))
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
return
None
,
_OpProd
(
o1
,
o2
)
cc
=
_ConstCollector
()
cc
.
mult
(
f1
,
o1
.
target
)
cc
.
mult
(
f2
,
o2
.
target
)
return
cc
.
constfield
,
_OpProd
(
o1
,
o2
)
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
return
"_OpProd:
\n
"
+
indent
(
subs
)
...
...
@@ -293,6 +396,21 @@ class _OpSum(Operator):
res
=
res
.
add_metric
(
lin1
.
_metric
+
lin2
.
_metric
)
return
res
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
f1
,
o1
=
self
.
_op1
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op1
.
domain
))
f2
,
o2
=
self
.
_op2
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
self
.
_op2
.
domain
))
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
return
None
,
_OpSum
(
o1
,
o2
)
cc
=
_ConstCollector
()
cc
.
add
(
f1
,
o1
.
target
)
cc
.
add
(
f2
,
o2
.
target
)
return
cc
.
constfield
,
_OpSum
(
o1
,
o2
)
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
(
self
.
_op1
,
self
.
_op2
))
return
"_OpSum:
\n
"
+
indent
(
subs
)
nifty5/operators/scaling_operator.py
View file @
abc660d4
...
...
@@ -35,14 +35,6 @@ class ScalingOperator(EndomorphicOperator):
-----
:class:`Operator` supports the multiplication with a scalar. So one does
not need instantiate :class:`ScalingOperator` explicitly in most cases.
Formally, this operator always supports all operation modes (times,
adjoint_times, inverse_times and inverse_adjoint_times), even if `factor`
is 0 or infinity. It is the user's responsibility to apply the operator
only in appropriate ways (e.g. call inverse_times only if `factor` is
nonzero).
This shortcoming will hopefully be fixed in the future.
"""
def
__init__
(
self
,
factor
,
domain
):
...
...
@@ -52,7 +44,10 @@ class ScalingOperator(EndomorphicOperator):
raise
TypeError
(
"Scalar required"
)
self
.
_factor
=
factor
self
.
_domain
=
makeDomain
(
domain
)
self
.
_capability
=
self
.
_all_ops
if
self
.
_factor
==
0.
:
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
else
:
self
.
_capability
=
self
.
_all_ops
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
...
...
nifty5/operators/simple_linear_operators.py
View file @
abc660d4
...
...
@@ -315,3 +315,23 @@ class NullOperator(LinearOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
return
self
.
_nullfield
(
self
.
_tgt
(
mode
))
class
_PartialExtractor
(
LinearOperator
):
def
__init__
(
self
,
domain
,
target
):
if
not
isinstance
(
domain
,
MultiDomain
):
raise
TypeError
(
"MultiDomain expected"
)
if
not
isinstance
(
target
,
MultiDomain
):
raise
TypeError
(
"MultiDomain expected"
)
self
.
_domain
=
domain
self
.
_target
=
target
for
key
in
self
.
_target
.
keys
():
if
not
(
self
.
_domain
[
key
]
is
not
self
.
_target
[
key
]):
raise
ValueError
(
"domain mismatch"
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
return
x
.
extract
(
self
.
_target
)
return
MultiField
.
from_dict
({
key
:
x
[
key
]
for
key
in
x
.
domain
.
keys
()})
nifty5/operators/sum_operator.py
View file @
abc660d4
...
...
@@ -23,6 +23,7 @@ from ..sugar import domain_union
from
..utilities
import
indent
from
.block_diagonal_operator
import
BlockDiagonalOperator
from
.linear_operator
import
LinearOperator
from
.simple_linear_operators
import
NullOperator
class
SumOperator
(
LinearOperator
):
...
...
@@ -59,6 +60,9 @@ class SumOperator(LinearOperator):
negnew
+=
[
not
n
for
n
in
op
.
_neg
]
else
:
negnew
+=
list
(
op
.
_neg
)
# FIXME: this needs some more work to keep the domain and target unchanged!
# elif isinstance(op, NullOperator):
# pass
else
:
opsnew
.
append
(
op
)
negnew
.
append
(
ng
)
...
...
@@ -193,6 +197,9 @@ class SumOperator(LinearOperator):
"cannot draw from inverse of this operator"
)
res
=
None
for
op
in
self
.
_ops
:
from
.simple_linear_operators
import
NullOperator
if
isinstance
(
op
,
NullOperator
):
continue
tmp
=
op
.
draw_sample
(
from_inverse
,
dtype
)
res
=
tmp
if
res
is
None
else
res
.
unite
(
tmp
)
return
res
...
...
@@ -200,3 +207,29 @@ class SumOperator(LinearOperator):
def
__repr__
(
self
):
subs
=
"
\n
"
.
join
(
sub
.
__repr__
()
for
sub
in
self
.
_ops
)
return
"SumOperator:
\n
"
+
indent
(
subs
)
def
_simplify_for_constant_input_nontrivial
(
self
,
c_inp
):
f
=
[]
o
=
[]
for
op
in
self
.
_ops
:
tf
,
to
=
op
.
simplify_for_constant_input
(
c_inp
.
extract_part
(
op
.
domain
))
f
.
append
(
tf
)
o
.
append
(
to
)
from
..multi_domain
import
MultiDomain
if
not
isinstance
(
self
.
_target
,
MultiDomain
):
fullop
=
None
for
to
,
n
in
zip
(
o
,
self
.
_neg
):
op
=
to
if
not
n
else
-
to
fullop
=
op
if
fullop
is
None
else
fullop
+
op
return
None
,
fullop
from
.operator
import
_ConstCollector
cc
=
_ConstCollector
()
fullop
=
None
for
tf
,
to
,
n
in
zip
(
f
,
o
,
self
.
_neg
):
cc
.
add
(
tf
,
to
.
target
)
op
=
to
if
not
n
else
-
to
fullop
=
op
if
fullop
is
None
else
fullop
+
op
return
cc
.
constfield
,
fullop
test/test_operators/test_simplification.py
0 → 100644
View file @
abc660d4
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
pytest
from
numpy.testing
import
assert_allclose
,
assert_equal
import
nifty5
as
ift
def
test_simplification
():
from
nifty5.operators.operator
import
_ConstantOperator
f1
=
ift
.
Field
.
full
(
ift
.
RGSpace
(
10
),
2.
)
op
=
ift
.
FFTOperator
(
f1
.
domain
)
_
,
op2
=
op
.
simplify_for_constant_input
(
f1
)
assert_equal
(
isinstance
(
op2
,
_ConstantOperator
),
True
)
assert_allclose
(
op
(
f1
).
local_data
,
op2
(
f1
).
local_data
)
dom
=
{
"a"
:
ift
.
RGSpace
(
10
)}
f1
=
ift
.
full
(
dom
,
2.
)
op
=
ift
.
FFTOperator
(
f1
.
domain
[
"a"
]).
ducktape
(
"a"
)
_
,
op2
=
op
.
simplify_for_constant_input
(
f1
)
assert_equal
(
isinstance
(
op2
,
_ConstantOperator
),
True
)
assert_allclose
(
op
(
f1
).
local_data
,
op2
(
f1
).
local_data
)
dom
=
{
"a"
:
ift
.
RGSpace
(
10
),
"b"
:
ift
.
RGSpace
(
5
)}
f1
=
ift
.
full
(
dom
,
2.
)
pdom
=
{
"a"
:
ift
.
RGSpace
(
10
)}
f2
=
ift
.
full
(
pdom
,
2.
)
o1
=
ift
.
FFTOperator
(
f1
.
domain
[
"a"
])
o2
=
ift
.
FFTOperator
(
f1
.
domain
[
"b"
])
op
=
(
o1
.
ducktape
(
"a"
).
ducktape_left
(
"a"
)
+
o2
.
ducktape
(
"b"
).
ducktape_left
(
"b"
))
_
,
op2
=
op
.
simplify_for_constant_input
(
f2
)
assert_equal
(
isinstance
(
op2
.
_op1
,
_ConstantOperator
),
True
)
assert_allclose
(
op
(
f1
)[
"a"
].
local_data
,
op2
(
f1
)[
"a"
].
local_data
)
assert_allclose
(
op
(
f1
)[
"b"
].
local_data
,
op2
(
f1
)[
"b"
].
local_data
)
lin
=
ift
.
Linearization
.
make_var
(
ift
.
MultiField
.
full
(
op2
.
domain
,
2.
),
True
)
assert_allclose
(
op
(
lin
).
val
[
"a"
].
local_data
,
op2
(
lin
).
val
[
"a"
].
local_data
)
assert_allclose
(
op
(
lin
).
val
[
"b"
].
local_data
,
op2
(
lin
).
val
[
"b"
].
local_data
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment