Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
78a5e223
Commit
78a5e223
authored
Apr 03, 2020
by
Rouven Lemmerz
Browse files
Fixes
parent
395ab44f
Pipeline
#72207
passed with stages
in 15 minutes and 3 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operator_tree_optimiser.py
View file @
78a5e223
...
...
@@ -19,9 +19,9 @@ from .operators.operator import _OpChain, _OpSum, _OpProd
from
.sugar
import
domain_union
from
.operators.simple_linear_operators
import
FieldAdapter
def
_optimi
z
e_operator
(
op
):
def
_optimi
s
e_operator
(
op
):
"""
O
ptimi
z
es operator trees, so that same operator subtrees are not computed twice.
o
ptimi
s
es operator trees, so that same operator subtrees are not computed twice.
Recognizes same subtrees and replaces them at nodes.
Recognizes same leaves and structures them.
Works partly inplace, rendering the old operator unusable"""
...
...
@@ -33,6 +33,24 @@ def _optimize_operator(op):
# Format: [parent_index, left]
leaves
=
set
()
# helper functions
def
readable_id
():
# Gives out letters to prepend field_adapter ids for cosmetics
# Start at 'A'
current_letter
=
65
repeats
=
1
while
True
:
yield
chr
(
current_letter
)
*
repeats
current_letter
+=
1
if
current_letter
==
91
:
# skip specials
current_letter
+=
6
elif
current_letter
==
123
:
# End at z and start at AA
current_letter
=
65
repeats
+=
1
prepend_id
=
readable_id
()
def
isnode
(
op
):
return
isinstance
(
op
,
(
_OpSum
,
_OpProd
))
...
...
@@ -41,7 +59,12 @@ def _optimize_operator(op):
def
left_parser
(
left_bool
):
return
'_op1'
if
left_bool
else
'_op2'
def
get_duplicate_keys
(
k_list
,
dic
):
for
item
in
list
(
dic
.
items
()):
if
len
(
item
[
1
])
>
1
:
k_list
.
append
(
item
[
0
])
# Main algorithm functions
def
rebuild_domains
(
index
):
"""Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
cond
=
True
...
...
@@ -49,12 +72,12 @@ def _optimize_operator(op):
op
=
nodes
[
index
][
0
]
for
attr
in
(
'_op1'
,
'_op2'
):
if
isinstance
(
getattr
(
op
,
attr
),
_OpChain
):
getattr
(
op
,
attr
).
_domain
=
getattr
(
op
,
attr
).
_ops
[
-
1
].
domain
getattr
(
op
,
attr
).
_domain
=
getattr
(
op
,
attr
).
_ops
[
-
1
].
domain
if
isnode
(
op
):
# Some problems doing this on non-multidomains, because one side becomes a multidomain and the other not
try
:
op
.
_domain
=
domain_union
((
op
.
_op1
.
domain
,
op
.
_op2
.
domain
))
except
:
except
AttributeError
:
import
warnings
warnings
.
warn
(
'Operator should be defined on a MultiDomain'
)
pass
...
...
@@ -79,11 +102,9 @@ def _optimize_operator(op):
def
equal_nodes
(
op
):
"""
BFS-Algorithm which fills the nodes list and id_dic dictionary
.
Does not scan equal subtrees multiple times
."""
#
BFS-Algorithm which fills the nodes list and id_dic dictionary
#
Does not scan equal subtrees multiple times
list_index_traversed
=
0
# if isnode(op):
# nodes.append((op, None, None, None))
recognize_nodes
(
op
,
None
,
None
)
while
list_index_traversed
<
len
(
nodes
):
...
...
@@ -104,8 +125,9 @@ def _optimize_operator(op):
list_index_traversed
+=
1
edited
=
set
()
def
equal_leaves
(
leaves
,
edited
):
def
equal_leaves
(
leaves
):
id_leaf
=
{}
# Find matching leaves
def
write_to_dic
(
leaf
,
leaf_op_id
):
...
...
@@ -132,13 +154,11 @@ def _optimize_operator(op):
# Unroll their OpChain and see how far they are equal
key_list_op
=
[]
same_op
=
{}
key_list_leaf
=
[]
same_leaf
=
{}
get_duplicate_keys
(
key_list_leaf
,
id_leaf
)
for
item
in
list
(
id_leaf
.
items
()):
if
len
(
item
[
1
])
>
1
:
key_list_op
.
append
(
item
[
0
])
for
key
in
key_list_op
:
for
key
in
key_list_leaf
:
to_compare
=
[]
for
leaf
in
id_leaf
[
key
]:
parent
=
nodes
[
leaf
[
0
]][
0
]
...
...
@@ -165,7 +185,7 @@ def _optimize_operator(op):
for
ops
in
common_op
[
1
:]:
res_op
=
ops
@
res_op
same_
op
[
key
]
=
[
res_op
,
FieldAdapter
(
res_op
.
target
,
str
(
id
(
res_op
)))]
same_
leaf
[
key
]
=
[
res_op
,
FieldAdapter
(
res_op
.
target
,
next
(
prepend_id
)
+
str
(
id
(
res_op
)))]
for
leaf
in
id_leaf
[
key
]:
parent
=
nodes
[
leaf
[
0
]][
0
]
...
...
@@ -174,58 +194,58 @@ def _optimize_operator(op):
leaf_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
leaf_op
,
_OpChain
):
if
first_difference
==
len
(
leaf_op
.
_ops
):
setattr
(
parent
,
attr
,
same_
op
[
key
][
1
])
setattr
(
parent
,
attr
,
same_
leaf
[
key
][
1
])
else
:
leaf_op
.
_ops
=
leaf_op
.
_ops
[:
-
first_difference
]
+
(
same_
op
[
key
][
1
],)
leaf_op
.
_ops
=
leaf_op
.
_ops
[:
-
first_difference
]
+
(
same_
leaf
[
key
][
1
],)
else
:
setattr
(
parent
,
attr
,
same_
op
[
key
][
1
])
return
key_list_
op
,
same_
op
,
edited
setattr
(
parent
,
attr
,
same_
leaf
[
key
][
1
])
return
key_list_
leaf
,
same_
leaf
equal_nodes
(
op
)
edited
=
set
()
key_list_op
,
same_op
,
edited
=
equal_leaves
(
leaves
,
edited
)
key_temp
=
[]
key_list_op
,
same_op
=
equal_leaves
(
leaves
)
cond
=
True
while
cond
:
key_temp
,
same_op_temp
,
edited_temp
=
equal_leaves
(
leaves
,
edited
)
key_temp
,
same_op_temp
=
equal_leaves
(
leaves
)
key_list_op
+=
key_temp
same_op
.
update
(
same_op_temp
)
edited
.
update
(
edited
)
cond
=
len
(
same_op_temp
)
>
0
key_temp
.
clear
()
# Cut subtrees
key_list_tree
=
[]
same_tree
=
{}
key_list_node
=
[]
key_list_subtrees
=
[]
same_node
=
{}
same_subtrees
=
{}
subtree_leaves
=
set
()
key_list_tree_w_leaves
=
[]
same_tree_w_leaves
=
{}
for
item
in
list
(
id_dic
.
items
()):
if
len
(
item
[
1
])
>
1
:
key_list_tree
.
append
(
item
[
0
])
for
key
in
key_list_tree
:
same_tree
[
key
]
=
[
nodes
[
id_dic
[
key
][
0
]][
0
],]
performance_adapter
=
FieldAdapter
(
same_tree
[
key
][
0
].
target
,
str
(
key
))
same_tree
[
key
]
+=
[
performance_adapter
]
get_duplicate_keys
(
key_list_node
,
id_dic
)
for
key
in
key_list_node
:
same_node
[
key
]
=
[
nodes
[
id_dic
[
key
][
0
]][
0
],
FieldAdapter
(
nodes
[
id_dic
[
key
][
0
]][
0
].
target
,
next
(
prepend_id
)
+
str
(
key
))]
for
node_indices
in
id_dic
[
key
]:
edited
.
add
(
node_indices
)
parent
=
nodes
[
nodes
[
node_indices
][
1
]][
0
]
attr
=
left_parser
(
nodes
[
node_indices
][
2
])
if
isinstance
(
getattr
(
parent
,
attr
),
_OpChain
):
getattr
(
parent
,
attr
).
_ops
=
getattr
(
parent
,
attr
).
_ops
[:
-
1
]
+
(
performance_adapter
,)
getattr
(
parent
,
attr
).
_ops
=
getattr
(
parent
,
attr
).
_ops
[:
-
1
]
+
(
same_node
[
key
][
1
]
,)
else
:
setattr
(
parent
,
attr
,
performance_adapter
)
setattr
(
parent
,
attr
,
same_node
[
key
][
1
])
# Nodes have been replaced - treat replacements now as leaves
subtree_leaves
.
add
((
nodes
[
node_indices
][
1
],
nodes
[
node_indices
][
2
]))
cond
=
True
while
cond
:
key_temp
,
same_op_temp
,
_
=
equal_leaves
(
subtree_leaves
,
edited
)
key_list_tree_w_leaves
+=
key_temp
same_tree_w_leaves
.
update
(
same_op_temp
)
cond
=
len
(
same_op_temp
)
>
0
key_list_tree_w_leaves
+=
[
key
,]
key_temp1
,
same_temp
=
equal_leaves
(
subtree_leaves
)
key_temp
=
key_temp1
+
key_temp
same_subtrees
.
update
(
same_temp
)
cond
=
len
(
same_temp
)
>
0
key_list_subtrees
+=
key_temp
+
[
key
,
]
key_temp
.
clear
()
subtree_leaves
.
clear
()
same_tre
e_w_leav
es
.
update
(
same_
tre
e
)
same_
sub
trees
.
update
(
same_
nod
e
)
for
index
in
edited
:
rebuild_domains
(
index
)
...
...
@@ -233,8 +253,8 @@ def _optimize_operator(op):
op
.
_domain
=
op
.
_ops
[
-
1
].
domain
# Insert trees before leaves
for
key
in
key_list_tre
e_w_leav
es
:
op
=
op
.
partial_insert
(
same_tre
e_w_leav
es
[
key
][
1
].
adjoint
(
same_tre
e_w_leav
es
[
key
][
0
]))
for
key
in
key_list_
sub
trees
:
op
=
op
.
partial_insert
(
same_
sub
trees
[
key
][
1
].
adjoint
(
same_
sub
trees
[
key
][
0
]))
for
key
in
reversed
(
key_list_op
):
op
=
op
.
partial_insert
(
same_op
[
key
][
1
].
adjoint
(
same_op
[
key
][
0
]))
return
op
...
...
@@ -245,13 +265,44 @@ from .sugar import from_random
from
.multi_field
import
MultiField
from
numpy
import
allclose
def
optimize_operator
(
op
):
op_optimized
=
deepcopy
(
op
)
op_optimized
=
_optimize_operator
(
op_optimized
)
def
optimise_operator
(
op
):
"""
Merges redundant operations in the tree structure of an operator.
For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once.
Parameters
----------
op: Operator
Returns
-------
op_optimised : Operator
Notes
-----
Since operators are compared by id best results are achieved when the following code
>>> from nifty6 import UniformOperator, DomainTuple
>>> uni1 = UniformOperator(DomainTuple.scalar_domain()
>>> uni2 = UniformOperator(DomainTuple.scalar_domain()
>>> op = (uni1 + uni2)*(uni1 + uni2)
is replaced by something comparable to
>>> from nifty6 import UniformOperator, DomainTuple
>>> uni = UniformOperator(DomainTuple.scalar_domain())
>>> uni_add = uni + uni
>>> op = uni_add * uni_add
After optimisation the operator is comparable in speed to
>>> op = (2*uni)**2
"""
op_optimised
=
deepcopy
(
op
)
op_optimised
=
_optimise_operator
(
op_optimised
)
test_field
=
from_random
(
'normal'
,
op
.
domain
)
if
isinstance
(
op
(
test_field
),
MultiField
):
for
key
in
op
(
test_field
).
keys
():
assert
allclose
(
op
(
test_field
).
val
[
key
],
op_optimi
z
ed
(
test_field
).
val
[
key
],
1e-10
)
assert
allclose
(
op
(
test_field
).
val
[
key
],
op_optimi
s
ed
(
test_field
).
val
[
key
],
1e-10
)
else
:
assert
allclose
(
op
(
test_field
).
val
,
op_optimi
z
ed
(
test_field
).
val
,
1e-10
)
return
op_optimi
z
ed
assert
allclose
(
op
(
test_field
).
val
,
op_optimi
s
ed
(
test_field
).
val
,
1e-10
)
return
op_optimi
s
ed
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment