Commit d296d7a7 authored by Philipp Frank's avatar Philipp Frank
Browse files

split LinearEinsum init for calling as lin

parent bb066937
Pipeline #74967 passed with stages
in 26 minutes and 16 seconds
...@@ -103,18 +103,14 @@ class MultiLinearEinsum(Operator): ...@@ -103,18 +103,14 @@ class MultiLinearEinsum(Operator):
self._sscr_endswith = dict() self._sscr_endswith = dict()
self._linpaths = dict() self._linpaths = dict()
for k, (i, ss), nss in zip(self._key_order, enumerate(iss_spl), for k, (i, ss) in zip(self._key_order, enumerate(numpy_iss_spl)):
numpy_iss_spl): left_ss_spl = (*numpy_iss_spl[:i], *numpy_iss_spl[i + 1:], ss)
left_ss_spl = (*iss_spl[:i], *iss_spl[i + 1:], ss)
self._sscr_endswith[k] = '->'.join(
(','.join(left_ss_spl), oss)
)
left_ss_spl = (*numpy_iss_spl[:i], *numpy_iss_spl[i + 1:], nss)
linpath = '->'.join((','.join(left_ss_spl), numpy_oss)) linpath = '->'.join((','.join(left_ss_spl), numpy_oss))
plc = tuple(np.broadcast_to(np.nan, shapes[q]) for q in shapes.keys() if q!=k) plc = tuple(np.broadcast_to(np.nan, shapes[q]) for q in shapes.keys() if q!=k)
plc += (np.broadcast_to(np.nan, shapes[k]),) plc += (np.broadcast_to(np.nan, shapes[k]),)
linpath = np.einsum_path(linpath, *plc, optimize=optimize)[0] self._sscr_endswith[k] = linpath
self._linpaths[k] = linpath self._linpaths[k] = np.einsum_path(linpath, *plc, optimize=optimize)[0]
if isinstance(optimize, list): if isinstance(optimize, list):
path = optimize path = optimize
else: else:
...@@ -151,7 +147,9 @@ class MultiLinearEinsum(Operator): ...@@ -151,7 +147,9 @@ class MultiLinearEinsum(Operator):
mf_wo_k, mf_wo_k,
ss, ss,
key_order=tuple(plc.keys()), key_order=tuple(plc.keys()),
optimize=self._linpaths[wrt] optimize=self._linpaths[wrt],
_target = self._target,
_calling_as_lin = True
).ducktape(wrt) ).ducktape(wrt)
jac = jac + jac_k if jac is not None else jac_k jac = jac + jac_k if jac is not None else jac_k
return x.new(Field.from_raw(self.target, res), jac) return x.new(Field.from_raw(self.target, res), jac)
...@@ -180,24 +178,28 @@ class LinearEinsum(LinearOperator): ...@@ -180,24 +178,28 @@ class LinearEinsum(LinearOperator):
optimize: bool, String or List, optional optimize: bool, String or List, optional
Parameter passed on to einsum_path. Parameter passed on to einsum_path.
""" """
def __init__(self, domain, mf, subscripts, key_order=None, optimize='optimal'): def __init__(self, domain, mf, subscripts, key_order=None, optimize='optimal',
_target=None, _calling_as_lin=False):
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
if _calling_as_lin:
self._init2(mf, subscripts, key_order, optimize, _target)
else:
self._mf = mf self._mf = mf
if key_order is None: if key_order is None:
self._key_order = tuple(self._mf.domain.keys()) _key_order = tuple(self._mf.domain.keys())
else: else:
self._key_order = key_order _key_order = key_order
self._ein_kw = {"optimize": optimize} self._ein_kw = {"optimize": optimize}
iss, oss, *rest = subscripts.split("->") iss, oss, *rest = subscripts.split("->")
iss_spl = iss.split(",") iss_spl = iss.split(",")
sscr_consist = all(o in iss for o in oss) sscr_consist = all(o in iss for o in oss)
len_consist = len(self._key_order) == len(iss_spl[:-1]) len_consist = len(_key_order) == len(iss_spl[:-1])
if rest or not sscr_consist or "," in oss or not len_consist: if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}") raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}" ve = f"invalid order of keys {_key_order} for subscripts {subscripts}"
shapes, numpy_subscripts, subscriptmap = (),'',{} shapes, numpy_subscripts, subscriptmap = (),'',{}
alphabet = list(string.ascii_lowercase) alphabet = list(string.ascii_lowercase)
for k, ss in zip(self._key_order, iss_spl[:-1]): for k, ss in zip(_key_order, iss_spl[:-1]):
dom = self._mf[k].domain dom = self._mf[k].domain
if len(dom) != len(ss): if len(dom) != len(ss):
raise ValueError(ve) raise ValueError(ve)
...@@ -218,19 +220,19 @@ class LinearEinsum(LinearOperator): ...@@ -218,19 +220,19 @@ class LinearEinsum(LinearOperator):
shapes += (self._domain.shape,) shapes += (self._domain.shape,)
numpy_subscripts += '->' numpy_subscripts += '->'
dom_sscr = dict(zip(self._key_order, iss_spl[:-1])) dom_sscr = dict(zip(_key_order, iss_spl[:-1]))
dom_sscr[id(self)] = iss_spl[-1] dom_sscr[id(self)] = iss_spl[-1]
tgt = [] tgt = []
for o in oss: for o in oss:
k_hit = tuple(k for k, sscr in dom_sscr.items() if o in sscr)[0] k_hit = tuple(k for k, sscr in dom_sscr.items() if o in sscr)[0]
dom_k_idx = dom_sscr[k_hit].index(o) dom_k_idx = dom_sscr[k_hit].index(o)
if k_hit in self._key_order: if k_hit in _key_order:
tgt += [self._mf.domain[k_hit][dom_k_idx]] tgt += [self._mf.domain[k_hit][dom_k_idx]]
else: else:
assert k_hit == id(self) assert k_hit == id(self)
tgt += [self._domain[dom_k_idx]] tgt += [self._domain[dom_k_idx]]
numpy_subscripts += "".join(subscriptmap[o]) numpy_subscripts += "".join(subscriptmap[o])
self._target = DomainTuple.make(tgt) _target = DomainTuple.make(tgt)
self._sscr = numpy_subscripts self._sscr = numpy_subscripts
if isinstance(optimize, list): if isinstance(optimize, list):
...@@ -238,9 +240,16 @@ class LinearEinsum(LinearOperator): ...@@ -238,9 +240,16 @@ class LinearEinsum(LinearOperator):
else: else:
plc = (np.broadcast_to(np.nan, shp) for shp in shapes) plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0] path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
self._ein_kw = {"optimize": path} self._init2(mf, numpy_subscripts, _key_order, path, _target)
iss, oss, *_ = numpy_subscripts.split("->") def _init2(self, mf, subscripts, keyorder, optimize, target):
self._ein_kw = {"optimize": optimize}
self._mf = mf
self._sscr = subscripts
self._key_order = keyorder
self._target = target
iss, oss, *_ = subscripts.split("->")
iss_spl = iss.split(",") iss_spl = iss.split(",")
adj_iss = ",".join((",".join(iss_spl[:-1]), oss)) adj_iss = ",".join((",".join(iss_spl[:-1]), oss))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment