Commit bb066937 authored by Philipp Frank's avatar Philipp Frank
Browse files

pop list

parent 724aec55
Pipeline #74963 passed with stages
in 26 minutes and 17 seconds
...@@ -46,7 +46,7 @@ class MultiLinearEinsum(Operator): ...@@ -46,7 +46,7 @@ class MultiLinearEinsum(Operator):
`key_order` is not part of the `domain`. Fields in this object are `key_order` is not part of the `domain`. Fields in this object are
supposed to be static as they will not appear as FieldAdapter in the supposed to be static as they will not appear as FieldAdapter in the
Linearization. Linearization.
optimize: bool, String or List optional optimize: bool, String or List, optional
Parameter passed on to einsum_path. Parameter passed on to einsum_path.
""" """
def __init__(self, domain, subscripts, def __init__(self, domain, subscripts,
...@@ -68,7 +68,7 @@ class MultiLinearEinsum(Operator): ...@@ -68,7 +68,7 @@ class MultiLinearEinsum(Operator):
raise ValueError(f"invalid subscripts specified; got {subscripts}") raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}" ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
shapes, numpy_subscripts, subscriptmap = {},'',{} shapes, numpy_subscripts, subscriptmap = {},'',{}
alphabet = list(string.ascii_lowercase) alphabet = list(string.ascii_lowercase)[::-1]
for k, ss in zip(self._key_order, iss_spl): for k, ss in zip(self._key_order, iss_spl):
dom = self._domain[k] if k in self._domain.keys( dom = self._domain[k] if k in self._domain.keys(
) else self._stat_mf[k].domain ) else self._stat_mf[k].domain
...@@ -76,8 +76,8 @@ class MultiLinearEinsum(Operator): ...@@ -76,8 +76,8 @@ class MultiLinearEinsum(Operator):
raise ValueError(ve) raise ValueError(ve)
for i, a in enumerate(list(ss)): for i, a in enumerate(list(ss)):
if a not in subscriptmap.keys(): if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(dom[i].shape)].copy() subscriptmap[a] = [alphabet.pop() for _ in
del alphabet[:len(dom[i].shape)] range(len(dom[i].shape))]
numpy_subscripts += ''.join(subscriptmap[a]) numpy_subscripts += ''.join(subscriptmap[a])
numpy_subscripts += ',' numpy_subscripts += ','
shapes[k] = dom.shape shapes[k] = dom.shape
...@@ -177,7 +177,7 @@ class LinearEinsum(LinearOperator): ...@@ -177,7 +177,7 @@ class LinearEinsum(LinearOperator):
key_order: tuple of str, optional key_order: tuple of str, optional
The order of the keys in the multi-field. If not specified, defaults to The order of the keys in the multi-field. If not specified, defaults to
the order of the keys in the multi-field. the order of the keys in the multi-field.
optimize: bool, String or List optional optimize: bool, String or List, optional
Parameter passed on to einsum_path. Parameter passed on to einsum_path.
""" """
def __init__(self, domain, mf, subscripts, key_order=None, optimize='optimal'): def __init__(self, domain, mf, subscripts, key_order=None, optimize='optimal'):
...@@ -203,8 +203,8 @@ class LinearEinsum(LinearOperator): ...@@ -203,8 +203,8 @@ class LinearEinsum(LinearOperator):
raise ValueError(ve) raise ValueError(ve)
for i, a in enumerate(list(ss)): for i, a in enumerate(list(ss)):
if a not in subscriptmap.keys(): if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(dom[i].shape)].copy() subscriptmap[a] = [alphabet.pop() for _ in
del alphabet[:len(dom[i].shape)] range(len(dom[i].shape))]
numpy_subscripts += ''.join(subscriptmap[a]) numpy_subscripts += ''.join(subscriptmap[a])
numpy_subscripts += ',' numpy_subscripts += ','
shapes +=(dom.shape,) shapes +=(dom.shape,)
...@@ -212,8 +212,8 @@ class LinearEinsum(LinearOperator): ...@@ -212,8 +212,8 @@ class LinearEinsum(LinearOperator):
raise ValueError(ve) raise ValueError(ve)
for i, a in enumerate(list(iss_spl[-1])): for i, a in enumerate(list(iss_spl[-1])):
if a not in subscriptmap.keys(): if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(self._domain[i].shape)].copy() subscriptmap[a] = [alphabet.pop() for _ in
del alphabet[:len(self._domain[i].shape)] range(len(self._domain[i].shape))]
numpy_subscripts += ''.join(subscriptmap[a]) numpy_subscripts += ''.join(subscriptmap[a])
shapes += (self._domain.shape,) shapes += (self._domain.shape,)
numpy_subscripts += '->' numpy_subscripts += '->'
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment