Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
395ab44f
Commit
395ab44f
authored
Apr 03, 2020
by
Martin Reinecke
Browse files
tweaks
parent
143c96ec
Pipeline
#72186
passed with stages
in 15 minutes and 4 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/linearization.py
View file @
395ab44f
...
...
@@ -379,7 +379,7 @@ class Linearization(object):
def
one_over
(
self
):
tmp
=
1.
/
self
.
_val
tmp2
=
-
tmp
/
self
.
_val
tmp2
=
-
tmp
*
tmp
return
self
.
new
(
tmp
,
makeOp
(
tmp2
)(
self
.
_jac
))
def
add_metric
(
self
,
metric
):
...
...
nifty6/
operators/
operator_tree_optimiser.py
→
nifty6/operator_tree_optimiser.py
View file @
395ab44f
...
...
@@ -15,11 +15,11 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
nifty6
.operators.operator
import
_OpChain
,
_OpSum
,
_OpProd
from
nifty6
.sugar
import
domain_union
from
nifty6
import
FieldAdapter
from
.operators.operator
import
_OpChain
,
_OpSum
,
_OpProd
from
.sugar
import
domain_union
from
.operators.simple_linear_operators
import
FieldAdapter
def
optimize_operator
(
op
):
def
_
optimize_operator
(
op
):
"""
Optimizes operator trees, so that same operator subtrees are not computed twice.
Recognizes same subtrees and replaces them at nodes.
...
...
@@ -33,13 +33,14 @@ def optimize_operator(op):
# Format: [parent_index, left]
leaves
=
set
()
def
isnode
(
op
):
return
isinstance
(
op
,
_OpSum
)
or
isinstance
(
op
,
_OpProd
)
return
isinstance
(
op
,
(
_OpSum
,
_OpProd
))
def
left_parser
(
left_bool
):
if
left_bool
:
return
'_op1'
return
'_op2'
return
'_op1'
if
left_bool
else
'_op2'
def
rebuild_domains
(
index
):
"""Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
...
...
@@ -60,22 +61,22 @@ def optimize_operator(op):
index
=
nodes
[
index
][
1
]
cond
=
type
(
index
)
is
int
return
def
recognize_nodes
(
op
,
active_node
,
left
):
# If nothing added - is a lea
ve
!
islea
ve
=
True
# If nothing added - is a lea
f
!
islea
f
=
True
if
isinstance
(
op
,
_OpChain
):
for
i
in
range
(
len
(
op
.
_ops
)):
if
isnode
(
op
.
_ops
[
i
]):
nodes
.
append
((
op
.
_ops
[
i
],
active_node
,
left
))
islea
ve
=
False
islea
f
=
False
elif
isnode
(
op
):
nodes
.
append
((
op
,
active_node
,
left
))
islea
ve
=
False
if
islea
ve
:
islea
f
=
False
if
islea
f
:
leaves
.
add
((
active_node
,
left
))
return
def
equal_nodes
(
op
):
"""BFS-Algorithm which fills the nodes list and id_dic dictionary.
...
...
@@ -102,51 +103,51 @@ def optimize_operator(op):
recognize_nodes
(
active_node
.
_op2
,
list_index_traversed
,
False
)
list_index_traversed
+=
1
return
def
equal_leaves
(
leaves
,
edited
):
id_lea
ve
=
{}
id_lea
f
=
{}
# Find matching leaves
def
write_to_dic
(
lea
ve
,
lea
ve
_op_id
):
def
write_to_dic
(
lea
f
,
lea
f
_op_id
):
try
:
id_lea
ve
[
lea
ve
_op_id
]
=
id_lea
ve
[
lea
ve
_op_id
]
+
(
lea
ve
,)
id_lea
f
[
lea
f
_op_id
]
=
id_lea
f
[
lea
f
_op_id
]
+
(
lea
f
,)
except
KeyError
:
id_lea
ve
[
lea
ve
_op_id
]
=
(
lea
ve
,)
for
lea
ve
in
leaves
:
parent
=
nodes
[
lea
ve
[
0
]][
0
]
attr
=
left_parser
(
lea
ve
[
1
])
lea
ve
_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
lea
ve
_op
,
_OpChain
):
lea
ve
_op_id
=
''
for
i
in
reversed
(
lea
ve
_op
.
_ops
):
lea
ve
_op_id
+=
str
(
id
(
i
))
id_lea
f
[
lea
f
_op_id
]
=
(
lea
f
,)
for
lea
f
in
leaves
:
parent
=
nodes
[
lea
f
[
0
]][
0
]
attr
=
left_parser
(
lea
f
[
1
])
lea
f
_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
lea
f
_op
,
_OpChain
):
lea
f
_op_id
=
''
for
i
in
reversed
(
lea
f
_op
.
_ops
):
lea
f
_op_id
+=
str
(
id
(
i
))
if
not
isinstance
(
i
,
FieldAdapter
):
# Do not optimise leaves which only have equal FieldAdapters
write_to_dic
(
lea
ve
,
lea
ve
_op_id
)
write_to_dic
(
lea
f
,
lea
f
_op_id
)
break
else
:
if
not
isinstance
(
lea
ve
_op
,
FieldAdapter
):
write_to_dic
(
lea
ve
,
str
(
id
(
lea
ve
_op
)))
if
not
isinstance
(
lea
f
_op
,
FieldAdapter
):
write_to_dic
(
lea
f
,
str
(
id
(
lea
f
_op
)))
# Unroll their OpChain and see how far they are equal
key_list_op
=
[]
same_op
=
{}
for
item
in
list
(
id_lea
ve
.
items
()):
for
item
in
list
(
id_lea
f
.
items
()):
if
len
(
item
[
1
])
>
1
:
key_list_op
.
append
(
item
[
0
])
for
key
in
key_list_op
:
to_compare
=
[]
for
lea
ve
in
id_lea
ve
[
key
]:
parent
=
nodes
[
lea
ve
[
0
]][
0
]
attr
=
left_parser
(
lea
ve
[
1
])
lea
ve
_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
lea
ve
_op
,
_OpChain
):
to_compare
.
append
(
tuple
(
reversed
(
lea
ve
_op
.
_ops
)))
for
lea
f
in
id_lea
f
[
key
]:
parent
=
nodes
[
lea
f
[
0
]][
0
]
attr
=
left_parser
(
lea
f
[
1
])
lea
f
_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
lea
f
_op
,
_OpChain
):
to_compare
.
append
(
tuple
(
reversed
(
lea
f
_op
.
_ops
)))
else
:
to_compare
.
append
((
lea
ve
_op
,))
to_compare
.
append
((
lea
f
_op
,))
first_difference
=
1
max_diff
=
min
(
len
(
i
)
for
i
in
to_compare
)
if
not
max_diff
==
1
:
...
...
@@ -166,16 +167,16 @@ def optimize_operator(op):
same_op
[
key
]
=
[
res_op
,
FieldAdapter
(
res_op
.
target
,
str
(
id
(
res_op
)))]
for
lea
ve
in
id_lea
ve
[
key
]:
parent
=
nodes
[
lea
ve
[
0
]][
0
]
for
lea
f
in
id_lea
f
[
key
]:
parent
=
nodes
[
lea
f
[
0
]][
0
]
edited
.
add
(
id_dic
[
id
(
parent
)][
0
])
attr
=
left_parser
(
lea
ve
[
1
])
lea
ve
_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
lea
ve
_op
,
_OpChain
):
if
first_difference
==
len
(
lea
ve
_op
.
_ops
):
attr
=
left_parser
(
lea
f
[
1
])
lea
f
_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
lea
f
_op
,
_OpChain
):
if
first_difference
==
len
(
lea
f
_op
.
_ops
):
setattr
(
parent
,
attr
,
same_op
[
key
][
1
])
else
:
lea
ve
_op
.
_ops
=
lea
ve
_op
.
_ops
[:
-
first_difference
]
+
(
same_op
[
key
][
1
],)
lea
f
_op
.
_ops
=
lea
f
_op
.
_ops
[:
-
first_difference
]
+
(
same_op
[
key
][
1
],)
else
:
setattr
(
parent
,
attr
,
same_op
[
key
][
1
])
return
key_list_op
,
same_op
,
edited
...
...
@@ -238,12 +239,15 @@ def optimize_operator(op):
op
=
op
.
partial_insert
(
same_op
[
key
][
1
].
adjoint
(
same_op
[
key
][
0
]))
return
op
from
copy
import
deepcopy
from
nifty6
import
from_random
,
MultiField
from
.sugar
import
from_random
from
.multi_field
import
MultiField
from
numpy
import
allclose
def
optimize_operator_safe
(
op
):
def
optimize_operator
(
op
):
op_optimized
=
deepcopy
(
op
)
op_optimized
=
optimize_operator
(
op_optimized
)
op_optimized
=
_
optimize_operator
(
op_optimized
)
test_field
=
from_random
(
'normal'
,
op
.
domain
)
if
isinstance
(
op
(
test_field
),
MultiField
):
for
key
in
op
(
test_field
).
keys
():
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment