Commit a1fcf437 authored by Rouven Lemmerz's avatar Rouven Lemmerz
Browse files

Improved algorithm

parent 1fbf8e2a
Pipeline #71682 passed with stages
in 15 minutes and 13 seconds
import nifty6 as ift
import numpy as np
def optimize_node(op):
"""Takes an operator node (_OpSum or _OpProd) and searches for similar substructures in both vertices.
Only searches one vertex deep.
"""
if isinstance(op, ift.operators.operator._OpSum):
sum = True
elif isinstance(op, ift.operators.operator._OpProd):
sum = False
else:
return op
def to_list(x):
if isinstance(x, ift.operators.operator._OpChain):
op_list = list(reversed(x._ops))
else:
op_list = [x,]
return op_list
op_list = [to_list(op._op1), to_list(op._op2)]
first_difference=0
try:
while op_list[0][first_difference] is op_list[1][first_difference]:
first_difference = first_difference + 1
except IndexError:
pass
if first_difference == 0:
return op
common_op = op_list[0][:first_difference]
res_op = common_op[-1]
for ops in reversed(common_op[:-1]):
res_op = res_op @ ops
performance_adapter = ift.FieldAdapter(res_op.target, str(id(res_op)))
vertex = [0, 0]
for i in range(len(op_list)):
op_list[i][:first_difference] = [performance_adapter]
vertex[i] = op_list[i][-1]
for ops in reversed(op_list[i][:-1]):
vertex[i] = vertex[i] @ ops
# This seems broken
# op._op1 = vertex[0]
# op._op2 = vertex[1]
# op._domain = ift.sugar.domain_union((op._op1.domain, op._op2.domain))
if sum: op = vertex[0] + vertex[1]
else: op = vertex[0] * vertex[1]
op = op.partial_insert(performance_adapter.adjoint(res_op))
return op
def optimize_all_nodes(op):
"""Traverses the tree and applies optimization to every node"""
if isinstance(op, ift.operators.operator._OpChain):
x = op._ops[-1]
chained = True
else:
x = op
chained = False
if isinstance(x, ift.operators.operator._OpSum) or isinstance(x, ift.operators.operator._OpProd):
#postorder traversing
x._op1 = optimize_all_nodes(x._op1)
x._op2 = optimize_all_nodes(x._op2)
x = optimize_node(x)
if chained:
op._ops = op._ops[:-1] + (x,)
op._domain = x.domain
else:
op = x
return op
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Some Examples for above
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from nifty6.operators.operator import _OpChain, _OpSum, _OpProd
from nifty6.sugar import domain_union
from nifty6 import FieldAdapter
class CountingOp(ift.LinearOperator):
def __init__(self, domain):
self._domain = self._target = ift.sugar.makeDomain(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._count = 0
def apply(self, x, mode):
self._count += 1
return x
@property
def count(self):
return self._count
dom = ift.DomainTuple.scalar_domain()
h = ift.from_random('normal', dom)
copuni = CountingOp(dom)
copig = CountingOp(dom)
uni = copuni @ ift.UniformOperator(dom)
ig = copig @ ift.InverseGammaOperator(dom, 1, 1)
op = uni + ig(uni)
op(h)
print(copuni.count) # is increased by 2
print(copig.count)
op2 = optimize_all_nodes(op)
op2(h)
print(copuni.count) # only increased by 1
print(copig.count)
np.allclose(op(h).val, op2(h).val)
# More complex things work partially:
op = ig(uni + uni +uni(ig))
op2 = optimize_all_nodes(op)
# However, since the search depth is only one vertex, this is not optimised:
op = ig(uni + uni(ig) + uni)
op2 = optimize_all_nodes(op)
# To find bigger chunks of subtrees the following function is defined:
def optimize_subtrees(op):
"""Recognizes same subtrees and replaces them.
Currently only works on operators defined on Multidomains.
Should work inplace"""
if not isinstance(op.domain, ift.MultiDomain):
raise TypeError('Operator needs to be defined on a multidomain')
dic = {}
dic_list = {}
def equal_vertices(x, coord):
if isinstance(x, ift.operators.operator._OpChain):
def optimize_operator(op):
"""
Optimizes operator trees, so that same operator subtrees are not computed twice.
Recognizes same subtrees and replaces them at nodes.
Recognizes same leaves and structures them.
Works partly inplace, rendering the old operator unusable"""
# Format: List of tuple [op, parent_index, left=True right=False, pos in op_chain]
nodes = []
# Format: ID: index in nodes[]
id_dic = {}
# Format: [parent_index, left]
leaves = set()
def isnode(op):
return isinstance(op, _OpSum) or isinstance(op, _OpProd)
def left_parser(left_bool):
if left_bool:
return '_op1'
return '_op2'
def rebuild_domains(index):
"""Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
cond = True
while cond:
op = nodes[index][0]
for attr in ('_op1', '_op2'):
if isinstance(getattr(op, attr), _OpChain):
getattr(op, attr)._domain = getattr(op, attr)._ops[-1].domain
if isnode(op):
op._domain = domain_union((op._op1.domain, op._op2.domain))
index = nodes[index][1]
cond = type(index) is int
return
def recognize_nodes(op, active_node, left):
# If nothing added - is a leave!
isleave = True
if isinstance(op, _OpChain):
for i in range(len(op._ops)):
if isnode(op._ops[i]):
nodes.append((op._ops[i], active_node, left, i))
isleave = False
elif isnode(op):
nodes.append((op, active_node, left, -1))
isleave = False
if isleave:
leaves.add((active_node, left))
return
def equal_nodes(op):
"""BFS-Algorithm which fills the nodes list and id_dic dictionary.
Does not scan equal subtrees multiple times."""
list_index_traversed = 0
# if isnode(op):
# nodes.append((op, None, None, None))
recognize_nodes(op, None, None)
while list_index_traversed < len(nodes):
# Visit node
active_node = nodes[list_index_traversed][0]
# Check whether exists already
try:
# Might be better to save parent + left or right
# However, this makes adjusting the operator domains hard later on
dic[id(x)] = [coord, ] + dic[id(x)]
#dic_list[id(x)] = dic_list[id(x)]
id_dic[id(active_node)] = id_dic[id(active_node)] + [list_index_traversed]
match = True
except KeyError:
dic[id(x)] = [coord, ]
dic_list[id(x)] = x
x = x._ops[-1]
if isinstance(x, ift.operators.operator._OpSum) or isinstance(x, ift.operators.operator._OpProd):
equal_vertices(x._op1, coord + [1])
equal_vertices(x._op2, coord + [2])
equal_vertices(op, [])
key = None
#Heuristically, the first entry should be the largest subtree
for items in list(dic.items()):
if len(items[1]) > 1:
# Multiple Ops are the same
key = items[0]
break
if key is None:
return op
same_op = dic_list[key]
performance_adapter = ift.FieldAdapter(same_op.target, str(key))
visited = []
for coord_list in dic[key]:
x = op
for coord in coord_list[:-1]:
# Travel to the nodes
if isinstance(x, ift.operators.operator._OpChain):
visited.append(x)
x = x._ops[-1]
visited.append(x)
if coord == 1:
x = x._op1
id_dic[id(active_node)] = [list_index_traversed]
match = False
# Check vertices for nodes
if not match:
recognize_nodes(active_node._op1, list_index_traversed, True)
recognize_nodes(active_node._op2, list_index_traversed, False)
list_index_traversed += 1
return
equal_nodes(op)
# Cut subtrees
key_list_tree = []
same_tree = {}
edited = set()
for item in list(id_dic.items()):
if len(item[1]) > 1:
key_list_tree.append(item[0])
for key in key_list_tree:
same_tree[key] = [nodes[id_dic[key][0]][0],]
performance_adapter = FieldAdapter(same_tree[key][0].target, str(key))
same_tree[key] += [performance_adapter]
for node_indices in id_dic[key]:
edited.add(node_indices)
pos = nodes[node_indices][3]
parent = nodes[nodes[node_indices][1]][0]
#TODO: cut subtrees can also be regarded as nodes
#Some kind of order is needed to get the domains right at the end...
#leaves.add((node_indices, nodes[node_indices][2]))
attr = left_parser(nodes[node_indices][2])
if pos is not -1:
getattr(parent, attr)._ops = getattr(parent, attr)._ops[:pos] + (performance_adapter,) + \
getattr(parent, attr)._ops[pos+1:]
else:
setattr(parent, attr, performance_adapter)
# Compare leaves
id_leave = {}
# Find matching leaves
for leave in leaves:
parent = nodes[leave[0]][0]
attr = left_parser(leave[1])
leave_op = getattr(parent, attr)
if isinstance(leave_op, _OpChain):
leave_op = leave_op._ops[-1]
try:
id_leave[id(leave_op)] = id_leave[id(leave_op)] + (leave,)
except KeyError:
id_leave[id(leave_op)] = (leave, )
# Unroll their OpChain and see how far they are equal
# TODO: Repeat detection, to handle eg. op2(op) + op3(op) + op2(op)
key_list_op = []
same_op = {}
for item in list(id_leave.items()):
if len(item[1]) > 1:
key_list_op.append(item[0])
for key in key_list_op:
to_compare = []
for leave in id_leave[key]:
parent = nodes[leave[0]][0]
attr = left_parser(leave[1])
leave_op = getattr(parent, attr)
if isinstance(leave_op, _OpChain):
to_compare.append(leave_op._ops)
else:
x = x._op2
if isinstance(x, ift.operators.operator._OpChain):
visited.append(x)
x = x._ops[-1]
visited.append(x)
# Substitute subtrees
if coord_list[-1] == 1:
x._op1 = performance_adapter
x._op1._domain = performance_adapter.domain
else:
x._op2 = performance_adapter
x._op2._domain = performance_adapter.domain
for v in reversed(visited):
if isinstance(v, ift.operators.operator._OpChain):
v._domain = v._ops[-1].domain
if isinstance(v, ift.operators.operator._OpSum) or isinstance(v, ift.operators.operator._OpProd):
v._domain = ift.sugar.domain_union((v._op1.domain, v._op2.domain))
op = op.partial_insert(performance_adapter.adjoint(same_op))
op = optimize_subtrees(op)
to_compare.append((leave_op,))
first_difference = 1
try:
while to_compare[1:][first_difference] == to_compare[-1:][first_difference]:
first_difference += 1
except IndexError:
pass
# Do not optimise leaves which only have equal FieldAdapters
if first_difference >= int(isinstance(to_compare[0][0], FieldAdapter)):
to_compare[0][:first_difference]
common_op = to_compare[0][:first_difference]
res_op = common_op[-1]
for ops in reversed(common_op[:-1]):
res_op = res_op @ ops
same_op[key] = [res_op, FieldAdapter(res_op.target, str(id(res_op)))]
for leave in id_leave[key]:
parent = nodes[leave[0]][0]
edited.add(id_dic[id(parent)][0])
attr = left_parser(leave[1])
leave_op = getattr(parent, attr)
if isinstance(leave_op, _OpChain):
leave_op._ops = leave_op._ops[:first_difference] + (same_op[key][1],)
else:
setattr(parent, attr, same_op[key][1])
for index in edited:
rebuild_domains(index)
if isinstance(op, _OpChain):
op._domain = op._ops[-1].domain
# Insert trees before leaves
for key in key_list_tree:
op = op.partial_insert(same_tree[key][1].adjoint(same_tree[key][0]))
for key in reversed(key_list_op):
op = op.partial_insert(same_op[key][1].adjoint(same_op[key][0]))
return op
# Examples for operator above
dom = ift.UnstructuredDomain([10000])
uni = ift.UniformOperator(dom)
# It needs to be defined on a multidomain, so that one can recognize the leaves
uni_t = uni.ducktape('test')
ig = ift.InverseGammaOperator(dom, 1, 1)
#Now this is optimised:
op = ig(uni_t + ig(uni_t) + uni_t)
h = ift.from_random('normal', op.domain)
# %timeit op(h)
# 2.7 ms
op = optimize_all_nodes(op)
# %timeit op(h)
# 2.3 ms
op = optimize_subtrees(op)
# %timeit op(h)
# 2.3 ms (only replaces one very fast uni operation in this example)
# However, still some improvements:
# 0. Bugs
# 1. Currently only searching for nodes at the end of operator chains, but
optimize_all_nodes( (uni + uni)(uni + uni) )
# can happen
# 2. Subtrees are only replaced at node points, should also compare the vertices above and replace it
# 3. Only comparing by ids is done, one might add a cache to prevent multiple operators from going unnoticed,
# even though this shouldn't be the norm
optimize_subtrees(op = ig(uni_t + ig(uni_t) + uni_t))
optimize_subtrees(op = ig(uni.ducktape('test') + ig(uni.ducktape('test')) + uni.ducktape('test')))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment