operator_tree_optimiser.py 9.48 KB
Newer Older
Rouven Lemmerz's avatar
Rouven Lemmerz committed
1
2
3
4
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
Rouven Lemmerz's avatar
Rouven Lemmerz committed
5
#
Rouven Lemmerz's avatar
Rouven Lemmerz committed
6
7
8
9
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
Rouven Lemmerz's avatar
Rouven Lemmerz committed
10
#
Rouven Lemmerz's avatar
Rouven Lemmerz committed
11
12
13
14
15
16
17
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
18
19
20
from .operators.operator import _OpChain, _OpSum, _OpProd
from .sugar import domain_union
from .operators.simple_linear_operators import FieldAdapter
Rouven Lemmerz's avatar
Rouven Lemmerz committed
21

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
22
def _optimize_operator(op):
Rouven Lemmerz's avatar
Rouven Lemmerz committed
23
24
25
26
27
28
    """
    Optimizes operator trees, so that same operator subtrees are not computed twice.
    Recognizes same subtrees and replaces them at nodes.
    Recognizes same leaves and structures them.
    Works partly inplace, rendering the old operator unusable"""

29
    # Format: List of tuple [op, parent_index, left=True right=False]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
30
31
32
33
34
35
    nodes = []
    # Format: ID: index in nodes[]
    id_dic = {}
    # Format: [parent_index, left]
    leaves = set()

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
36

Rouven Lemmerz's avatar
Rouven Lemmerz committed
37
    def isnode(op):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
38
39
        return isinstance(op, (_OpSum, _OpProd))

Rouven Lemmerz's avatar
Rouven Lemmerz committed
40
41

    def left_parser(left_bool):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
42
43
        return '_op1' if left_bool else '_op2'

Rouven Lemmerz's avatar
Rouven Lemmerz committed
44
45
46
47
48
49
50
51
52
53

    def rebuild_domains(index):
        """Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
        cond = True
        while cond:
            op = nodes[index][0]
            for attr in ('_op1', '_op2'):
                if isinstance(getattr(op, attr), _OpChain):
                    getattr(op, attr)._domain =  getattr(op, attr)._ops[-1].domain
            if isnode(op):
54
55
56
57
58
59
60
                # Some problems doing this on non-multidomains, because one side becomes a multidomain and the other not
                try:
                    op._domain = domain_union((op._op1.domain, op._op2.domain))
                except:
                    import warnings
                    warnings.warn('Operator should be defined on a MultiDomain')
                    pass
Rouven Lemmerz's avatar
Rouven Lemmerz committed
61
62
63

            index = nodes[index][1]
            cond = type(index) is int
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
64

Rouven Lemmerz's avatar
Rouven Lemmerz committed
65
66

    def recognize_nodes(op, active_node, left):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
67
68
        # If nothing added - is a leaf!
        isleaf = True
Rouven Lemmerz's avatar
Rouven Lemmerz committed
69
70
71
        if isinstance(op, _OpChain):
           for i in range(len(op._ops)):
                if isnode(op._ops[i]):
72
                    nodes.append((op._ops[i], active_node, left))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
73
                    isleaf = False
Rouven Lemmerz's avatar
Rouven Lemmerz committed
74
        elif isnode(op):
75
            nodes.append((op, active_node, left))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
76
77
            isleaf = False
        if isleaf:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
78
            leaves.add((active_node, left))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
79

Rouven Lemmerz's avatar
Rouven Lemmerz committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    def equal_nodes(op):
        """BFS-Algorithm which fills the nodes list and id_dic dictionary.
        Does not scan equal subtrees multiple times."""
        list_index_traversed = 0
        # if isnode(op):
        #     nodes.append((op, None, None, None))
        recognize_nodes(op, None, None)

        while list_index_traversed < len(nodes):
            # Visit node
            active_node = nodes[list_index_traversed][0]

            # Check whether exists already
Rouven Lemmerz's avatar
Rouven Lemmerz committed
94
            try:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
95
96
                id_dic[id(active_node)] = id_dic[id(active_node)] + [list_index_traversed]
                match = True
Rouven Lemmerz's avatar
Rouven Lemmerz committed
97
            except KeyError:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
98
99
100
101
102
103
104
105
                id_dic[id(active_node)] = [list_index_traversed]
                match = False
            # Check vertices for nodes
            if not match:
                recognize_nodes(active_node._op1, list_index_traversed, True)
                recognize_nodes(active_node._op2, list_index_traversed, False)

            list_index_traversed += 1
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
106

Rouven Lemmerz's avatar
Rouven Lemmerz committed
107

108
    def equal_leaves(leaves, edited):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
109
        id_leaf = {}
110
        # Find matching leaves
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
111
        def write_to_dic(leaf, leaf_op_id):
112
            try:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
113
                id_leaf[leaf_op_id] = id_leaf[leaf_op_id] + (leaf,)
114
            except KeyError:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
115
116
117
118
119
120
121
122
123
124
                id_leaf[leaf_op_id] = (leaf,)

        for leaf in leaves:
            parent = nodes[leaf[0]][0]
            attr = left_parser(leaf[1])
            leaf_op = getattr(parent, attr)
            if isinstance(leaf_op, _OpChain):
                leaf_op_id = ''
                for i in reversed(leaf_op._ops):
                    leaf_op_id += str(id(i))
125
126
                    if not isinstance(i, FieldAdapter):
                        # Do not optimise leaves which only have equal FieldAdapters
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
127
                        write_to_dic(leaf, leaf_op_id)
128
                        break
Rouven Lemmerz's avatar
Rouven Lemmerz committed
129
            else:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
130
131
                if not isinstance(leaf_op, FieldAdapter):
                    write_to_dic(leaf, str(id(leaf_op)))
132
133
134
135
136
137


        # Unroll their OpChain and see how far they are equal
        key_list_op = []
        same_op = {}

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
138
        for item in list(id_leaf.items()):
139
140
141
142
            if len(item[1]) > 1:
                key_list_op.append(item[0])
        for key in key_list_op:
            to_compare = []
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
143
144
145
146
147
148
            for leaf in id_leaf[key]:
                parent = nodes[leaf[0]][0]
                attr = left_parser(leaf[1])
                leaf_op = getattr(parent, attr)
                if isinstance(leaf_op, _OpChain):
                    to_compare.append(tuple(reversed(leaf_op._ops)))
149
                else:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
150
                    to_compare.append((leaf_op,))
151
152
153
            first_difference = 1
            max_diff = min(len(i) for i in to_compare)
            if not max_diff == 1:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
154
155
                compare_iterator = iter(to_compare)
                first = next(compare_iterator)
156
157
158
159
160
161
                while all(first[first_difference] == rest[first_difference] for rest in compare_iterator):
                    first_difference += 1
                    if first_difference >= max_diff:
                        break
                    compare_iterator = iter(to_compare)
                    first = next(compare_iterator)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
162
163

            common_op = to_compare[0][:first_difference]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
164
165
166
            res_op = common_op[0]
            for ops in common_op[1:]:
                res_op = ops @ res_op
Rouven Lemmerz's avatar
Rouven Lemmerz committed
167
168
169

            same_op[key] = [res_op, FieldAdapter(res_op.target, str(id(res_op)))]

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
170
171
            for leaf in id_leaf[key]:
                parent = nodes[leaf[0]][0]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
172
                edited.add(id_dic[id(parent)][0])
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
173
174
175
176
                attr = left_parser(leaf[1])
                leaf_op = getattr(parent, attr)
                if isinstance(leaf_op, _OpChain):
                    if first_difference == len(leaf_op._ops):
Rouven Lemmerz's avatar
Rouven Lemmerz committed
177
178
                        setattr(parent, attr, same_op[key][1])
                    else:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
179
                        leaf_op._ops = leaf_op._ops[:-first_difference] + (same_op[key][1],)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
180
181
                else:
                    setattr(parent, attr, same_op[key][1])
182
183
        return key_list_op, same_op, edited

184
185
    equal_nodes(op)

186
187
188
189
190
191
192
193
194
195
196
197
198
199
    edited = set()
    key_list_op, same_op, edited = equal_leaves(leaves, edited)
    cond = True
    while cond:
        key_temp, same_op_temp, edited_temp = equal_leaves(leaves, edited)
        key_list_op += key_temp
        same_op.update(same_op_temp)
        edited.update(edited)
        cond = len(same_op_temp) > 0

    # Cut subtrees
    key_list_tree = []
    same_tree = {}
    subtree_leaves = set()
200
201
    key_list_tree_w_leaves = []
    same_tree_w_leaves = {}
202
203
204
205
206
207
208
209
210
211
212
213
214
    for item in list(id_dic.items()):
        if len(item[1]) > 1:
            key_list_tree.append(item[0])

    for key in key_list_tree:
        same_tree[key] = [nodes[id_dic[key][0]][0],]
        performance_adapter = FieldAdapter(same_tree[key][0].target, str(key))
        same_tree[key] += [performance_adapter]

        for node_indices in id_dic[key]:
            edited.add(node_indices)
            parent = nodes[nodes[node_indices][1]][0]
            attr = left_parser(nodes[node_indices][2])
215
216
            if isinstance(getattr(parent, attr), _OpChain):
                getattr(parent, attr)._ops = getattr(parent, attr)._ops[:-1] + (performance_adapter,)
217
218
            else:
                setattr(parent, attr, performance_adapter)
219
220
221
222
223
224
225
226
227
228
            subtree_leaves.add((nodes[node_indices][1], nodes[node_indices][2]))
        cond = True
        while cond:
            key_temp, same_op_temp, _ = equal_leaves(subtree_leaves, edited)
            key_list_tree_w_leaves += key_temp
            same_tree_w_leaves.update(same_op_temp)
            cond = len(same_op_temp) > 0
        key_list_tree_w_leaves += [key,]
        subtree_leaves.clear()
    same_tree_w_leaves.update(same_tree)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
229
230
231
232
233
234
235

    for index in edited:
        rebuild_domains(index)
    if isinstance(op, _OpChain):
        op._domain = op._ops[-1].domain

    # Insert trees before leaves
236
237
    for key in key_list_tree_w_leaves:
        op = op.partial_insert(same_tree_w_leaves[key][1].adjoint(same_tree_w_leaves[key][0]))
238
    for key in reversed(key_list_op):
Rouven Lemmerz's avatar
Rouven Lemmerz committed
239
        op = op.partial_insert(same_op[key][1].adjoint(same_op[key][0]))
Rouven Lemmerz's avatar
Rouven Lemmerz committed
240
    return op
Rouven Lemmerz's avatar
Rouven Lemmerz committed
241

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
242

Rouven Lemmerz's avatar
Rouven Lemmerz committed
243
from copy import deepcopy
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
244
245
from .sugar import from_random
from .multi_field import MultiField
Rouven Lemmerz's avatar
Rouven Lemmerz committed
246
from numpy import allclose
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
247
248

def optimize_operator(op):
Rouven Lemmerz's avatar
Rouven Lemmerz committed
249
    op_optimized = deepcopy(op)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
250
    op_optimized = _optimize_operator(op_optimized)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
251
252
253
    test_field = from_random('normal', op.domain)
    if isinstance(op(test_field), MultiField):
        for key in op(test_field).keys():
Rouven Lemmerz's avatar
Rouven Lemmerz committed
254
            assert allclose(op(test_field).val[key], op_optimized(test_field).val[key], 1e-10)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
255
    else:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
256
        assert allclose(op(test_field).val, op_optimized(test_field).val, 1e-10)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
257
    return op_optimized