Commit 2990b099 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'operator_tree_optimiser' into 'NIFTy_6'

Operator tree optimiser

See merge request !434
parents d281a905 0b20d39b
Pipeline #75289 passed with stages
in 8 minutes and 22 seconds
......@@ -98,5 +98,7 @@ from .linearization import Linearization
from .operator_spectrum import operator_spectrum
from .operator_tree_optimiser import optimise_operator
# We deliberately don't set __all__ here, because we don't want people to do a
# "from nifty6 import *"; that would swamp the global namespace.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .operators.operator import _OpChain, _OpSum, _OpProd
from .sugar import domain_union
from .operators.simple_linear_operators import FieldAdapter
def _optimise_operator(op):
"""
optimises operator trees, so that same operator subtrees are not computed twice.
Recognizes same subtrees and replaces them at nodes.
Recognizes same leaves and structures them.
Works partly inplace, rendering the old operator unusable"""
# Format: List of tuple [op, parent_index, left=True right=False]
nodes = []
# Format: ID: index in nodes[]
id_dic = {}
# Format: [parent_index, left]
leaves = set()
# helper functions
def readable_id():
# Gives out letters to prepend field_adapter ids for cosmetics
# Start at 'A'
current_letter = 65
repeats = 1
while True:
yield chr(current_letter)*repeats
current_letter += 1
if current_letter == 91:
# skip specials
current_letter += 6
elif current_letter == 123:
# End at z and start at AA
current_letter = 65
repeats += 1
prepend_id = readable_id()
def isnode(op):
return isinstance(op, (_OpSum, _OpProd))
def left_parser(left_bool):
return '_op1' if left_bool else '_op2'
def get_duplicate_keys(k_list, dic):
for item in list(dic.items()):
if len(item[1]) > 1:
k_list.append(item[0])
# Main algorithm functions
def rebuild_domains(index):
"""Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
cond = True
while cond:
op = nodes[index][0]
for attr in ('_op1', '_op2'):
if isinstance(getattr(op, attr), _OpChain):
getattr(op, attr)._domain = getattr(op, attr)._ops[-1].domain
if isnode(op):
# Some problems doing this on non-multidomains, because one side becomes a multidomain and the other not
try:
op._domain = domain_union((op._op1.domain, op._op2.domain))
except AttributeError:
import warnings
warnings.warn('Operator should be defined on a MultiDomain')
pass
index = nodes[index][1]
cond = type(index) is int
def recognize_nodes(op, active_node, left):
# If nothing added - is a leaf!
isleaf = True
if isinstance(op, _OpChain):
for i in range(len(op._ops)):
if isnode(op._ops[i]):
nodes.append((op._ops[i], active_node, left))
isleaf = False
elif isnode(op):
nodes.append((op, active_node, left))
isleaf = False
if isleaf:
leaves.add((active_node, left))
def equal_nodes(op):
# BFS-Algorithm which fills the nodes list and id_dic dictionary
# Does not scan equal subtrees multiple times
list_index_traversed = 0
recognize_nodes(op, None, None)
while list_index_traversed < len(nodes):
# Visit node
active_node = nodes[list_index_traversed][0]
# Check whether exists already
try:
id_dic[id(active_node)] = id_dic[id(active_node)] + [list_index_traversed]
match = True
except KeyError:
id_dic[id(active_node)] = [list_index_traversed]
match = False
# Check vertices for nodes
if not match:
recognize_nodes(active_node._op1, list_index_traversed, True)
recognize_nodes(active_node._op2, list_index_traversed, False)
list_index_traversed += 1
edited = set()
def equal_leaves(leaves):
id_leaf = {}
# Find matching leaves
def write_to_dic(leaf, leaf_op_id):
try:
id_leaf[leaf_op_id] = id_leaf[leaf_op_id] + (leaf,)
except KeyError:
id_leaf[leaf_op_id] = (leaf,)
for leaf in leaves:
parent = nodes[leaf[0]][0]
attr = left_parser(leaf[1])
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
leaf_op_id = ''
for i in reversed(leaf_op._ops):
leaf_op_id += str(id(i))
if not isinstance(i, FieldAdapter):
# Do not optimise leaves which only have equal FieldAdapters
write_to_dic(leaf, leaf_op_id)
break
else:
if not isinstance(leaf_op, FieldAdapter):
write_to_dic(leaf, str(id(leaf_op)))
# Unroll their OpChain and see how far they are equal
key_list_leaf = []
same_leaf = {}
get_duplicate_keys(key_list_leaf, id_leaf)
for key in key_list_leaf:
to_compare = []
for leaf in id_leaf[key]:
parent = nodes[leaf[0]][0]
attr = left_parser(leaf[1])
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
to_compare.append(tuple(reversed(leaf_op._ops)))
else:
to_compare.append((leaf_op,))
first_difference = 1
max_diff = min(len(i) for i in to_compare)
if not max_diff == 1:
compare_iterator = iter(to_compare)
first = next(compare_iterator)
while all(first[first_difference] == rest[first_difference] for rest in compare_iterator):
first_difference += 1
if first_difference >= max_diff:
break
compare_iterator = iter(to_compare)
first = next(compare_iterator)
common_op = to_compare[0][:first_difference]
res_op = common_op[0]
for ops in common_op[1:]:
res_op = ops @ res_op
same_leaf[key] = [res_op, FieldAdapter(res_op.target, next(prepend_id) + str(id(res_op)))]
for leaf in id_leaf[key]:
parent = nodes[leaf[0]][0]
edited.add(id_dic[id(parent)][0])
attr = left_parser(leaf[1])
leaf_op = getattr(parent, attr)
if isinstance(leaf_op, _OpChain):
if first_difference == len(leaf_op._ops):
setattr(parent, attr, same_leaf[key][1])
else:
leaf_op._ops = leaf_op._ops[:-first_difference] + (same_leaf[key][1],)
else:
setattr(parent, attr, same_leaf[key][1])
return key_list_leaf, same_leaf
equal_nodes(op)
key_temp = []
key_list_op, same_op = equal_leaves(leaves)
cond = True
while cond:
key_temp, same_op_temp = equal_leaves(leaves)
key_list_op += key_temp
same_op.update(same_op_temp)
cond = len(same_op_temp) > 0
key_temp.clear()
# Cut subtrees
key_list_node = []
key_list_subtrees = []
same_node = {}
same_subtrees = {}
subtree_leaves = set()
get_duplicate_keys(key_list_node, id_dic)
for key in key_list_node:
same_node[key] = [nodes[id_dic[key][0]][0],
FieldAdapter(nodes[id_dic[key][0]][0].target, next(prepend_id) + str(key))]
for node_indices in id_dic[key]:
edited.add(node_indices)
parent = nodes[nodes[node_indices][1]][0]
attr = left_parser(nodes[node_indices][2])
if isinstance(getattr(parent, attr), _OpChain):
getattr(parent, attr)._ops = getattr(parent, attr)._ops[:-1] + (same_node[key][1],)
else:
setattr(parent, attr, same_node[key][1])
# Nodes have been replaced - treat replacements now as leaves
subtree_leaves.add((nodes[node_indices][1], nodes[node_indices][2]))
cond = True
while cond:
key_temp1, same_temp = equal_leaves(subtree_leaves)
key_temp = key_temp1 + key_temp
same_subtrees.update(same_temp)
cond = len(same_temp) > 0
key_list_subtrees += key_temp + [key, ]
key_temp.clear()
subtree_leaves.clear()
same_subtrees.update(same_node)
for index in edited:
rebuild_domains(index)
if isinstance(op, _OpChain):
op._domain = op._ops[-1].domain
# Insert trees before leaves
for key in key_list_subtrees:
op = op.partial_insert(same_subtrees[key][1].adjoint(same_subtrees[key][0]))
for key in reversed(key_list_op):
op = op.partial_insert(same_op[key][1].adjoint(same_op[key][0]))
return op
from copy import deepcopy
from .sugar import from_random
from .multi_field import MultiField
from numpy import allclose
def optimise_operator(op):
"""
Merges redundant operations in the tree structure of an operator.
For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once.
Currently works only on ``_OpChain``, ``_OpSum`` and ``_OpProd`` and does not optimise their linear pendants
``ChainOp`` and ``SumOperator``.
Parameters
----------
op: Operator
Returns
-------
op_optimised : Operator
Notes
-----
Since operators are compared by id best results are achieved when the following code
>>> from nifty6 import UniformOperator, DomainTuple
>>> uni1 = UniformOperator(DomainTuple.scalar_domain()
>>> uni2 = UniformOperator(DomainTuple.scalar_domain()
>>> op = (uni1 + uni2)*(uni1 + uni2)
is replaced by something comparable to
>>> uni = UniformOperator(DomainTuple.scalar_domain())
>>> uni_add = uni + uni
>>> op = uni_add * uni_add
After optimisation the operator is comparable in speed to
>>> op = (2*uni)**2
"""
op_optimised = deepcopy(op)
op_optimised = _optimise_operator(op_optimised)
test_field = from_random('normal', op.domain)
if isinstance(op(test_field), MultiField):
for key in op(test_field).keys():
assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10)
else:
assert allclose(op(test_field).val, op_optimised(test_field).val, 1e-10)
return op_optimised
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2020 Max-Planck-Society
# Author: Rouven Lemmerz
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from numpy.testing import assert_, assert_allclose
import numpy as np
from copy import deepcopy
import nifty6 as ift
class CountingOp(ift.Operator):
#FIXME: Not a LinearOperator since ChainOps not supported yet
def __init__(self, domain):
self._domain = self._target = ift.sugar.makeDomain(domain)
self._count = 0
def apply(self, x):
self._count += 1
return x
@property
def count(self):
return self._count
def test_operator_tree_optimiser():
dom = ift.RGSpace(10, harmonic=True)
hdom = dom.get_default_codomain()
cop1 = CountingOp(dom)
op1 = (ift.UniformOperator(dom, -1, 2)@cop1).ducktape('a')
cop2 = CountingOp(dom)
op2 = ift.FieldZeroPadder(dom, (11,))@cop2
cop3 = CountingOp(op2.target)
op3 = ift.ScalingOperator(op2.target, 3)@cop3
cop4 = CountingOp(op2.target)
op4 = ift.ScalingOperator(op2.target, 1.5) @ cop4
op1 = op1 * op1
# test layering in between two levels
op = op3@op2@op1 + op2@op1 + op3@op2@op1 + op2@op1
op = op + op
op = op4@(op4@op + op4@op)
fld = ift.from_random(op.domain, 'normal', np.float64)
op_orig = deepcopy(op)
op = ift.operator_tree_optimiser._optimise_operator(op)
assert_allclose(op(fld).val, op_orig(fld).val, rtol=np.finfo(np.float64).eps)
assert_(1 == ( (cop4.count-1) * cop3.count * cop2.count * cop1.count))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment