Commit 5c745090 authored by Philipp Frank's avatar Philipp Frank
Browse files

cleanup

parent 2cdc9e00
Pipeline #74960 passed with stages
in 26 minutes and 20 seconds
......@@ -49,14 +49,8 @@ class MultiLinearEinsum(Operator):
optimize: bool, optional
Parameter passed on to einsum.
"""
def __init__(
self,
domain,
subscripts,
key_order=None,
static_mf=None,
optimize=True
):
def __init__(self, domain, subscripts,
key_order=None, static_mf=None, optimize=True):
self._domain = MultiDomain.make(domain)
if key_order is None:
self._key_order = tuple(self._domain.keys())
......@@ -73,9 +67,7 @@ class MultiLinearEinsum(Operator):
if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
shapes = ()
mysubscripts = ""
subscriptmap = {}
shapes, mysubscripts, subscriptmap = (),'',{}
alphabet = list(string.ascii_lowercase)
for k, ss in zip(self._key_order, iss_spl):
dom = self._domain[k] if k in self._domain.keys(
......@@ -85,8 +77,7 @@ class MultiLinearEinsum(Operator):
for i, a in enumerate(list(ss)):
if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(dom[i].shape)].copy()
for j in range(len(dom[i].shape)):
del alphabet[0]
del alphabet[:len(dom[i].shape)]
mysubscripts += ''.join(subscriptmap[a])
mysubscripts += ','
shapes += (dom.shape, )
......@@ -103,14 +94,14 @@ class MultiLinearEinsum(Operator):
ve = f"{k_hit} is not in domain nor in static_mf"
raise ValueError(ve)
tgt += [self._stat_mf[k_hit].domain[dom_k_idx]]
mysubscripts += "".join(subscriptmap[o])
mysubscripts += ''.join(subscriptmap[o])
self._target = DomainTuple.make(tgt)
self._sscr_endswith = dict()
for k, (i, ss) in zip(self._key_order, enumerate(iss_spl)):
left_ss_spl = (*iss_spl[:i], *iss_spl[i + 1:], ss)
self._sscr_endswith[k] = "->".join(
(",".join(left_ss_spl), oss)
self._sscr_endswith[k] = '->'.join(
(','.join(left_ss_spl), oss)
)
plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(mysubscripts, *plc, optimize=optimize)[0]
......@@ -189,9 +180,7 @@ class LinearEinsum(LinearOperator):
if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
shapes = ()
mysubscripts = ""
subscriptmap = {}
shapes, mysubscripts, subscriptmap = (),'',{}
alphabet = list(string.ascii_lowercase)
for k, ss in zip(self._key_order, iss_spl[:-1]):
dom = self._mf[k].domain
......@@ -200,8 +189,7 @@ class LinearEinsum(LinearOperator):
for i, a in enumerate(list(ss)):
if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(dom[i].shape)].copy()
for j in range(len(dom[i].shape)):
del alphabet[0]
del alphabet[:len(dom[i].shape)]
mysubscripts += ''.join(subscriptmap[a])
mysubscripts += ','
shapes +=(dom.shape,)
......@@ -210,8 +198,7 @@ class LinearEinsum(LinearOperator):
for i, a in enumerate(list(iss_spl[-1])):
if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(self._domain[i].shape)].copy()
for j in range(len(self._domain[i].shape)):
del alphabet[0]
del alphabet[:len(self._domain[i].shape)]
mysubscripts += ''.join(subscriptmap[a])
shapes += (self._domain.shape,)
mysubscripts += '->'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment