Commit 4b8621d3 authored by Rouven Lemmerz's avatar Rouven Lemmerz
Browse files

Checking leaves multiple times

parent de5e30c4
Pipeline #71780 passed with stages
in 15 minutes and 10 seconds
......@@ -50,7 +50,13 @@ def optimize_operator(op):
if isinstance(getattr(op, attr), _OpChain):
getattr(op, attr)._domain = getattr(op, attr)._ops[-1].domain
if isnode(op):
op._domain = domain_union((op._op1.domain, op._op2.domain))
# Some problems doing this on non-multidomains, because one side becomes a multidomain and the other not
try:
op._domain = domain_union((op._op1.domain, op._op2.domain))
except:
import warnings
warnings.warn('Operator should be defined on a MultiDomain')
pass
index = nodes[index][1]
cond = type(index) is int
......@@ -100,82 +106,61 @@ def optimize_operator(op):
equal_nodes(op)
# Cut subtrees
key_list_tree = []
same_tree = {}
edited = set()
for item in list(id_dic.items()):
if len(item[1]) > 1:
key_list_tree.append(item[0])
for key in key_list_tree:
same_tree[key] = [nodes[id_dic[key][0]][0],]
performance_adapter = FieldAdapter(same_tree[key][0].target, str(key))
same_tree[key] += [performance_adapter]
for node_indices in id_dic[key]:
edited.add(node_indices)
pos = nodes[node_indices][3]
parent = nodes[nodes[node_indices][1]][0]
#TODO: cut subtrees can also be regarded as nodes
#Some kind of order is needed to get the domains right at the end...
#leaves.add((node_indices, nodes[node_indices][2]))
attr = left_parser(nodes[node_indices][2])
if pos is not -1:
getattr(parent, attr)._ops = getattr(parent, attr)._ops[:pos] + (performance_adapter,) + \
getattr(parent, attr)._ops[pos+1:]
else:
setattr(parent, attr, performance_adapter)
# Compare leaves
id_leave = {}
# Find matching leaves
for leave in leaves:
parent = nodes[leave[0]][0]
attr = left_parser(leave[1])
leave_op = getattr(parent, attr)
if isinstance(leave_op, _OpChain):
leave_op = leave_op._ops[-1]
try:
id_leave[id(leave_op)] = id_leave[id(leave_op)] + (leave,)
except KeyError:
id_leave[id(leave_op)] = (leave, )
# Unroll their OpChain and see how far they are equal
# TODO: Repeat detection, to handle eg. op2(op) + op3(op) + op2(op)
key_list_op = []
same_op = {}
def equal_leaves(leaves, edited):
id_leave = {}
# Find matching leaves
def write_to_dic(leave, leave_op_id):
try:
id_leave[leave_op_id] = id_leave[leave_op_id] + (leave,)
except KeyError:
id_leave[leave_op_id] = (leave,)
for item in list(id_leave.items()):
if len(item[1]) > 1:
key_list_op.append(item[0])
for key in key_list_op:
to_compare = []
for leave in id_leave[key]:
for leave in leaves:
parent = nodes[leave[0]][0]
attr = left_parser(leave[1])
leave_op = getattr(parent, attr)
if isinstance(leave_op, _OpChain):
to_compare.append(tuple(reversed(leave_op._ops)))
leave_op_id = ''
for i in reversed(leave_op._ops):
leave_op_id += str(id(i))
if not isinstance(i, FieldAdapter):
# Do not optimise leaves which only have equal FieldAdapters
write_to_dic(leave, leave_op_id)
break
else:
to_compare.append((leave_op,))
first_difference = 1
max_diff = min(len(i) for i in to_compare)
if not max_diff == 1:
compare_iterator = iter(to_compare)
first = next(compare_iterator)
while all(first[first_difference] == rest[first_difference] for rest in compare_iterator):
first_difference +=1
if first_difference >= max_diff:
break
if not isinstance(leave_op, FieldAdapter):
write_to_dic(leave, str(id(leave_op)))
# Unroll their OpChain and see how far they are equal
key_list_op = []
same_op = {}
for item in list(id_leave.items()):
if len(item[1]) > 1:
key_list_op.append(item[0])
for key in key_list_op:
to_compare = []
for leave in id_leave[key]:
parent = nodes[leave[0]][0]
attr = left_parser(leave[1])
leave_op = getattr(parent, attr)
if isinstance(leave_op, _OpChain):
to_compare.append(tuple(reversed(leave_op._ops)))
else:
to_compare.append((leave_op,))
first_difference = 1
max_diff = min(len(i) for i in to_compare)
if not max_diff == 1:
compare_iterator = iter(to_compare)
first = next(compare_iterator)
while all(first[first_difference] == rest[first_difference] for rest in compare_iterator):
first_difference += 1
if first_difference >= max_diff:
break
compare_iterator = iter(to_compare)
first = next(compare_iterator)
# Do not optimise leaves which only have equal FieldAdapters
if first_difference > int(isinstance(to_compare[0][-1], FieldAdapter)):
common_op = to_compare[0][:first_difference]
res_op = common_op[0]
for ops in common_op[1:]:
......@@ -195,6 +180,47 @@ def optimize_operator(op):
leave_op._ops = leave_op._ops[:-first_difference] + (same_op[key][1],)
else:
setattr(parent, attr, same_op[key][1])
return key_list_op, same_op, edited
edited = set()
key_list_op, same_op, edited = equal_leaves(leaves, edited)
cond = True
while cond:
key_temp, same_op_temp, edited_temp = equal_leaves(leaves, edited)
key_list_op += key_temp
same_op.update(same_op_temp)
edited.update(edited)
cond = len(same_op_temp) > 0
# Cut subtrees
key_list_tree = []
same_tree = {}
subtree_leaves = set()
for item in list(id_dic.items()):
if len(item[1]) > 1:
key_list_tree.append(item[0])
for key in key_list_tree:
same_tree[key] = [nodes[id_dic[key][0]][0],]
performance_adapter = FieldAdapter(same_tree[key][0].target, str(key))
same_tree[key] += [performance_adapter]
for node_indices in id_dic[key]:
edited.add(node_indices)
pos = nodes[node_indices][3]
parent = nodes[nodes[node_indices][1]][0]
#TODO: cut subtrees can also be regarded as leaves
#Some kind of order is needed to get the domains right at the end...
#leaves.add((node_indices, nodes[node_indices][2]))
attr = left_parser(nodes[node_indices][2])
if pos is not -1:
getattr(parent, attr)._ops = getattr(parent, attr)._ops[:pos] + (performance_adapter,) + \
getattr(parent, attr)._ops[pos+1:]
else:
setattr(parent, attr, performance_adapter)
for index in edited:
rebuild_domains(index)
......@@ -204,7 +230,7 @@ def optimize_operator(op):
# Insert trees before leaves
for key in key_list_tree:
op = op.partial_insert(same_tree[key][1].adjoint(same_tree[key][0]))
for key in key_list_op:
for key in reversed(key_list_op):
op = op.partial_insert(same_op[key][1].adjoint(same_op[key][0]))
return op
......@@ -215,7 +241,6 @@ def optimize_operator_safe(op):
op_optimized = deepcopy(op)
op_optimized = optimize_operator(op_optimized)
test_field = from_random('normal', op.domain)
if isinstance(op(test_field), MultiField):
for key in op(test_field).keys():
assert allclose(op(test_field).val[key], op_optimized(test_field).val[key], 1e-10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment