operator_tree_optimiser.py 10.9 KB
Newer Older
Rouven Lemmerz's avatar
Rouven Lemmerz committed
1
2
3
4
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
Rouven Lemmerz's avatar
Rouven Lemmerz committed
5
#
Rouven Lemmerz's avatar
Rouven Lemmerz committed
6
7
8
9
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
Rouven Lemmerz's avatar
Rouven Lemmerz committed
10
#
Rouven Lemmerz's avatar
Rouven Lemmerz committed
11
12
13
14
15
16
17
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
18
19
20
from .operators.operator import _OpChain, _OpSum, _OpProd
from .sugar import domain_union
from .operators.simple_linear_operators import FieldAdapter
Rouven Lemmerz's avatar
Rouven Lemmerz committed
21

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
22
def _optimise_operator(op):
Rouven Lemmerz's avatar
Rouven Lemmerz committed
23
    """
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
24
    optimises operator trees, so that same operator subtrees are not computed twice.
Rouven Lemmerz's avatar
Rouven Lemmerz committed
25
26
27
28
    Recognizes same subtrees and replaces them at nodes.
    Recognizes same leaves and structures them.
    Works partly inplace, rendering the old operator unusable"""

29
    # Format: List of tuple [op, parent_index, left=True right=False]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
30
31
32
33
34
35
    nodes = []
    # Format: ID: index in nodes[]
    id_dic = {}
    # Format: [parent_index, left]
    leaves = set()

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    # helper functions
    def readable_id():
        # Gives out letters to prepend field_adapter ids for cosmetics
        # Start at 'A'
        current_letter = 65
        repeats = 1
        while True:
            yield chr(current_letter)*repeats
            current_letter += 1
            if current_letter == 91:
                # skip specials
                current_letter += 6
            elif current_letter == 123:
                # End at z and start at AA
                current_letter = 65
                repeats += 1

    prepend_id = readable_id()
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
54

Rouven Lemmerz's avatar
Rouven Lemmerz committed
55
    def isnode(op):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
56
57
        return isinstance(op, (_OpSum, _OpProd))

Rouven Lemmerz's avatar
Rouven Lemmerz committed
58
59

    def left_parser(left_bool):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
60
61
        return '_op1' if left_bool else '_op2'

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
62
63
64
65
    def get_duplicate_keys(k_list, dic):
        for item in list(dic.items()):
            if len(item[1]) > 1:
                k_list.append(item[0])
Rouven Lemmerz's avatar
Rouven Lemmerz committed
66

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
67
    # Main algorithm functions
Rouven Lemmerz's avatar
Rouven Lemmerz committed
68
69
70
71
72
73
74
    def rebuild_domains(index):
        """Goes bottom up to fix domains which were destroyed by plugging in field adapters"""
        cond = True
        while cond:
            op = nodes[index][0]
            for attr in ('_op1', '_op2'):
                if isinstance(getattr(op, attr), _OpChain):
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
75
                    getattr(op, attr)._domain = getattr(op, attr)._ops[-1].domain
Rouven Lemmerz's avatar
Rouven Lemmerz committed
76
            if isnode(op):
77
78
79
                # Some problems doing this on non-multidomains, because one side becomes a multidomain and the other not
                try:
                    op._domain = domain_union((op._op1.domain, op._op2.domain))
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
80
                except AttributeError:
81
82
83
                    import warnings
                    warnings.warn('Operator should be defined on a MultiDomain')
                    pass
Rouven Lemmerz's avatar
Rouven Lemmerz committed
84
85
86

            index = nodes[index][1]
            cond = type(index) is int
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
87

Rouven Lemmerz's avatar
Rouven Lemmerz committed
88
89

    def recognize_nodes(op, active_node, left):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
90
91
        # If nothing added - is a leaf!
        isleaf = True
Rouven Lemmerz's avatar
Rouven Lemmerz committed
92
93
94
        if isinstance(op, _OpChain):
           for i in range(len(op._ops)):
                if isnode(op._ops[i]):
95
                    nodes.append((op._ops[i], active_node, left))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
96
                    isleaf = False
Rouven Lemmerz's avatar
Rouven Lemmerz committed
97
        elif isnode(op):
98
            nodes.append((op, active_node, left))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
99
100
            isleaf = False
        if isleaf:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
101
            leaves.add((active_node, left))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
102

Rouven Lemmerz's avatar
Rouven Lemmerz committed
103
104

    def equal_nodes(op):
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
105
106
        # BFS-Algorithm which fills the nodes list and id_dic dictionary
        # Does not scan equal subtrees multiple times
Rouven Lemmerz's avatar
Rouven Lemmerz committed
107
108
109
110
111
112
113
114
        list_index_traversed = 0
        recognize_nodes(op, None, None)

        while list_index_traversed < len(nodes):
            # Visit node
            active_node = nodes[list_index_traversed][0]

            # Check whether exists already
Rouven Lemmerz's avatar
Rouven Lemmerz committed
115
            try:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
116
117
                id_dic[id(active_node)] = id_dic[id(active_node)] + [list_index_traversed]
                match = True
Rouven Lemmerz's avatar
Rouven Lemmerz committed
118
            except KeyError:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
119
120
121
122
123
124
125
126
                id_dic[id(active_node)] = [list_index_traversed]
                match = False
            # Check vertices for nodes
            if not match:
                recognize_nodes(active_node._op1, list_index_traversed, True)
                recognize_nodes(active_node._op2, list_index_traversed, False)

            list_index_traversed += 1
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
127

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
128
    edited = set()
Rouven Lemmerz's avatar
Rouven Lemmerz committed
129

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
130
    def equal_leaves(leaves):
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
131
        id_leaf = {}
132
        # Find matching leaves
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
133
        def write_to_dic(leaf, leaf_op_id):
134
            try:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
135
                id_leaf[leaf_op_id] = id_leaf[leaf_op_id] + (leaf,)
136
            except KeyError:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
137
138
139
140
141
142
143
144
145
146
                id_leaf[leaf_op_id] = (leaf,)

        for leaf in leaves:
            parent = nodes[leaf[0]][0]
            attr = left_parser(leaf[1])
            leaf_op = getattr(parent, attr)
            if isinstance(leaf_op, _OpChain):
                leaf_op_id = ''
                for i in reversed(leaf_op._ops):
                    leaf_op_id += str(id(i))
147
148
                    if not isinstance(i, FieldAdapter):
                        # Do not optimise leaves which only have equal FieldAdapters
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
149
                        write_to_dic(leaf, leaf_op_id)
150
                        break
Rouven Lemmerz's avatar
Rouven Lemmerz committed
151
            else:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
152
153
                if not isinstance(leaf_op, FieldAdapter):
                    write_to_dic(leaf, str(id(leaf_op)))
154
155
156


        # Unroll their OpChain and see how far they are equal
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
157
158
159
        key_list_leaf = []
        same_leaf = {}
        get_duplicate_keys(key_list_leaf, id_leaf)
160

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
161
        for key in key_list_leaf:
162
            to_compare = []
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
163
164
165
166
167
168
            for leaf in id_leaf[key]:
                parent = nodes[leaf[0]][0]
                attr = left_parser(leaf[1])
                leaf_op = getattr(parent, attr)
                if isinstance(leaf_op, _OpChain):
                    to_compare.append(tuple(reversed(leaf_op._ops)))
169
                else:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
170
                    to_compare.append((leaf_op,))
171
172
173
            first_difference = 1
            max_diff = min(len(i) for i in to_compare)
            if not max_diff == 1:
Rouven Lemmerz's avatar
Rouven Lemmerz committed
174
175
                compare_iterator = iter(to_compare)
                first = next(compare_iterator)
176
177
178
179
180
181
                while all(first[first_difference] == rest[first_difference] for rest in compare_iterator):
                    first_difference += 1
                    if first_difference >= max_diff:
                        break
                    compare_iterator = iter(to_compare)
                    first = next(compare_iterator)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
182
183

            common_op = to_compare[0][:first_difference]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
184
185
186
            res_op = common_op[0]
            for ops in common_op[1:]:
                res_op = ops @ res_op
Rouven Lemmerz's avatar
Rouven Lemmerz committed
187

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
188
            same_leaf[key] = [res_op, FieldAdapter(res_op.target, next(prepend_id) + str(id(res_op)))]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
189

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
190
191
            for leaf in id_leaf[key]:
                parent = nodes[leaf[0]][0]
Rouven Lemmerz's avatar
Rouven Lemmerz committed
192
                edited.add(id_dic[id(parent)][0])
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
193
194
195
196
                attr = left_parser(leaf[1])
                leaf_op = getattr(parent, attr)
                if isinstance(leaf_op, _OpChain):
                    if first_difference == len(leaf_op._ops):
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
197
                        setattr(parent, attr, same_leaf[key][1])
Rouven Lemmerz's avatar
Rouven Lemmerz committed
198
                    else:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
199
                        leaf_op._ops = leaf_op._ops[:-first_difference] + (same_leaf[key][1],)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
200
                else:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
201
202
                    setattr(parent, attr, same_leaf[key][1])
        return key_list_leaf, same_leaf
203

204
205
    equal_nodes(op)

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
206
207
    key_temp = []
    key_list_op, same_op = equal_leaves(leaves)
208
209
    cond = True
    while cond:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
210
        key_temp, same_op_temp = equal_leaves(leaves)
211
212
213
        key_list_op += key_temp
        same_op.update(same_op_temp)
        cond = len(same_op_temp) > 0
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
214
    key_temp.clear()
215
216

    # Cut subtrees
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
217
218
219
220
    key_list_node = []
    key_list_subtrees = []
    same_node = {}
    same_subtrees = {}
221
222
    subtree_leaves = set()

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
223
224
225
226
227
    get_duplicate_keys(key_list_node, id_dic)

    for key in key_list_node:
        same_node[key] = [nodes[id_dic[key][0]][0],
                          FieldAdapter(nodes[id_dic[key][0]][0].target, next(prepend_id) + str(key))]
228
229
230
231
232

        for node_indices in id_dic[key]:
            edited.add(node_indices)
            parent = nodes[nodes[node_indices][1]][0]
            attr = left_parser(nodes[node_indices][2])
233
            if isinstance(getattr(parent, attr), _OpChain):
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
234
                getattr(parent, attr)._ops = getattr(parent, attr)._ops[:-1] + (same_node[key][1],)
235
            else:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
236
237
                setattr(parent, attr, same_node[key][1])
            # Nodes have been replaced - treat replacements now as leaves
238
239
240
            subtree_leaves.add((nodes[node_indices][1], nodes[node_indices][2]))
        cond = True
        while cond:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
241
242
243
244
245
246
            key_temp1, same_temp = equal_leaves(subtree_leaves)
            key_temp = key_temp1 + key_temp
            same_subtrees.update(same_temp)
            cond = len(same_temp) > 0
        key_list_subtrees += key_temp + [key, ]
        key_temp.clear()
247
        subtree_leaves.clear()
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
248
    same_subtrees.update(same_node)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
249
250
251
252
253
254
255

    for index in edited:
        rebuild_domains(index)
    if isinstance(op, _OpChain):
        op._domain = op._ops[-1].domain

    # Insert trees before leaves
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
256
257
    for key in key_list_subtrees:
        op = op.partial_insert(same_subtrees[key][1].adjoint(same_subtrees[key][0]))
258
    for key in reversed(key_list_op):
Rouven Lemmerz's avatar
Rouven Lemmerz committed
259
        op = op.partial_insert(same_op[key][1].adjoint(same_op[key][0]))
Rouven Lemmerz's avatar
Rouven Lemmerz committed
260
    return op
Rouven Lemmerz's avatar
Rouven Lemmerz committed
261

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
262

Rouven Lemmerz's avatar
Rouven Lemmerz committed
263
from copy import deepcopy
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
264
265
from .sugar import from_random
from .multi_field import MultiField
Rouven Lemmerz's avatar
Rouven Lemmerz committed
266
from numpy import allclose
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
267

Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

def optimise_operator(op):
    """
    Merges redundant operations in the tree structure of an operator.
    For example it is ensured that for ``(f@x + x)`` ``x`` is only computed once.

    Parameters
    ----------
    op: Operator

    Returns
    -------
    op_optimised : Operator

    Notes
    -----
    Since operators are compared by id best results are achieved when the following code

    >>> from nifty6 import UniformOperator, DomainTuple
    >>> uni1 = UniformOperator(DomainTuple.scalar_domain()
    >>> uni2 = UniformOperator(DomainTuple.scalar_domain()
    >>> op = (uni1 + uni2)*(uni1 + uni2)

    is replaced by something comparable to
    >>> from nifty6 import UniformOperator, DomainTuple
    >>> uni = UniformOperator(DomainTuple.scalar_domain())
    >>> uni_add = uni + uni
    >>> op = uni_add * uni_add

    After optimisation the operator is comparable in speed to
    >>> op = (2*uni)**2
    """
    op_optimised = deepcopy(op)
    op_optimised = _optimise_operator(op_optimised)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
302
303
304
    test_field = from_random('normal', op.domain)
    if isinstance(op(test_field), MultiField):
        for key in op(test_field).keys():
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
305
            assert allclose(op(test_field).val[key], op_optimised(test_field).val[key], 1e-10)
Rouven Lemmerz's avatar
Rouven Lemmerz committed
306
    else:
Rouven Lemmerz's avatar
Fixes    
Rouven Lemmerz committed
307
308
        assert allclose(op(test_field).val, op_optimised(test_field).val, 1e-10)
    return op_optimised