einsum.py 11.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Gordian Edenhofer
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np
20
import string
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from ..domain_tuple import DomainTuple
from ..linearization import Linearization
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .operator import Operator
from .linear_operator import LinearOperator


class MultiLinearEinsum(Operator):
    """Multi-linear Einsum operator with corresponding derivates

    FIXME: This operator does not perform any complex conjugation!

    Parameters
    ----------
    domain : MultiDomain or dict{name: DomainTuple}
        The operator's input domain.
    subscripts : str
        The subscripts which is passed to einsum.
    key_order: tuple of str, optional
        The order of the keys in the multi-field. If not specified, defaults to
        the order of the keys in the multi-field.
    static_mf: MultiField or dict{name: Field}, optional
        A dictionary like type from which Fields are to be taken if the key from
        `key_order` is not part of the `domain`. Fields in this object are
        supposed to be static as they will not appear as FieldAdapter in the
        Linearization.
Philipp Frank's avatar
Philipp Frank committed
49
    optimize: bool, String or List, optional
Philipp Frank's avatar
fixup    
Philipp Frank committed
50
        Parameter passed on to einsum_path.
51
    """
Philipp Frank's avatar
cleanup    
Philipp Frank committed
52
    def __init__(self, domain, subscripts,
Philipp Frank's avatar
fixup    
Philipp Frank committed
53
                 key_order=None, static_mf=None, optimize='optimal'):
54
55
56
57
58
59
60
61
62
        self._domain = MultiDomain.make(domain)
        if key_order is None:
            self._key_order = tuple(self._domain.keys())
        else:
            self._key_order = key_order
        if static_mf is not None and key_order is None:
            ve = "`key_order` mus be specified if additional fields are munged"
            raise ValueError(ve)
        self._stat_mf = static_mf
63
        iss, oss, *rest = subscripts.split("->")
64
65
        iss_spl = iss.split(",")
        len_consist = len(self._key_order) == len(iss_spl)
66
67
        sscr_consist = all(o in iss for o in oss)
        if rest or not sscr_consist or "," in oss or not len_consist:
68
            raise ValueError(f"invalid subscripts specified; got {subscripts}")
69
        ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
70
        shapes, numpy_subscripts,  subscriptmap = {},'',{}
Philipp Frank's avatar
Philipp Frank committed
71
        alphabet = list(string.ascii_lowercase)[::-1]
72
73
74
        for k, ss in zip(self._key_order, iss_spl):
            dom = self._domain[k] if k in self._domain.keys(
            ) else self._stat_mf[k].domain
75
            if len(dom) != len(ss):
76
                raise ValueError(ve)
77
78
            for i, a in enumerate(list(ss)):
                if a not in subscriptmap.keys():
Philipp Frank's avatar
Philipp Frank committed
79
80
                    subscriptmap[a] = [alphabet.pop() for _ in
                                       range(len(dom[i].shape))]
81
82
83
84
                numpy_subscripts += ''.join(subscriptmap[a])
            numpy_subscripts += ','
            shapes[k] = dom.shape
        numpy_subscripts = numpy_subscripts[:-1] + '->'
85
86
        dom_sscr = dict(zip(self._key_order, iss_spl))
        tgt = []
87
        for o in oss:
88
89
90
91
92
93
94
95
96
            k_hit = tuple(k for k, sscr in dom_sscr.items() if o in sscr)[0]
            dom_k_idx = dom_sscr[k_hit].index(o)
            if k_hit in self._domain.keys():
                tgt += [self._domain[k_hit][dom_k_idx]]
            else:
                if k_hit not in self._stat_mf.keys():
                    ve = f"{k_hit} is not in domain nor in static_mf"
                    raise ValueError(ve)
                tgt += [self._stat_mf[k_hit].domain[dom_k_idx]]
97
            numpy_subscripts += ''.join(subscriptmap[o])
98
99
        self._target = DomainTuple.make(tgt)

100
101
102
103
        
        numpy_iss, numpy_oss, *_ = numpy_subscripts.split("->")
        numpy_iss_spl = numpy_iss.split(",")

104
        self._sscr_endswith = dict()
105
        self._linpaths = dict()
106
107
        for k, (i, ss) in zip(self._key_order, enumerate(numpy_iss_spl)):
            left_ss_spl = (*numpy_iss_spl[:i], *numpy_iss_spl[i + 1:], ss)
108
            linpath = '->'.join((','.join(left_ss_spl), numpy_oss))
109
            
110
111
            plc = tuple(np.broadcast_to(np.nan, shapes[q]) for q in shapes.keys() if q!=k)
            plc += (np.broadcast_to(np.nan, shapes[k]),)
112
113
            self._sscr_endswith[k] = linpath
            self._linpaths[k] = np.einsum_path(linpath, *plc, optimize=optimize)[0]
Philipp Frank's avatar
fixup    
Philipp Frank committed
114
115
116
117
118
        if isinstance(optimize, list):
            path = optimize
        else:
            plc = (np.broadcast_to(np.nan, shapes[k]) for k in shapes.keys())
            path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
119
        self._sscr = numpy_subscripts
120
        self._ein_kw = {"optimize": path}
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    def apply(self, x):
        self._check_input(x)
        if isinstance(x, Linearization):
            val = x.val.val
        else:
            val = x.val
        v = (
            val[k] if k in val else self._stat_mf[k].val
            for k in self._key_order
        )
        res = np.einsum(self._sscr, *v, **self._ein_kw)

        if isinstance(x, Linearization):
            jac = None
            for wrt in self.domain.keys():
                plc = {
                    k: x.val[k] if k in x.val else self._stat_mf[k]
                    for k in self._key_order if k != wrt
                }
                mf_wo_k = MultiField.from_dict(plc)
                ss = self._sscr_endswith[wrt]
                # Use the fact that the insertion order in a dictionary is the
                # ordering of keys as to pass on `key_order`
                jac_k = LinearEinsum(
                    self.domain[wrt],
                    mf_wo_k,
                    ss,
                    key_order=tuple(plc.keys()),
150
151
152
                    optimize=self._linpaths[wrt],
                    _target = self._target,
                    _calling_as_lin = True
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
                ).ducktape(wrt)
                jac = jac + jac_k if jac is not None else jac_k
            return x.new(Field.from_raw(self.target, res), jac)
        return Field.from_raw(self.target, res)


class LinearEinsum(LinearOperator):
    """Linear Einsum operator with exactly one freely varying field

    FIXME: This operator does not perform any complex conjugation!

    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
        The operator's input domain.
    mf : MultiField
        The first part of the left-hand side of the einsum.
    subscripts : str
        The subscripts which is passed to einsum. Everything before the very
        last scripts before the '->' is treated as part of the fixed mulfi-
        field while the last scripts are taken to correspond to the freely
        varying field.
    key_order: tuple of str, optional
        The order of the keys in the multi-field. If not specified, defaults to
        the order of the keys in the multi-field.
Philipp Frank's avatar
Philipp Frank committed
178
    optimize: bool, String or List, optional
Philipp Frank's avatar
fixup    
Philipp Frank committed
179
        Parameter passed on to einsum_path.
180
    """
181
182
    def __init__(self, domain, mf, subscripts, key_order=None, optimize='optimal',
                 _target=None, _calling_as_lin=False):
183
        self._domain = DomainTuple.make(domain)
184
185
        if _calling_as_lin:
            self._init2(mf, subscripts, key_order, optimize, _target)
186
        else:
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            self._mf = mf
            if key_order is None:
                _key_order = tuple(self._mf.domain.keys())
            else:
                _key_order = key_order
            self._ein_kw = {"optimize": optimize}
            iss, oss, *rest = subscripts.split("->")
            iss_spl = iss.split(",")
            sscr_consist = all(o in iss for o in oss)
            len_consist = len(_key_order) == len(iss_spl[:-1])
            if rest or not sscr_consist or "," in oss or not len_consist:
                raise ValueError(f"invalid subscripts specified; got {subscripts}")
            ve = f"invalid order of keys {_key_order} for subscripts {subscripts}"
            shapes, numpy_subscripts,  subscriptmap = (),'',{}
            alphabet = list(string.ascii_lowercase)
            for k, ss in zip(_key_order, iss_spl[:-1]):
                dom = self._mf[k].domain
                if len(dom) != len(ss):
                    raise ValueError(ve)
                for i, a in enumerate(list(ss)):
                    if a not in subscriptmap.keys():
                        subscriptmap[a] = [alphabet.pop() for _ in
                                           range(len(dom[i].shape))]
                    numpy_subscripts += ''.join(subscriptmap[a])
                numpy_subscripts += ','
                shapes +=(dom.shape,)
            if len(self._domain) != len(iss_spl[-1]):
214
                raise ValueError(ve)
215
            for i, a in enumerate(list(iss_spl[-1])):
216
                if a not in subscriptmap.keys():
Philipp Frank's avatar
Philipp Frank committed
217
                    subscriptmap[a] = [alphabet.pop() for _ in
218
                                       range(len(self._domain[i].shape))]
219
                numpy_subscripts += ''.join(subscriptmap[a])
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
            shapes += (self._domain.shape,)
            numpy_subscripts += '->'
    
            dom_sscr = dict(zip(_key_order, iss_spl[:-1]))
            dom_sscr[id(self)] = iss_spl[-1]
            tgt = []
            for o in oss:
                k_hit = tuple(k for k, sscr in dom_sscr.items() if o in sscr)[0]
                dom_k_idx = dom_sscr[k_hit].index(o)
                if k_hit in _key_order:
                    tgt += [self._mf.domain[k_hit][dom_k_idx]]
                else:
                    assert k_hit == id(self)
                    tgt += [self._domain[dom_k_idx]]
                numpy_subscripts += "".join(subscriptmap[o])
            _target = DomainTuple.make(tgt)
            self._sscr = numpy_subscripts
            
            if isinstance(optimize, list):
                path = optimize
240
            else:
241
242
243
                plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
                path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
            self._init2(mf, numpy_subscripts, _key_order, path, _target)
244
        
245
246
247
248
249
250
251
252

    def _init2(self, mf, subscripts, keyorder, optimize, target):
        self._ein_kw = {"optimize": optimize}
        self._mf = mf
        self._sscr = subscripts
        self._key_order = keyorder
        self._target = target
        iss, oss, *_ = subscripts.split("->")
253
        iss_spl = iss.split(",")
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269

        adj_iss = ",".join((",".join(iss_spl[:-1]), oss))
        self._adj_sscr = "->".join((adj_iss, iss_spl[-1]))
        self._capability = self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
        self._check_input(x, mode)
        if mode == self.TIMES:
            dom, ss = self.target, self._sscr
        else:
            dom, ss = self.domain, self._adj_sscr
        res = np.einsum(
            ss, *(self._mf.val[k] for k in self._key_order), x.val,
            **self._ein_kw
        )
        return Field.from_raw(dom, res)