Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
On Thursday, 7th July from 1 to 3 pm there will be a maintenance with a short downtime of GitLab.
Open sidebar
ift
NIFTy
Commits
a1fcf437
Commit
a1fcf437
authored
Mar 30, 2020
by
Rouven Lemmerz
Browse files
Improved algorithm
parent
1fbf8e2a
Pipeline
#71682
passed with stages
in 15 minutes and 13 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/operator_tree_optimiser.py
View file @
a1fcf437
import
nifty6
as
ift
import
numpy
as
np
def
optimize_node
(
op
):
"""Takes an operator node (_OpSum or _OpProd) and searches for similar substructures in both vertices.
Only searches one vertex deep.
"""
if
isinstance
(
op
,
ift
.
operators
.
operator
.
_OpSum
):
sum
=
True
elif
isinstance
(
op
,
ift
.
operators
.
operator
.
_OpProd
):
sum
=
False
else
:
return
op
def
to_list
(
x
):
if
isinstance
(
x
,
ift
.
operators
.
operator
.
_OpChain
):
op_list
=
list
(
reversed
(
x
.
_ops
))
else
:
op_list
=
[
x
,]
return
op_list
op_list
=
[
to_list
(
op
.
_op1
),
to_list
(
op
.
_op2
)]
first_difference
=
0
try
:
while
op_list
[
0
][
first_difference
]
is
op_list
[
1
][
first_difference
]:
first_difference
=
first_difference
+
1
except
IndexError
:
pass
if
first_difference
==
0
:
return
op
common_op
=
op_list
[
0
][:
first_difference
]
res_op
=
common_op
[
-
1
]
for
ops
in
reversed
(
common_op
[:
-
1
]):
res_op
=
res_op
@
ops
performance_adapter
=
ift
.
FieldAdapter
(
res_op
.
target
,
str
(
id
(
res_op
)))
vertex
=
[
0
,
0
]
for
i
in
range
(
len
(
op_list
)):
op_list
[
i
][:
first_difference
]
=
[
performance_adapter
]
vertex
[
i
]
=
op_list
[
i
][
-
1
]
for
ops
in
reversed
(
op_list
[
i
][:
-
1
]):
vertex
[
i
]
=
vertex
[
i
]
@
ops
# This seems broken
# op._op1 = vertex[0]
# op._op2 = vertex[1]
# op._domain = ift.sugar.domain_union((op._op1.domain, op._op2.domain))
if
sum
:
op
=
vertex
[
0
]
+
vertex
[
1
]
else
:
op
=
vertex
[
0
]
*
vertex
[
1
]
op
=
op
.
partial_insert
(
performance_adapter
.
adjoint
(
res_op
))
return
op
def
optimize_all_nodes
(
op
):
"""Traverses the tree and applies optimization to every node"""
if
isinstance
(
op
,
ift
.
operators
.
operator
.
_OpChain
):
x
=
op
.
_ops
[
-
1
]
chained
=
True
else
:
x
=
op
chained
=
False
if
isinstance
(
x
,
ift
.
operators
.
operator
.
_OpSum
)
or
isinstance
(
x
,
ift
.
operators
.
operator
.
_OpProd
):
#postorder traversing
x
.
_op1
=
optimize_all_nodes
(
x
.
_op1
)
x
.
_op2
=
optimize_all_nodes
(
x
.
_op2
)
x
=
optimize_node
(
x
)
if
chained
:
op
.
_ops
=
op
.
_ops
[:
-
1
]
+
(
x
,)
op
.
_domain
=
x
.
domain
else
:
op
=
x
return
op
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Some Examples for above
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
nifty6.operators.operator
import
_OpChain
,
_OpSum
,
_OpProd
from
nifty6.sugar
import
domain_union
from
nifty6
import
FieldAdapter
class
CountingOp
(
ift
.
LinearOperator
):
def
__init__
(
self
,
domain
):
self
.
_domain
=
self
.
_target
=
ift
.
sugar
.
makeDomain
(
domain
)
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_count
=
0
def
apply
(
self
,
x
,
mode
):
self
.
_count
+=
1
return
x
@
property
def
count
(
self
):
return
self
.
_count
dom
=
ift
.
DomainTuple
.
scalar_domain
()
h
=
ift
.
from_random
(
'normal'
,
dom
)
copuni
=
CountingOp
(
dom
)
copig
=
CountingOp
(
dom
)
uni
=
copuni
@
ift
.
UniformOperator
(
dom
)
ig
=
copig
@
ift
.
InverseGammaOperator
(
dom
,
1
,
1
)
op
=
uni
+
ig
(
uni
)
op
(
h
)
print
(
copuni
.
count
)
# is increased by 2
print
(
copig
.
count
)
op2
=
optimize_all_nodes
(
op
)
op2
(
h
)
print
(
copuni
.
count
)
# only increased by 1
print
(
copig
.
count
)
np
.
allclose
(
op
(
h
).
val
,
op2
(
h
).
val
)
# More complex things work partially:
op
=
ig
(
uni
+
uni
+
uni
(
ig
))
op2
=
optimize_all_nodes
(
op
)
# However, since the search depth is only one vertex, this is not optimised:
op
=
ig
(
uni
+
uni
(
ig
)
+
uni
)
op2
=
optimize_all_nodes
(
op
)
# To find bigger chunks of subtrees the following function is defined:
def
optimize_subtrees
(
op
):
"""Recognizes same subtrees and replaces them.
Currently only works on operators defined on Multidomains.
Should work inplace"""
if
not
isinstance
(
op
.
domain
,
ift
.
MultiDomain
):
raise
TypeError
(
'Operator needs to be defined on a multidomain'
)
dic
=
{}
dic_list
=
{}
def
equal_vertices
(
x
,
coord
):
if
isinstance
(
x
,
ift
.
operators
.
operator
.
_OpChain
):
def
optimize_operator
(
op
):
"""
Optimizes operator trees, so that same operator subtrees are not computed twice.
Recognizes same subtrees and replaces them at nodes.
Recognizes same leaves and structures them.
Works partly inplace, rendering the old operator unusable"""
# Format: List of tuple [op, parent_index, left=True right=False, pos in op_chain]
nodes
=
[]
# Format: ID: index in nodes[]
id_dic
=
{}
# Format: [parent_index, left]
leaves
=
set
()
def
isnode
(
op
):
return
isinstance
(
op
,
_OpSum
)
or
isinstance
(
op
,
_OpProd
)
def
left_parser
(
left_bool
):
if
left_bool
:
return
'_op1'
return
'_op2'
def
rebuild_domains
(
index
):
"""Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
cond
=
True
while
cond
:
op
=
nodes
[
index
][
0
]
for
attr
in
(
'_op1'
,
'_op2'
):
if
isinstance
(
getattr
(
op
,
attr
),
_OpChain
):
getattr
(
op
,
attr
).
_domain
=
getattr
(
op
,
attr
).
_ops
[
-
1
].
domain
if
isnode
(
op
):
op
.
_domain
=
domain_union
((
op
.
_op1
.
domain
,
op
.
_op2
.
domain
))
index
=
nodes
[
index
][
1
]
cond
=
type
(
index
)
is
int
return
def
recognize_nodes
(
op
,
active_node
,
left
):
# If nothing added - is a leave!
isleave
=
True
if
isinstance
(
op
,
_OpChain
):
for
i
in
range
(
len
(
op
.
_ops
)):
if
isnode
(
op
.
_ops
[
i
]):
nodes
.
append
((
op
.
_ops
[
i
],
active_node
,
left
,
i
))
isleave
=
False
elif
isnode
(
op
):
nodes
.
append
((
op
,
active_node
,
left
,
-
1
))
isleave
=
False
if
isleave
:
leaves
.
add
((
active_node
,
left
))
return
def
equal_nodes
(
op
):
"""BFS-Algorithm which fills the nodes list and id_dic dictionary.
Does not scan equal subtrees multiple times."""
list_index_traversed
=
0
# if isnode(op):
# nodes.append((op, None, None, None))
recognize_nodes
(
op
,
None
,
None
)
while
list_index_traversed
<
len
(
nodes
):
# Visit node
active_node
=
nodes
[
list_index_traversed
][
0
]
# Check whether exists already
try
:
# Might be better to save parent + left or right
# However, this makes adjusting the operator domains hard later on
dic
[
id
(
x
)]
=
[
coord
,
]
+
dic
[
id
(
x
)]
#dic_list[id(x)] = dic_list[id(x)]
id_dic
[
id
(
active_node
)]
=
id_dic
[
id
(
active_node
)]
+
[
list_index_traversed
]
match
=
True
except
KeyError
:
dic
[
id
(
x
)]
=
[
coord
,
]
dic_list
[
id
(
x
)]
=
x
x
=
x
.
_ops
[
-
1
]
if
isinstance
(
x
,
ift
.
operators
.
operator
.
_OpSum
)
or
isinstance
(
x
,
ift
.
operators
.
operator
.
_OpProd
):
equal_vertices
(
x
.
_op1
,
coord
+
[
1
])
equal_vertices
(
x
.
_op2
,
coord
+
[
2
])
equal_vertices
(
op
,
[])
key
=
None
#Heuristically, the first entry should be the largest subtree
for
items
in
list
(
dic
.
items
()):
if
len
(
items
[
1
])
>
1
:
# Multiple Ops are the same
key
=
items
[
0
]
break
if
key
is
None
:
return
op
same_op
=
dic_list
[
key
]
performance_adapter
=
ift
.
FieldAdapter
(
same_op
.
target
,
str
(
key
))
visited
=
[]
for
coord_list
in
dic
[
key
]:
x
=
op
for
coord
in
coord_list
[:
-
1
]:
# Travel to the nodes
if
isinstance
(
x
,
ift
.
operators
.
operator
.
_OpChain
):
visited
.
append
(
x
)
x
=
x
.
_ops
[
-
1
]
visited
.
append
(
x
)
if
coord
==
1
:
x
=
x
.
_op1
id_dic
[
id
(
active_node
)]
=
[
list_index_traversed
]
match
=
False
# Check vertices for nodes
if
not
match
:
recognize_nodes
(
active_node
.
_op1
,
list_index_traversed
,
True
)
recognize_nodes
(
active_node
.
_op2
,
list_index_traversed
,
False
)
list_index_traversed
+=
1
return
equal_nodes
(
op
)
# Cut subtrees
key_list_tree
=
[]
same_tree
=
{}
edited
=
set
()
for
item
in
list
(
id_dic
.
items
()):
if
len
(
item
[
1
])
>
1
:
key_list_tree
.
append
(
item
[
0
])
for
key
in
key_list_tree
:
same_tree
[
key
]
=
[
nodes
[
id_dic
[
key
][
0
]][
0
],]
performance_adapter
=
FieldAdapter
(
same_tree
[
key
][
0
].
target
,
str
(
key
))
same_tree
[
key
]
+=
[
performance_adapter
]
for
node_indices
in
id_dic
[
key
]:
edited
.
add
(
node_indices
)
pos
=
nodes
[
node_indices
][
3
]
parent
=
nodes
[
nodes
[
node_indices
][
1
]][
0
]
#TODO: cut subtrees can also be regarded as nodes
#Some kind of order is needed to get the domains right at the end...
#leaves.add((node_indices, nodes[node_indices][2]))
attr
=
left_parser
(
nodes
[
node_indices
][
2
])
if
pos
is
not
-
1
:
getattr
(
parent
,
attr
).
_ops
=
getattr
(
parent
,
attr
).
_ops
[:
pos
]
+
(
performance_adapter
,)
+
\
getattr
(
parent
,
attr
).
_ops
[
pos
+
1
:]
else
:
setattr
(
parent
,
attr
,
performance_adapter
)
# Compare leaves
id_leave
=
{}
# Find matching leaves
for
leave
in
leaves
:
parent
=
nodes
[
leave
[
0
]][
0
]
attr
=
left_parser
(
leave
[
1
])
leave_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
leave_op
,
_OpChain
):
leave_op
=
leave_op
.
_ops
[
-
1
]
try
:
id_leave
[
id
(
leave_op
)]
=
id_leave
[
id
(
leave_op
)]
+
(
leave
,)
except
KeyError
:
id_leave
[
id
(
leave_op
)]
=
(
leave
,
)
# Unroll their OpChain and see how far they are equal
# TODO: Repeat detection, to handle eg. op2(op) + op3(op) + op2(op)
key_list_op
=
[]
same_op
=
{}
for
item
in
list
(
id_leave
.
items
()):
if
len
(
item
[
1
])
>
1
:
key_list_op
.
append
(
item
[
0
])
for
key
in
key_list_op
:
to_compare
=
[]
for
leave
in
id_leave
[
key
]:
parent
=
nodes
[
leave
[
0
]][
0
]
attr
=
left_parser
(
leave
[
1
])
leave_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
leave_op
,
_OpChain
):
to_compare
.
append
(
leave_op
.
_ops
)
else
:
x
=
x
.
_op2
if
isinstance
(
x
,
ift
.
operators
.
operator
.
_OpChain
):
visited
.
append
(
x
)
x
=
x
.
_ops
[
-
1
]
visited
.
append
(
x
)
# Substitute subtrees
if
coord_list
[
-
1
]
==
1
:
x
.
_op1
=
performance_adapter
x
.
_op1
.
_domain
=
performance_adapter
.
domain
else
:
x
.
_op2
=
performance_adapter
x
.
_op2
.
_domain
=
performance_adapter
.
domain
for
v
in
reversed
(
visited
):
if
isinstance
(
v
,
ift
.
operators
.
operator
.
_OpChain
):
v
.
_domain
=
v
.
_ops
[
-
1
].
domain
if
isinstance
(
v
,
ift
.
operators
.
operator
.
_OpSum
)
or
isinstance
(
v
,
ift
.
operators
.
operator
.
_OpProd
):
v
.
_domain
=
ift
.
sugar
.
domain_union
((
v
.
_op1
.
domain
,
v
.
_op2
.
domain
))
op
=
op
.
partial_insert
(
performance_adapter
.
adjoint
(
same_op
))
op
=
optimize_subtrees
(
op
)
to_compare
.
append
((
leave_op
,))
first_difference
=
1
try
:
while
to_compare
[
1
:][
first_difference
]
==
to_compare
[
-
1
:][
first_difference
]:
first_difference
+=
1
except
IndexError
:
pass
# Do not optimise leaves which only have equal FieldAdapters
if
first_difference
>=
int
(
isinstance
(
to_compare
[
0
][
0
],
FieldAdapter
)):
to_compare
[
0
][:
first_difference
]
common_op
=
to_compare
[
0
][:
first_difference
]
res_op
=
common_op
[
-
1
]
for
ops
in
reversed
(
common_op
[:
-
1
]):
res_op
=
res_op
@
ops
same_op
[
key
]
=
[
res_op
,
FieldAdapter
(
res_op
.
target
,
str
(
id
(
res_op
)))]
for
leave
in
id_leave
[
key
]:
parent
=
nodes
[
leave
[
0
]][
0
]
edited
.
add
(
id_dic
[
id
(
parent
)][
0
])
attr
=
left_parser
(
leave
[
1
])
leave_op
=
getattr
(
parent
,
attr
)
if
isinstance
(
leave_op
,
_OpChain
):
leave_op
.
_ops
=
leave_op
.
_ops
[:
first_difference
]
+
(
same_op
[
key
][
1
],)
else
:
setattr
(
parent
,
attr
,
same_op
[
key
][
1
])
for
index
in
edited
:
rebuild_domains
(
index
)
if
isinstance
(
op
,
_OpChain
):
op
.
_domain
=
op
.
_ops
[
-
1
].
domain
# Insert trees before leaves
for
key
in
key_list_tree
:
op
=
op
.
partial_insert
(
same_tree
[
key
][
1
].
adjoint
(
same_tree
[
key
][
0
]))
for
key
in
reversed
(
key_list_op
):
op
=
op
.
partial_insert
(
same_op
[
key
][
1
].
adjoint
(
same_op
[
key
][
0
]))
return
op
# Examples for operator above
dom
=
ift
.
UnstructuredDomain
([
10000
])
uni
=
ift
.
UniformOperator
(
dom
)
# It needs to be defined on a multidomain, so that one can recognize the leaves
uni_t
=
uni
.
ducktape
(
'test'
)
ig
=
ift
.
InverseGammaOperator
(
dom
,
1
,
1
)
#Now this is optimised:
op
=
ig
(
uni_t
+
ig
(
uni_t
)
+
uni_t
)
h
=
ift
.
from_random
(
'normal'
,
op
.
domain
)
# %timeit op(h)
# 2.7 ms
op
=
optimize_all_nodes
(
op
)
# %timeit op(h)
# 2.3 ms
op
=
optimize_subtrees
(
op
)
# %timeit op(h)
# 2.3 ms (only replaces one very fast uni operation in this example)
# However, still some improvements:
# 0. Bugs
# 1. Currently only searching for nodes at the end of operator chains, but
optimize_all_nodes
(
(
uni
+
uni
)(
uni
+
uni
)
)
# can happen
# 2. Subtrees are only replaced at node points, should also compare the vertices above and replace it
# 3. Only comparing by ids is done, one might add a cache to prevent multiple operators from going unnoticed,
# even though this shouldn't be the norm
optimize_subtrees
(
op
=
ig
(
uni_t
+
ig
(
uni_t
)
+
uni_t
))
optimize_subtrees
(
op
=
ig
(
uni
.
ducktape
(
'test'
)
+
ig
(
uni
.
ducktape
(
'test'
))
+
uni
.
ducktape
(
'test'
)))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment