operator_tree_optimiser.py 11.2 KB
Newer Older
Rouven Lemmerz's avatar
Rouven Lemmerz committed
1
2
3
4
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
Rouven Lemmerz's avatar
Rouven Lemmerz committed
5
#
Rouven Lemmerz's avatar
Rouven Lemmerz committed
6
7
8
9
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
Rouven Lemmerz's avatar
Rouven Lemmerz committed
10
#
Rouven Lemmerz's avatar
Rouven Lemmerz committed
11
12
13
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2021 Max-Planck-Society
Rouven Lemmerz's avatar
Rouven Lemmerz committed
15
16
17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

Philipp Arras's avatar
Philipp Arras committed
18
19
20
21
22
23
from copy import deepcopy

from numpy import allclose

from .multi_field import MultiField
from .operators.operator import _OpChain, _OpProd, _OpSum
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
24
from .operators.simple_linear_operators import FieldAdapter
Philipp Arras's avatar
Philipp Arras committed
25
from .sugar import domain_union, from_random
26
from .utilities import myassert
Philipp Arras's avatar
Philipp Arras committed
27

Rouven Lemmerz's avatar
Rouven Lemmerz committed
28

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
29
def _optimise_operator(op):
Rouven Lemmerz's avatar
Rouven Lemmerz committed
30
    """
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
31
    optimises operator trees, so that same operator subtrees are not computed twice.
Rouven Lemmerz's avatar
Rouven Lemmerz committed
32
33
34
35
    Recognizes same subtrees and replaces them at nodes.
    Recognizes same leaves and structures them.
    Works partly inplace, rendering the old operator unusable"""

36
    # Format: List of tuple [op, parent_index, left=True right=False]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
37
38
39
40
41
42
    nodes = []
    # Format: ID: index in nodes[]
    id_dic = {}
    # Format: [parent_index, left]
    leaves = set()

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    # helper functions
    def readable_id():
        # Gives out letters to prepend field_adapter ids for cosmetics
        # Start at 'A'
        current_letter = 65
        repeats = 1
        while True:
            yield chr(current_letter)*repeats
            current_letter += 1
            if current_letter == 91:
                # skip specials
                current_letter += 6
            elif current_letter == 123:
                # End at z and start at AA
                current_letter = 65
                repeats += 1

    prepend_id = readable_id()
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
61

Rouven Lemmerz's avatar
Rouven Lemmerz committed
62
    def isnode(op):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
63
64
        return isinstance(op, (_OpSum, _OpProd))

Rouven Lemmerz's avatar
Rouven Lemmerz committed
65
66

    def left_parser(left_bool):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
67
68
        return '_op1' if left_bool else '_op2'

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
69
70
71
72
    def get_duplicate_keys(k_list, dic):
        for item in list(dic.items()):
            if len(item[1]) > 1:
                k_list.append(item[0])
Rouven Lemmerz's avatar
Rouven Lemmerz committed
73

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
74
    # Main algorithm functions
Rouven Lemmerz's avatar
Rouven Lemmerz committed
75
76
77
78
79
80
81
    def rebuild_domains(index):
        """Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
        cond = True
        while cond:
            op = nodes[index][0]
            for attr in ('_op1', '_op2'):
                if isinstance(getattr(op, attr), _OpChain):
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
82
                    getattr(op, attr)._domain = getattr(op, attr)._ops[-1].domain
Rouven Lemmerz's avatar
Rouven Lemmerz committed
83
            if isnode(op):
84
85
86
                # Some problems doing this on non-multidomains, because one side becomes a multidomain and the other not
                try:
                    op._domain = domain_union((op._op1.domain, op._op2.domain))
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
87
                except AttributeError:
88
89
90
                    import warnings
                    warnings.warn('Operator should be defined on a MultiDomain')
                    pass
Rouven Lemmerz's avatar
Rouven Lemmerz committed
91
92
93

            index = nodes[index][1]
            cond = type(index) is int
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
94

Rouven Lemmerz's avatar
Rouven Lemmerz committed
95
96

    def recognize_nodes(op, active_node, left):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
97
98
        # If nothing added - is a leaf!
        isleaf = True
Rouven Lemmerz's avatar
Rouven Lemmerz committed
99
100
101
        if isinstance(op, _OpChain):
           for i in range(len(op._ops)):
                if isnode(op._ops[i]):
102
                    nodes.append((op._ops[i], active_node, left))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
103
                    isleaf = False
Rouven Lemmerz's avatar
Rouven Lemmerz committed
104
        elif isnode(op):
105
            nodes.append((op, active_node, left))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
106
107
            isleaf = False
        if isleaf:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
108
            leaves.add((active_node, left))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
109

Rouven Lemmerz's avatar
Rouven Lemmerz committed
110
111

    def equal_nodes(op):
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
112
113
        # BFS-Algorithm which fills the nodes list and id_dic dictionary
        # Does not scan equal subtrees multiple times
Rouven Lemmerz's avatar
Rouven Lemmerz committed
114
115
116
117
118
119
120
121
        list_index_traversed = 0
        recognize_nodes(op, None, None)

        while list_index_traversed < len(nodes):
            # Visit node
            active_node = nodes[list_index_traversed][0]

            # Check whether exists already
Rouven Lemmerz's avatar
Rouven Lemmerz committed
122
            try:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
123
124
                id_dic[id(active_node)] = id_dic[id(active_node)] + [list_index_traversed]
                match = True
Rouven Lemmerz's avatar
Rouven Lemmerz committed
125
            except KeyError:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
126
127
128
129
130
131
132
133
                id_dic[id(active_node)] = [list_index_traversed]
                match = False
            # Check vertices for nodes
            if not match:
                recognize_nodes(active_node._op1, list_index_traversed, True)
                recognize_nodes(active_node._op2, list_index_traversed, False)

            list_index_traversed += 1
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
134

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
135
    edited = set()
Rouven Lemmerz's avatar
Rouven Lemmerz committed
136

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
137
    def equal_leaves(leaves):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
138
        id_leaf = {}
139
        # Find matching leaves
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
140
        def write_to_dic(leaf, leaf_op_id):
141
            try:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
142
                id_leaf[leaf_op_id] = id_leaf[leaf_op_id] + (leaf,)
143
            except KeyError:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
144
145
146
147
148
149
150
151
152
153
                id_leaf[leaf_op_id] = (leaf,)

        for leaf in leaves:
            parent = nodes[leaf[0]][0]
            attr = left_parser(leaf[1])
            leaf_op = getattr(parent, attr)
            if isinstance(leaf_op, _OpChain):
                leaf_op_id = ''
                for i in reversed(leaf_op._ops):
                    leaf_op_id += str(id(i))
154
155
                    if not isinstance(i, FieldAdapter):
                        # Do not optimise leaves which only have equal FieldAdapters
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
156
                        write_to_dic(leaf, leaf_op_id)
157
                        break
Rouven Lemmerz's avatar
Rouven Lemmerz committed
158
            else:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
159
160
                if not isinstance(leaf_op, FieldAdapter):
                    write_to_dic(leaf, str(id(leaf_op)))
161
162
163


        # Unroll their OpChain and see how far they are equal
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
164
165
166
        key_list_leaf = []
        same_leaf = {}
        get_duplicate_keys(key_list_leaf, id_leaf)
167

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
168
        for key in key_list_leaf:
169
            to_compare = []
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
170
171
172
173
174
175
            for leaf in id_leaf[key]:
                parent = nodes[leaf[0]][0]
                attr = left_parser(leaf[1])
                leaf_op = getattr(parent, attr)
                if isinstance(leaf_op, _OpChain):
                    to_compare.append(tuple(reversed(leaf_op._ops)))
176
                else:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
177
                    to_compare.append((leaf_op,))
178
179
180
            first_difference = 1
            max_diff = min(len(i) for i in to_compare)
            if not max_diff == 1:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
181
182
                compare_iterator = iter(to_compare)
                first = next(compare_iterator)
183
184
185
186
187
188
                while all(first[first_difference] == rest[first_difference] for rest in compare_iterator):
                    first_difference += 1
                    if first_difference >= max_diff:
                        break
                    compare_iterator = iter(to_compare)
                    first = next(compare_iterator)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
189
190

            common_op = to_compare[0][:first_difference]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
191
192
193
            res_op = common_op[0]
            for ops in common_op[1:]:
                res_op = ops @ res_op
Rouven Lemmerz's avatar
Rouven Lemmerz committed
194

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
195
            same_leaf[key] = [res_op, FieldAdapter(res_op.target, next(prepend_id) + str(id(res_op)))]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
196

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
197
198
            for leaf in id_leaf[key]:
                parent = nodes[leaf[0]][0]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
199
                edited.add(id_dic[id(parent)][0])
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
200
201
202
203
                attr = left_parser(leaf[1])
                leaf_op = getattr(parent, attr)
                if isinstance(leaf_op, _OpChain):
                    if first_difference == len(leaf_op._ops):
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
204
                        setattr(parent, attr, same_leaf[key][1])
Rouven Lemmerz's avatar
Rouven Lemmerz committed
205
                    else:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
206
                        leaf_op._ops = leaf_op._ops[:-first_difference] + (same_leaf[key][1],)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
207
                else:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
208
209
                    setattr(parent, attr, same_leaf[key][1])
        return key_list_leaf, same_leaf
210

211
212
    equal_nodes(op)

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
213
214
    key_temp = []
    key_list_op, same_op = equal_leaves(leaves)
215
216
    cond = True
    while cond:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
217
        key_temp, same_op_temp = equal_leaves(leaves)
218
219
220
        key_list_op += key_temp
        same_op.update(same_op_temp)
        cond = len(same_op_temp) > 0
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
221
    key_temp.clear()
222
223

    # Cut subtrees
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
224
225
226
227
    key_list_node = []
    key_list_subtrees = []
    same_node = {}
    same_subtrees = {}
228
229
    subtree_leaves = set()

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
230
231
232
233
234
    get_duplicate_keys(key_list_node, id_dic)

    for key in key_list_node:
        same_node[key] = [nodes[id_dic[key][0]][0],
                          FieldAdapter(nodes[id_dic[key][0]][0].target, next(prepend_id) + str(key))]
235
236
237
238
239

        for node_indices in id_dic[key]:
            edited.add(node_indices)
            parent = nodes[nodes[node_indices][1]][0]
            attr = left_parser(nodes[node_indices][2])
240
            if isinstance(getattr(parent, attr), _OpChain):
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
241
                getattr(parent, attr)._ops = getattr(parent, attr)._ops[:-1] + (same_node[key][1],)
242
            else:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
243
244
                setattr(parent, attr, same_node[key][1])
            # Nodes have been replaced - treat replacements now as leaves
245
246
247
            subtree_leaves.add((nodes[node_indices][1], nodes[node_indices][2]))
        cond = True
        while cond:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
248
249
250
251
252
253
            key_temp1, same_temp = equal_leaves(subtree_leaves)
            key_temp = key_temp1 + key_temp
            same_subtrees.update(same_temp)
            cond = len(same_temp) > 0
        key_list_subtrees += key_temp + [key, ]
        key_temp.clear()
254
        subtree_leaves.clear()
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
255
    same_subtrees.update(same_node)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
256
257
258
259
260
261
262

    for index in edited:
        rebuild_domains(index)
    if isinstance(op, _OpChain):
        op._domain = op._ops[-1].domain

    # Insert trees before leaves
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
263
264
    for key in key_list_subtrees:
        op = op.partial_insert(same_subtrees[key][1].adjoint(same_subtrees[key][0]))
265
    for key in reversed(key_list_op):
Rouven Lemmerz's avatar
Rouven Lemmerz committed
266
        op = op.partial_insert(same_op[key][1].adjoint(same_op[key][0]))
Rouven Lemmerz's avatar
Rouven Lemmerz committed
267
    return op
Rouven Lemmerz's avatar
Rouven Lemmerz committed
268

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
269
270


Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
271

272

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
273
274
275
def optimise_operator(op):
    """
    Merges redundant operations in the tree structure of an operator.
276
277
278
279
    For example it is ensured that for ``f@x + x`` the operator ``x`` is only computed once.
    It is supposed to be used on the whole operator chain before doing minimisation.

    Currently optimises only ``_OpChain``, ``_OpSum`` and ``_OpProd`` and not their linear pendants
Rouven Lemmerz's avatar
Rouven Lemmerz committed
280
    ``ChainOp`` and ``SumOperator``.
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
281
282
283

    Parameters
    ----------
284
285
    op : Operator
        Operator with a tree structure.
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
286
287
288
289

    Returns
    -------
    op_optimised : Operator
290
        Operator with same input/output, but optimised tree structure.
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
291
292
293

    Notes
    -----
294
    Operators are compared only by id, so best results are achieved when the following code
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
295

Martin Reinecke's avatar
merge    
Martin Reinecke committed
296
    >>> from nifty7 import UniformOperator, DomainTuple
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
297
298
299
300
301
    >>> uni1 = UniformOperator(DomainTuple.scalar_domain()
    >>> uni2 = UniformOperator(DomainTuple.scalar_domain()
    >>> op = (uni1 + uni2)*(uni1 + uni2)

    is replaced by something comparable to
302

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
303
304
305
306
    >>> uni = UniformOperator(DomainTuple.scalar_domain())
    >>> uni_add = uni + uni
    >>> op = uni_add * uni_add

307
308
    After optimisation the operator is as fast as

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
309
310
311
312
    >>> op = (2*uni)**2
    """
    op_optimised = deepcopy(op)
    op_optimised = _optimise_operator(op_optimised)
313
    test_field = from_random(op.domain)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
314
315
    if isinstance(op(test_field), MultiField):
        for key in op(test_field).keys():
316
            myassert(allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10))
Rouven Lemmerz's avatar
Rouven Lemmerz committed
317
    else:
318
        myassert(allclose(op(test_field).val, op_optimised(test_field).val, 1e-10))
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
319
    return op_optimised