einsum.py 11.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
# Authors: Gordian Edenhofer, Philipp Frank
16
17
18
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

19
import string
Philipp Arras's avatar
Philipp Arras committed
20
21
22

import numpy as np

23
24
from ..domain_tuple import DomainTuple
from ..field import Field
Philipp Arras's avatar
Philipp Arras committed
25
from ..linearization import Linearization
26
27
28
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .linear_operator import LinearOperator
Philipp Arras's avatar
Philipp Arras committed
29
from .operator import Operator
30
31
32


class MultiLinearEinsum(Operator):
33
    """Multi-linear Einsum operator with corresponding derivates
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

    Parameters
    ----------
    domain : MultiDomain or dict{name: DomainTuple}
        The operator's input domain.
    subscripts : str
        The subscripts which is passed to einsum.
    key_order: tuple of str, optional
        The order of the keys in the multi-field. If not specified, defaults to
        the order of the keys in the multi-field.
    static_mf: MultiField or dict{name: Field}, optional
        A dictionary like type from which Fields are to be taken if the key from
        `key_order` is not part of the `domain`. Fields in this object are
        supposed to be static as they will not appear as FieldAdapter in the
        Linearization.
Philipp Frank's avatar
Philipp Frank committed
49
    optimize: bool, String or List, optional
Philipp Frank's avatar
fixup    
Philipp Frank committed
50
        Parameter passed on to einsum_path.
51
52
53
54
55

    Notes
    -----
    By convention :class:`MultiLinearEinsum` only performs operations with
    lower indices. Therefore no complex conjugation is performed on complex
Philipp Frank's avatar
Philipp Frank committed
56
    inputs. To achieve operations with upper/lower indices use
57
    :class:`PartialConjugate` before applying this operator.
58
    """
Philipp Frank's avatar
cleanup    
Philipp Frank committed
59
    def __init__(self, domain, subscripts,
Philipp Frank's avatar
fixup    
Philipp Frank committed
60
                 key_order=None, static_mf=None, optimize='optimal'):
61
62
63
64
65
66
67
68
69
        self._domain = MultiDomain.make(domain)
        if key_order is None:
            self._key_order = tuple(self._domain.keys())
        else:
            self._key_order = key_order
        if static_mf is not None and key_order is None:
            ve = "`key_order` mus be specified if additional fields are munged"
            raise ValueError(ve)
        self._stat_mf = static_mf
70
        iss, oss, *rest = subscripts.split("->")
71
72
        iss_spl = iss.split(",")
        len_consist = len(self._key_order) == len(iss_spl)
73
74
        sscr_consist = all(o in iss for o in oss)
        if rest or not sscr_consist or "," in oss or not len_consist:
75
            raise ValueError(f"invalid subscripts specified; got {subscripts}")
76
        ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
77
        shapes, numpy_subscripts, subscriptmap = {}, '', {}
Philipp Frank's avatar
Philipp Frank committed
78
        alphabet = list(string.ascii_lowercase)[::-1]
79
80
81
        for k, ss in zip(self._key_order, iss_spl):
            dom = self._domain[k] if k in self._domain.keys(
            ) else self._stat_mf[k].domain
82
            if len(dom) != len(ss):
83
                raise ValueError(ve)
84
85
            for i, a in enumerate(list(ss)):
                if a not in subscriptmap.keys():
Philipp Frank's avatar
Philipp Frank committed
86
87
                    subscriptmap[a] = [alphabet.pop() for _ in
                                       range(len(dom[i].shape))]
88
89
90
91
                numpy_subscripts += ''.join(subscriptmap[a])
            numpy_subscripts += ','
            shapes[k] = dom.shape
        numpy_subscripts = numpy_subscripts[:-1] + '->'
92
93
        dom_sscr = dict(zip(self._key_order, iss_spl))
        tgt = []
94
        for o in oss:
95
96
97
98
99
100
101
102
103
            k_hit = tuple(k for k, sscr in dom_sscr.items() if o in sscr)[0]
            dom_k_idx = dom_sscr[k_hit].index(o)
            if k_hit in self._domain.keys():
                tgt += [self._domain[k_hit][dom_k_idx]]
            else:
                if k_hit not in self._stat_mf.keys():
                    ve = f"{k_hit} is not in domain nor in static_mf"
                    raise ValueError(ve)
                tgt += [self._stat_mf[k_hit].domain[dom_k_idx]]
104
            numpy_subscripts += ''.join(subscriptmap[o])
105
106
        self._target = DomainTuple.make(tgt)

107
108
109
        numpy_iss, numpy_oss, *_ = numpy_subscripts.split("->")
        numpy_iss_spl = numpy_iss.split(",")

110
        self._sscr_endswith = dict()
111
        self._linpaths = dict()
112
113
        for k, (i, ss) in zip(self._key_order, enumerate(numpy_iss_spl)):
            left_ss_spl = (*numpy_iss_spl[:i], *numpy_iss_spl[i + 1:], ss)
114
            linpath = '->'.join((','.join(left_ss_spl), numpy_oss))
115
116

            plc = tuple(np.broadcast_to(np.nan, shapes[q]) for q in shapes if q != k)
117
            plc += (np.broadcast_to(np.nan, shapes[k]),)
118
119
            self._sscr_endswith[k] = linpath
            self._linpaths[k] = np.einsum_path(linpath, *plc, optimize=optimize)[0]
Philipp Frank's avatar
fixup    
Philipp Frank committed
120
121
122
        if isinstance(optimize, list):
            path = optimize
        else:
123
            plc = (np.broadcast_to(np.nan, shapes[k]) for k in shapes)
Philipp Frank's avatar
fixup    
Philipp Frank committed
124
            path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
125
        self._sscr = numpy_subscripts
126
        self._ein_kw = {"optimize": path}
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

    def apply(self, x):
        self._check_input(x)
        if isinstance(x, Linearization):
            val = x.val.val
        else:
            val = x.val
        v = (
            val[k] if k in val else self._stat_mf[k].val
            for k in self._key_order
        )
        res = np.einsum(self._sscr, *v, **self._ein_kw)

        if isinstance(x, Linearization):
            jac = None
            for wrt in self.domain.keys():
                plc = {
                    k: x.val[k] if k in x.val else self._stat_mf[k]
                    for k in self._key_order if k != wrt
                }
                mf_wo_k = MultiField.from_dict(plc)
                ss = self._sscr_endswith[wrt]
                # Use the fact that the insertion order in a dictionary is the
                # ordering of keys as to pass on `key_order`
                jac_k = LinearEinsum(
                    self.domain[wrt],
                    mf_wo_k,
                    ss,
                    key_order=tuple(plc.keys()),
156
                    optimize=self._linpaths[wrt],
157
158
                    _target=self._target,
                    _calling_as_lin=True
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
                ).ducktape(wrt)
                jac = jac + jac_k if jac is not None else jac_k
            return x.new(Field.from_raw(self.target, res), jac)
        return Field.from_raw(self.target, res)


class LinearEinsum(LinearOperator):
    """Linear Einsum operator with exactly one freely varying field

    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
        The operator's input domain.
    mf : MultiField
        The first part of the left-hand side of the einsum.
    subscripts : str
        The subscripts which is passed to einsum. Everything before the very
        last scripts before the '->' is treated as part of the fixed mulfi-
        field while the last scripts are taken to correspond to the freely
        varying field.
    key_order: tuple of str, optional
        The order of the keys in the multi-field. If not specified, defaults to
        the order of the keys in the multi-field.
Philipp Frank's avatar
Philipp Frank committed
182
    optimize: bool, String or List, optional
Philipp Frank's avatar
fixup    
Philipp Frank committed
183
        Parameter passed on to einsum_path.
Philipp Frank's avatar
Philipp Frank committed
184
185
186
187
188
189
190

    Notes
    -----
    By convention :class:`LinearEinsum` only performs operations with
    lower indices. Therefore no complex conjugation is performed on complex
    inputs or mf. To achieve operations with upper/lower indices use
    :class:`PartialConjugate` before applying this operator.
191
    """
192
193
    def __init__(self, domain, mf, subscripts, key_order=None, optimize='optimal',
                 _target=None, _calling_as_lin=False):
194
        self._domain = DomainTuple.make(domain)
195
        if _calling_as_lin:
196
            self._init_wo_preproc(mf, subscripts, key_order, optimize, _target)
197
        else:
198
199
200
201
202
203
204
205
206
207
208
209
210
            self._mf = mf
            if key_order is None:
                _key_order = tuple(self._mf.domain.keys())
            else:
                _key_order = key_order
            self._ein_kw = {"optimize": optimize}
            iss, oss, *rest = subscripts.split("->")
            iss_spl = iss.split(",")
            sscr_consist = all(o in iss for o in oss)
            len_consist = len(_key_order) == len(iss_spl[:-1])
            if rest or not sscr_consist or "," in oss or not len_consist:
                raise ValueError(f"invalid subscripts specified; got {subscripts}")
            ve = f"invalid order of keys {_key_order} for subscripts {subscripts}"
211
            shapes, numpy_subscripts, subscriptmap = (), '', {}
212
213
214
215
216
217
218
219
220
221
222
            alphabet = list(string.ascii_lowercase)
            for k, ss in zip(_key_order, iss_spl[:-1]):
                dom = self._mf[k].domain
                if len(dom) != len(ss):
                    raise ValueError(ve)
                for i, a in enumerate(list(ss)):
                    if a not in subscriptmap.keys():
                        subscriptmap[a] = [alphabet.pop() for _ in
                                           range(len(dom[i].shape))]
                    numpy_subscripts += ''.join(subscriptmap[a])
                numpy_subscripts += ','
223
                shapes += (dom.shape,)
224
            if len(self._domain) != len(iss_spl[-1]):
225
                raise ValueError(ve)
226
            for i, a in enumerate(list(iss_spl[-1])):
227
                if a not in subscriptmap.keys():
Philipp Frank's avatar
Philipp Frank committed
228
                    subscriptmap[a] = [alphabet.pop() for _ in
229
                                       range(len(self._domain[i].shape))]
230
                numpy_subscripts += ''.join(subscriptmap[a])
231
232
            shapes += (self._domain.shape,)
            numpy_subscripts += '->'
233

234
235
236
237
238
239
240
241
242
243
244
245
246
247
            dom_sscr = dict(zip(_key_order, iss_spl[:-1]))
            dom_sscr[id(self)] = iss_spl[-1]
            tgt = []
            for o in oss:
                k_hit = tuple(k for k, sscr in dom_sscr.items() if o in sscr)[0]
                dom_k_idx = dom_sscr[k_hit].index(o)
                if k_hit in _key_order:
                    tgt += [self._mf.domain[k_hit][dom_k_idx]]
                else:
                    assert k_hit == id(self)
                    tgt += [self._domain[dom_k_idx]]
                numpy_subscripts += "".join(subscriptmap[o])
            _target = DomainTuple.make(tgt)
            self._sscr = numpy_subscripts
248

249
250
            if isinstance(optimize, list):
                path = optimize
251
            else:
252
253
                plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
                path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
254
255
256
            self._init_wo_preproc(mf, numpy_subscripts, _key_order, path, _target)

    def _init_wo_preproc(self, mf, subscripts, keyorder, optimize, target):
257
258
259
260
261
262
        self._ein_kw = {"optimize": optimize}
        self._mf = mf
        self._sscr = subscripts
        self._key_order = keyorder
        self._target = target
        iss, oss, *_ = subscripts.split("->")
263
        iss_spl = iss.split(",")
264
265
266
267
268
269
270
271

        adj_iss = ",".join((",".join(iss_spl[:-1]), oss))
        self._adj_sscr = "->".join((adj_iss, iss_spl[-1]))
        self._capability = self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
        self._check_input(x, mode)
        if mode == self.TIMES:
272
            dom, ss, mf = self.target, self._sscr, self._mf
273
        else:
274
            dom, ss, mf = self.domain, self._adj_sscr, self._mf.conjugate()
275
        res = np.einsum(
276
            ss, *(mf[k].val for k in self._key_order), x.val,
277
278
279
            **self._ein_kw
        )
        return Field.from_raw(dom, res)