Commit 657e3f58 authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

einsum.py: Style and cosmetic changes

Among other things, rename `_init2` to the more descriptive name
`_init_wo_preproc` for initialization without (w/o) preprocessing.
parent 6532a1f8
......@@ -28,7 +28,7 @@ from .linear_operator import LinearOperator
class MultiLinearEinsum(Operator):
"""Multi-linear Einsum operator with corresponding derivates.
"""Multi-linear Einsum operator with corresponding derivates
Parameters
----------
......@@ -72,7 +72,7 @@ class MultiLinearEinsum(Operator):
if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
shapes, numpy_subscripts, subscriptmap = {},'',{}
shapes, numpy_subscripts, subscriptmap = {}, '', {}
alphabet = list(string.ascii_lowercase)[::-1]
for k, ss in zip(self._key_order, iss_spl):
dom = self._domain[k] if k in self._domain.keys(
......@@ -102,7 +102,6 @@ class MultiLinearEinsum(Operator):
numpy_subscripts += ''.join(subscriptmap[o])
self._target = DomainTuple.make(tgt)
numpy_iss, numpy_oss, *_ = numpy_subscripts.split("->")
numpy_iss_spl = numpy_iss.split(",")
......@@ -111,15 +110,15 @@ class MultiLinearEinsum(Operator):
for k, (i, ss) in zip(self._key_order, enumerate(numpy_iss_spl)):
left_ss_spl = (*numpy_iss_spl[:i], *numpy_iss_spl[i + 1:], ss)
linpath = '->'.join((','.join(left_ss_spl), numpy_oss))
plc = tuple(np.broadcast_to(np.nan, shapes[q]) for q in shapes.keys() if q!=k)
plc = tuple(np.broadcast_to(np.nan, shapes[q]) for q in shapes if q != k)
plc += (np.broadcast_to(np.nan, shapes[k]),)
self._sscr_endswith[k] = linpath
self._linpaths[k] = np.einsum_path(linpath, *plc, optimize=optimize)[0]
if isinstance(optimize, list):
path = optimize
else:
plc = (np.broadcast_to(np.nan, shapes[k]) for k in shapes.keys())
plc = (np.broadcast_to(np.nan, shapes[k]) for k in shapes)
path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
self._sscr = numpy_subscripts
self._ein_kw = {"optimize": path}
......@@ -153,8 +152,8 @@ class MultiLinearEinsum(Operator):
ss,
key_order=tuple(plc.keys()),
optimize=self._linpaths[wrt],
_target = self._target,
_calling_as_lin = True
_target=self._target,
_calling_as_lin=True
).ducktape(wrt)
jac = jac + jac_k if jac is not None else jac_k
return x.new(Field.from_raw(self.target, res), jac)
......@@ -164,7 +163,6 @@ class MultiLinearEinsum(Operator):
class LinearEinsum(LinearOperator):
"""Linear Einsum operator with exactly one freely varying field
Parameters
----------
domain : Domain, DomainTuple or tuple of Domain
......@@ -193,7 +191,7 @@ class LinearEinsum(LinearOperator):
_target=None, _calling_as_lin=False):
self._domain = DomainTuple.make(domain)
if _calling_as_lin:
self._init2(mf, subscripts, key_order, optimize, _target)
self._init_wo_preproc(mf, subscripts, key_order, optimize, _target)
else:
self._mf = mf
if key_order is None:
......@@ -208,7 +206,7 @@ class LinearEinsum(LinearOperator):
if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {_key_order} for subscripts {subscripts}"
shapes, numpy_subscripts, subscriptmap = (),'',{}
shapes, numpy_subscripts, subscriptmap = (), '', {}
alphabet = list(string.ascii_lowercase)
for k, ss in zip(_key_order, iss_spl[:-1]):
dom = self._mf[k].domain
......@@ -220,7 +218,7 @@ class LinearEinsum(LinearOperator):
range(len(dom[i].shape))]
numpy_subscripts += ''.join(subscriptmap[a])
numpy_subscripts += ','
shapes +=(dom.shape,)
shapes += (dom.shape,)
if len(self._domain) != len(iss_spl[-1]):
raise ValueError(ve)
for i, a in enumerate(list(iss_spl[-1])):
......@@ -230,7 +228,7 @@ class LinearEinsum(LinearOperator):
numpy_subscripts += ''.join(subscriptmap[a])
shapes += (self._domain.shape,)
numpy_subscripts += '->'
dom_sscr = dict(zip(_key_order, iss_spl[:-1]))
dom_sscr[id(self)] = iss_spl[-1]
tgt = []
......@@ -245,16 +243,16 @@ class LinearEinsum(LinearOperator):
numpy_subscripts += "".join(subscriptmap[o])
_target = DomainTuple.make(tgt)
self._sscr = numpy_subscripts
if isinstance(optimize, list):
path = optimize
else:
plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
self._init2(mf, numpy_subscripts, _key_order, path, _target)
self._init_wo_preproc(mf, numpy_subscripts, _key_order, path, _target)
def _init2(self, mf, subscripts, keyorder, optimize, target):
def _init_wo_preproc(self, mf, subscripts, keyorder, optimize, target):
self._ein_kw = {"optimize": optimize}
self._mf = mf
self._sscr = subscripts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment