Commit 395ab44f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent 143c96ec
Pipeline #72186 passed with stages
in 15 minutes and 4 seconds
......@@ -379,7 +379,7 @@ class Linearization(object):
def one_over(self):
tmp = 1./self._val
tmp2 = - tmp/self._val
tmp2 = - tmp*tmp
return self.new(tmp, makeOp(tmp2)(self._jac))
def add_metric(self, metric):
......
......@@ -15,11 +15,11 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from nifty6.operators.operator import _OpChain, _OpSum, _OpProd
from nifty6.sugar import domain_union
from nifty6 import FieldAdapter
from .operators.operator import _OpChain, _OpSum, _OpProd
from .sugar import domain_union
from .operators.simple_linear_operators import FieldAdapter
def optimize_operator(op):
def _optimize_operator(op):
"""
Optimizes operator trees, so that same operator subtrees are not computed twice.
Recognizes same subtrees and replaces them at nodes.
......@@ -33,13 +33,14 @@ def optimize_operator(op):
# Format: [parent_index, left]
leaves = set()
def isnode(op):
return isinstance(op, _OpSum) or isinstance(op, _OpProd)
return isinstance(op, (_OpSum, _OpProd))
def left_parser(left_bool):
if left_bool:
return '_op1'
return '_op2'
return '_op1' if left_bool else '_op2'
def rebuild_domains(index):
"""Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
......@@ -60,22 +61,22 @@ def optimize_operator(op):
index = nodes[index][1]
cond = type(index) is int
return
def recognize_nodes(op, active_node, left):
# If nothing added - is a leave!
isleave = True
# If nothing added - is a leaf!
isleaf = True
if isinstance(op, _OpChain):
for i in range(len(op._ops)):
if isnode(op._ops[i]):
nodes.append((op._ops[i], active_node, left))
isleave = False
isleaf = False
elif isnode(op):
nodes.append((op, active_node, left))
isleave = False
if isleave:
isleaf = False
if isleaf:
leaves.add((active_node, left))
return
def equal_nodes(op):
"""BFS-Algorithm which fills the nodes list and id_dic dictionary.
......@@ -102,51 +103,51 @@ def optimize_operator(op):
recognize_nodes(active_node._op2, list_index_traversed, False)
list_index_traversed += 1
return
def equal_leaves(leaves, edited):
id_leave = {}
id_leaf = {}
# Find matching leaves
def write_to_dic(leave, leave_op_id):
def write_to_dic(leaf, leaf_op_id):
try:
id_leave[leave_op_id] = id_leave[leave_op_id] + (leave,)
id_leaf[leaf_op_id] = id_leaf[leaf_op_id] + (leaf,)
except KeyError:
id_leave[leave_op_id] = (leave,)
for leave in leaves:
parent = nodes[leave[0]][0]
attr = left_parser(leave[1])
leave_op = getattr(parent, attr)
if isinstance(leave_op, _OpChain):
leave_op_id = ''
for i in reversed(leave_op._ops):
leave_op_id += str(id(i))
id_leaf[leaf_op_id] = (leaf,)
for leaf in leaves:
parent = nodes[leaf[0]][0]
attr = left_parser(leaf[1])
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
leaf_op_id = ''
for i in reversed(leaf_op._ops):
leaf_op_id += str(id(i))
if not isinstance(i, FieldAdapter):
# Do not optimise leaves which only have equal FieldAdapters
write_to_dic(leave, leave_op_id)
write_to_dic(leaf, leaf_op_id)
break
else:
if not isinstance(leave_op, FieldAdapter):
write_to_dic(leave, str(id(leave_op)))
if not isinstance(leaf_op, FieldAdapter):
write_to_dic(leaf, str(id(leaf_op)))
# Unroll their OpChain and see how far they are equal
key_list_op = []
same_op = {}
for item in list(id_leave.items()):
for item in list(id_leaf.items()):
if len(item[1]) > 1:
key_list_op.append(item[0])
for key in key_list_op:
to_compare = []
for leave in id_leave[key]:
parent = nodes[leave[0]][0]
attr = left_parser(leave[1])
leave_op = getattr(parent, attr)
if isinstance(leave_op, _OpChain):
to_compare.append(tuple(reversed(leave_op._ops)))
for leaf in id_leaf[key]:
parent = nodes[leaf[0]][0]
attr = left_parser(leaf[1])
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
to_compare.append(tuple(reversed(leaf_op._ops)))
else:
to_compare.append((leave_op,))
to_compare.append((leaf_op,))
first_difference = 1
max_diff = min(len(i) for i in to_compare)
if not max_diff == 1:
......@@ -166,16 +167,16 @@ def optimize_operator(op):
same_op[key] = [res_op, FieldAdapter(res_op.target, str(id(res_op)))]
for leave in id_leave[key]:
parent = nodes[leave[0]][0]
for leaf in id_leaf[key]:
parent = nodes[leaf[0]][0]
edited.add(id_dic[id(parent)][0])
attr = left_parser(leave[1])
leave_op = getattr(parent, attr)
if isinstance(leave_op, _OpChain):
if first_difference == len(leave_op._ops):
attr = left_parser(leaf[1])
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
if first_difference == len(leaf_op._ops):
setattr(parent, attr, same_op[key][1])
else:
leave_op._ops = leave_op._ops[:-first_difference] + (same_op[key][1],)
leaf_op._ops = leaf_op._ops[:-first_difference] + (same_op[key][1],)
else:
setattr(parent, attr, same_op[key][1])
return key_list_op, same_op, edited
......@@ -238,12 +239,15 @@ def optimize_operator(op):
op = op.partial_insert(same_op[key][1].adjoint(same_op[key][0]))
return op
from copy import deepcopy
from nifty6 import from_random, MultiField
from .sugar import from_random
from .multi_field import MultiField
from numpy import allclose
def optimize_operator_safe(op):
def optimize_operator(op):
op_optimized = deepcopy(op)
op_optimized = optimize_operator(op_optimized)
op_optimized = _optimize_operator(op_optimized)
test_field = from_random('normal', op.domain)
if isinstance(op(test_field), MultiField):
for key in op(test_field).keys():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment