Commit ef11e4e2 authored by Philipp Frank's avatar Philipp Frank
Browse files

first version precalculate paths for lin

parent 5c745090
......@@ -67,7 +67,7 @@ class MultiLinearEinsum(Operator):
if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
shapes, mysubscripts, subscriptmap = (),'',{}
shapes, numpy_subscripts, subscriptmap = {},'',{}
alphabet = list(string.ascii_lowercase)
for k, ss in zip(self._key_order, iss_spl):
dom = self._domain[k] if k in self._domain.keys(
......@@ -78,10 +78,10 @@ class MultiLinearEinsum(Operator):
if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(dom[i].shape)].copy()
del alphabet[:len(dom[i].shape)]
mysubscripts += ''.join(subscriptmap[a])
mysubscripts += ','
shapes += (dom.shape, )
mysubscripts = mysubscripts[:-1] + '->'
numpy_subscripts += ''.join(subscriptmap[a])
numpy_subscripts += ','
shapes[k] = dom.shape
numpy_subscripts = numpy_subscripts[:-1] + '->'
dom_sscr = dict(zip(self._key_order, iss_spl))
tgt = []
for o in oss:
......@@ -94,18 +94,31 @@ class MultiLinearEinsum(Operator):
ve = f"{k_hit} is not in domain nor in static_mf"
raise ValueError(ve)
tgt += [self._stat_mf[k_hit].domain[dom_k_idx]]
mysubscripts += ''.join(subscriptmap[o])
numpy_subscripts += ''.join(subscriptmap[o])
self._target = DomainTuple.make(tgt)
numpy_iss, numpy_oss, *_ = numpy_subscripts.split("->")
numpy_iss_spl = numpy_iss.split(",")
self._sscr_endswith = dict()
for k, (i, ss) in zip(self._key_order, enumerate(iss_spl)):
self._linpaths = dict()
for k, (i, ss), nss in zip(self._key_order, enumerate(iss_spl),
numpy_iss_spl):
left_ss_spl = (*iss_spl[:i], *iss_spl[i + 1:], ss)
self._sscr_endswith[k] = '->'.join(
(','.join(left_ss_spl), oss)
)
plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(mysubscripts, *plc, optimize=optimize)[0]
self._sscr = mysubscripts
left_ss_spl = (*numpy_iss_spl[:i], *numpy_iss_spl[i + 1:], nss)
linpath = '->'.join((','.join(left_ss_spl), numpy_oss))
plc = tuple(np.broadcast_to(np.nan, shapes[q]) for q in shapes.keys() if q!=k)
plc += (np.broadcast_to(np.nan, shapes[k]),)
linpath = np.einsum_path(linpath, *plc, optimize=optimize)[0]
self._linpaths[k] = linpath
plc = (np.broadcast_to(np.nan, shapes[k]) for k in shapes.keys())
path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
self._sscr = numpy_subscripts
self._ein_kw = {"optimize": path}
def apply(self, x):
......@@ -136,7 +149,7 @@ class MultiLinearEinsum(Operator):
mf_wo_k,
ss,
key_order=tuple(plc.keys()),
**self._ein_kw
optimize=self._linpaths[wrt]
).ducktape(wrt)
jac = jac + jac_k if jac is not None else jac_k
return x.new(Field.from_raw(self.target, res), jac)
......@@ -180,7 +193,7 @@ class LinearEinsum(LinearOperator):
if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
shapes, mysubscripts, subscriptmap = (),'',{}
shapes, numpy_subscripts, subscriptmap = (),'',{}
alphabet = list(string.ascii_lowercase)
for k, ss in zip(self._key_order, iss_spl[:-1]):
dom = self._mf[k].domain
......@@ -190,8 +203,8 @@ class LinearEinsum(LinearOperator):
if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(dom[i].shape)].copy()
del alphabet[:len(dom[i].shape)]
mysubscripts += ''.join(subscriptmap[a])
mysubscripts += ','
numpy_subscripts += ''.join(subscriptmap[a])
numpy_subscripts += ','
shapes +=(dom.shape,)
if len(self._domain) != len(iss_spl[-1]):
raise ValueError(ve)
......@@ -199,9 +212,9 @@ class LinearEinsum(LinearOperator):
if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(self._domain[i].shape)].copy()
del alphabet[:len(self._domain[i].shape)]
mysubscripts += ''.join(subscriptmap[a])
numpy_subscripts += ''.join(subscriptmap[a])
shapes += (self._domain.shape,)
mysubscripts += '->'
numpy_subscripts += '->'
dom_sscr = dict(zip(self._key_order, iss_spl[:-1]))
dom_sscr[id(self)] = iss_spl[-1]
......@@ -214,14 +227,18 @@ class LinearEinsum(LinearOperator):
else:
assert k_hit == id(self)
tgt += [self._domain[dom_k_idx]]
mysubscripts += "".join(subscriptmap[o])
numpy_subscripts += "".join(subscriptmap[o])
self._target = DomainTuple.make(tgt)
self._sscr = mysubscripts
plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(mysubscripts, *plc, optimize=optimize)[0]
self._ein_kw = {"optimize": path}
self._sscr = numpy_subscripts
if isinstance(optimize, list):
self._ein_kw = {"optimize": optimize}
else:
plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
self._ein_kw = {"optimize": path}
iss, oss, *_ = mysubscripts.split("->")
iss, oss, *_ = numpy_subscripts.split("->")
iss_spl = iss.split(",")
adj_iss = ",".join((",".join(iss_spl[:-1]), oss))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment