Commit 78a5e223 authored by Rouven Lemmerz's avatar Rouven Lemmerz
Browse files

Fixes

parent 395ab44f
Pipeline #72207 passed with stages
in 15 minutes and 3 seconds
......@@ -19,9 +19,9 @@ from .operators.operator import _OpChain, _OpSum, _OpProd
from .sugar import domain_union
from .operators.simple_linear_operators import FieldAdapter
def _optimize_operator(op):
def _optimise_operator(op):
"""
Optimizes operator trees, so that same operator subtrees are not computed twice.
optimises operator trees, so that same operator subtrees are not computed twice.
Recognizes same subtrees and replaces them at nodes.
Recognizes same leaves and structures them.
Works partly inplace, rendering the old operator unusable"""
......@@ -33,6 +33,24 @@ def _optimize_operator(op):
# Format: [parent_index, left]
leaves = set()
# helper functions
def readable_id():
# Gives out letters to prepend field_adapter ids for cosmetics
# Start at 'A'
current_letter = 65
repeats = 1
while True:
yield chr(current_letter)*repeats
current_letter += 1
if current_letter == 91:
# skip specials
current_letter += 6
elif current_letter == 123:
# End at z and start at AA
current_letter = 65
repeats += 1
prepend_id = readable_id()
def isnode(op):
return isinstance(op, (_OpSum, _OpProd))
......@@ -41,7 +59,12 @@ def _optimize_operator(op):
def left_parser(left_bool):
return '_op1' if left_bool else '_op2'
def get_duplicate_keys(k_list, dic):
for item in list(dic.items()):
if len(item[1]) > 1:
k_list.append(item[0])
# Main algorithm functions
def rebuild_domains(index):
"""Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
cond = True
......@@ -49,12 +72,12 @@ def _optimize_operator(op):
op = nodes[index][0]
for attr in ('_op1', '_op2'):
if isinstance(getattr(op, attr), _OpChain):
getattr(op, attr)._domain = getattr(op, attr)._ops[-1].domain
getattr(op, attr)._domain = getattr(op, attr)._ops[-1].domain
if isnode(op):
# Some problems doing this on non-multidomains, because one side becomes a multidomain and the other not
try:
op._domain = domain_union((op._op1.domain, op._op2.domain))
except:
except AttributeError:
import warnings
warnings.warn('Operator should be defined on a MultiDomain')
pass
......@@ -79,11 +102,9 @@ def _optimize_operator(op):
def equal_nodes(op):
"""BFS-Algorithm which fills the nodes list and id_dic dictionary.
Does not scan equal subtrees multiple times."""
# BFS-Algorithm which fills the nodes list and id_dic dictionary
# Does not scan equal subtrees multiple times
list_index_traversed = 0
# if isnode(op):
# nodes.append((op, None, None, None))
recognize_nodes(op, None, None)
while list_index_traversed < len(nodes):
......@@ -104,8 +125,9 @@ def _optimize_operator(op):
list_index_traversed += 1
edited = set()
def equal_leaves(leaves, edited):
def equal_leaves(leaves):
id_leaf = {}
# Find matching leaves
def write_to_dic(leaf, leaf_op_id):
......@@ -132,13 +154,11 @@ def _optimize_operator(op):
# Unroll their OpChain and see how far they are equal
key_list_op = []
same_op = {}
key_list_leaf = []
same_leaf = {}
get_duplicate_keys(key_list_leaf, id_leaf)
for item in list(id_leaf.items()):
if len(item[1]) > 1:
key_list_op.append(item[0])
for key in key_list_op:
for key in key_list_leaf:
to_compare = []
for leaf in id_leaf[key]:
parent = nodes[leaf[0]][0]
......@@ -165,7 +185,7 @@ def _optimize_operator(op):
for ops in common_op[1:]:
res_op = ops @ res_op
same_op[key] = [res_op, FieldAdapter(res_op.target, str(id(res_op)))]
same_leaf[key] = [res_op, FieldAdapter(res_op.target, next(prepend_id) + str(id(res_op)))]
for leaf in id_leaf[key]:
parent = nodes[leaf[0]][0]
......@@ -174,58 +194,58 @@ def _optimize_operator(op):
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
if first_difference == len(leaf_op._ops):
setattr(parent, attr, same_op[key][1])
setattr(parent, attr, same_leaf[key][1])
else:
leaf_op._ops = leaf_op._ops[:-first_difference] + (same_op[key][1],)
leaf_op._ops = leaf_op._ops[:-first_difference] + (same_leaf[key][1],)
else:
setattr(parent, attr, same_op[key][1])
return key_list_op, same_op, edited
setattr(parent, attr, same_leaf[key][1])
return key_list_leaf, same_leaf
equal_nodes(op)
edited = set()
key_list_op, same_op, edited = equal_leaves(leaves, edited)
key_temp = []
key_list_op, same_op = equal_leaves(leaves)
cond = True
while cond:
key_temp, same_op_temp, edited_temp = equal_leaves(leaves, edited)
key_temp, same_op_temp = equal_leaves(leaves)
key_list_op += key_temp
same_op.update(same_op_temp)
edited.update(edited)
cond = len(same_op_temp) > 0
key_temp.clear()
# Cut subtrees
key_list_tree = []
same_tree = {}
key_list_node = []
key_list_subtrees = []
same_node = {}
same_subtrees = {}
subtree_leaves = set()
key_list_tree_w_leaves = []
same_tree_w_leaves = {}
for item in list(id_dic.items()):
if len(item[1]) > 1:
key_list_tree.append(item[0])
for key in key_list_tree:
same_tree[key] = [nodes[id_dic[key][0]][0],]
performance_adapter = FieldAdapter(same_tree[key][0].target, str(key))
same_tree[key] += [performance_adapter]
get_duplicate_keys(key_list_node, id_dic)
for key in key_list_node:
same_node[key] = [nodes[id_dic[key][0]][0],
FieldAdapter(nodes[id_dic[key][0]][0].target, next(prepend_id) + str(key))]
for node_indices in id_dic[key]:
edited.add(node_indices)
parent = nodes[nodes[node_indices][1]][0]
attr = left_parser(nodes[node_indices][2])
if isinstance(getattr(parent, attr), _OpChain):
getattr(parent, attr)._ops = getattr(parent, attr)._ops[:-1] + (performance_adapter,)
getattr(parent, attr)._ops = getattr(parent, attr)._ops[:-1] + (same_node[key][1],)
else:
setattr(parent, attr, performance_adapter)
setattr(parent, attr, same_node[key][1])
# Nodes have been replaced - treat replacements now as leaves
subtree_leaves.add((nodes[node_indices][1], nodes[node_indices][2]))
cond = True
while cond:
key_temp, same_op_temp, _ = equal_leaves(subtree_leaves, edited)
key_list_tree_w_leaves += key_temp
same_tree_w_leaves.update(same_op_temp)
cond = len(same_op_temp) > 0
key_list_tree_w_leaves += [key,]
key_temp1, same_temp = equal_leaves(subtree_leaves)
key_temp = key_temp1 + key_temp
same_subtrees.update(same_temp)
cond = len(same_temp) > 0
key_list_subtrees += key_temp + [key, ]
key_temp.clear()
subtree_leaves.clear()
same_tree_w_leaves.update(same_tree)
same_subtrees.update(same_node)
for index in edited:
rebuild_domains(index)
......@@ -233,8 +253,8 @@ def _optimize_operator(op):
op._domain = op._ops[-1].domain
# Insert trees before leaves
for key in key_list_tree_w_leaves:
op = op.partial_insert(same_tree_w_leaves[key][1].adjoint(same_tree_w_leaves[key][0]))
for key in key_list_subtrees:
op = op.partial_insert(same_subtrees[key][1].adjoint(same_subtrees[key][0]))
for key in reversed(key_list_op):
op = op.partial_insert(same_op[key][1].adjoint(same_op[key][0]))
return op
......@@ -245,13 +265,44 @@ from .sugar import from_random
from .multi_field import MultiField
from numpy import allclose
def optimize_operator(op):
op_optimized = deepcopy(op)
op_optimized = _optimize_operator(op_optimized)
def optimise_operator(op):
"""
Merges redundant operations in the tree structure of an operator.
For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once.
Parameters
----------
op: Operator
Returns
-------
op_optimised : Operator
Notes
-----
Since operators are compared by id best results are achieved when the following code
>>> from nifty6 import UniformOperator, DomainTuple
>>> uni1 = UniformOperator(DomainTuple.scalar_domain()
>>> uni2 = UniformOperator(DomainTuple.scalar_domain()
>>> op = (uni1 + uni2)*(uni1 + uni2)
is replaced by something comparable to
>>> from nifty6 import UniformOperator, DomainTuple
>>> uni = UniformOperator(DomainTuple.scalar_domain())
>>> uni_add = uni + uni
>>> op = uni_add * uni_add
After optimisation the operator is comparable in speed to
>>> op = (2*uni)**2
"""
op_optimised = deepcopy(op)
op_optimised = _optimise_operator(op_optimised)
test_field = from_random('normal', op.domain)
if isinstance(op(test_field), MultiField):
for key in op(test_field).keys():
assert allclose(op(test_field).val[key], op_optimized(test_field).val[key], 1e-10)
assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10)
else:
assert allclose(op(test_field).val, op_optimized(test_field).val, 1e-10)
return op_optimized
assert allclose(op(test_field).val, op_optimised(test_field).val, 1e-10)
return op_optimised
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment