Commit 5835e470 authored by Philipp Frank's avatar Philipp Frank
Browse files

compute einsum_path on initialization

parent cd38eef1
...@@ -54,7 +54,7 @@ class MultiLinearEinsum(Operator): ...@@ -54,7 +54,7 @@ class MultiLinearEinsum(Operator):
subscripts, subscripts,
key_order=None, key_order=None,
static_mf=None, static_mf=None,
optimize=False optimize=True
): ):
self._domain = MultiDomain.make(domain) self._domain = MultiDomain.make(domain)
self._sscr = subscripts self._sscr = subscripts
...@@ -66,19 +66,20 @@ class MultiLinearEinsum(Operator): ...@@ -66,19 +66,20 @@ class MultiLinearEinsum(Operator):
ve = "`key_order` mus be specified if additional fields are munged" ve = "`key_order` mus be specified if additional fields are munged"
raise ValueError(ve) raise ValueError(ve)
self._stat_mf = static_mf self._stat_mf = static_mf
self._ein_kw = {"optimize": optimize}
iss, self._oss, *rest = subscripts.split("->") iss, self._oss, *rest = subscripts.split("->")
iss_spl = iss.split(",") iss_spl = iss.split(",")
len_consist = len(self._key_order) == len(iss_spl) len_consist = len(self._key_order) == len(iss_spl)
sscr_consist = all(o in iss for o in self._oss) sscr_consist = all(o in iss for o in self._oss)
if rest or not sscr_consist or "," in self._oss or not len_consist: if rest or not sscr_consist or "," in self._oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}") raise ValueError(f"invalid subscripts specified; got {subscripts}")
shapes = ()
for k, ss in zip(self._key_order, iss_spl): for k, ss in zip(self._key_order, iss_spl):
dom = self._domain[k] if k in self._domain.keys( dom = self._domain[k] if k in self._domain.keys(
) else self._stat_mf[k].domain ) else self._stat_mf[k].domain
if len(dom.shape) != len(ss): if len(dom.shape) != len(ss):
ve = f"invalid order of keys {key_order} for subscripts {subscripts}" ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
raise ValueError(ve) raise ValueError(ve)
shapes += (dom.shape, )
dom_sscr = dict(zip(self._key_order, iss_spl)) dom_sscr = dict(zip(self._key_order, iss_spl))
tgt = [] tgt = []
...@@ -100,6 +101,9 @@ class MultiLinearEinsum(Operator): ...@@ -100,6 +101,9 @@ class MultiLinearEinsum(Operator):
self._sscr_endswith[k] = "->".join( self._sscr_endswith[k] = "->".join(
(",".join(left_ss_spl), self._oss) (",".join(left_ss_spl), self._oss)
) )
plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(self._sscr, *plc, optimize=optimize)[0]
self._ein_kw = {"optimize": path}
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
...@@ -158,7 +162,7 @@ class LinearEinsum(LinearOperator): ...@@ -158,7 +162,7 @@ class LinearEinsum(LinearOperator):
optimize: bool, optional optimize: bool, optional
Parameter passed on to einsum. Parameter passed on to einsum.
""" """
def __init__(self, domain, mf, subscripts, key_order=None, optimize=False): def __init__(self, domain, mf, subscripts, key_order=None, optimize=True):
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._mf = mf self._mf = mf
self._sscr = subscripts self._sscr = subscripts
...@@ -174,11 +178,14 @@ class LinearEinsum(LinearOperator): ...@@ -174,11 +178,14 @@ class LinearEinsum(LinearOperator):
if rest or not sscr_consist or "," in oss or not len_consist: if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}") raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {key_order} for subscripts {subscripts}" ve = f"invalid order of keys {key_order} for subscripts {subscripts}"
shapes = ()
for k, ss in zip(self._key_order, iss_spl[:-1]): for k, ss in zip(self._key_order, iss_spl[:-1]):
if len(self._mf[k].shape) != len(ss): if len(self._mf[k].shape) != len(ss):
raise ValueError(ve) raise ValueError(ve)
shapes +=(self._mf[k].shape,)
if len(self._domain.shape) != len(iss_spl[-1]): if len(self._domain.shape) != len(iss_spl[-1]):
raise ValueError(ve) raise ValueError(ve)
shapes += (self._domain.shape,)
dom_sscr = dict(zip(self._key_order, iss_spl[:-1])) dom_sscr = dict(zip(self._key_order, iss_spl[:-1]))
dom_sscr[id(self)] = iss_spl[-1] dom_sscr[id(self)] = iss_spl[-1]
...@@ -192,6 +199,9 @@ class LinearEinsum(LinearOperator): ...@@ -192,6 +199,9 @@ class LinearEinsum(LinearOperator):
assert k_hit == id(self) assert k_hit == id(self)
tgt += [self._domain[dom_k_idx]] tgt += [self._domain[dom_k_idx]]
self._target = DomainTuple.make(tgt) self._target = DomainTuple.make(tgt)
plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(self._sscr, *plc, optimize=optimize)[0]
self._ein_kw = {"optimize": path}
adj_iss = ",".join((",".join(iss_spl[:-1]), oss)) adj_iss = ",".join((",".join(iss_spl[:-1]), oss))
self._adj_sscr = "->".join((adj_iss, iss_spl[-1])) self._adj_sscr = "->".join((adj_iss, iss_spl[-1]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment