Commit 724aec55 authored by Philipp Frank's avatar Philipp Frank
Browse files

fixup

parent ef11e4e2
......@@ -46,11 +46,11 @@ class MultiLinearEinsum(Operator):
`key_order` is not part of the `domain`. Fields in this object are
supposed to be static as they will not appear as FieldAdapter in the
Linearization.
optimize: bool, optional
Parameter passed on to einsum.
optimize: bool, String or List optional
Parameter passed on to einsum_path.
"""
def __init__(self, domain, subscripts,
key_order=None, static_mf=None, optimize=True):
key_order=None, static_mf=None, optimize='optimal'):
self._domain = MultiDomain.make(domain)
if key_order is None:
self._key_order = tuple(self._domain.keys())
......@@ -115,9 +115,11 @@ class MultiLinearEinsum(Operator):
plc += (np.broadcast_to(np.nan, shapes[k]),)
linpath = np.einsum_path(linpath, *plc, optimize=optimize)[0]
self._linpaths[k] = linpath
plc = (np.broadcast_to(np.nan, shapes[k]) for k in shapes.keys())
path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
if isinstance(optimize, list):
path = optimize
else:
plc = (np.broadcast_to(np.nan, shapes[k]) for k in shapes.keys())
path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
self._sscr = numpy_subscripts
self._ein_kw = {"optimize": path}
......@@ -175,10 +177,10 @@ class LinearEinsum(LinearOperator):
key_order: tuple of str, optional
The order of the keys in the multi-field. If not specified, defaults to
the order of the keys in the multi-field.
optimize: bool, optional
Parameter passed on to einsum.
optimize: bool, String or List optional
Parameter passed on to einsum_path.
"""
def __init__(self, domain, mf, subscripts, key_order=None, optimize=True):
def __init__(self, domain, mf, subscripts, key_order=None, optimize='optimal'):
self._domain = DomainTuple.make(domain)
self._mf = mf
if key_order is None:
......@@ -232,11 +234,11 @@ class LinearEinsum(LinearOperator):
self._sscr = numpy_subscripts
if isinstance(optimize, list):
self._ein_kw = {"optimize": optimize}
path = optimize
else:
plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
self._ein_kw = {"optimize": path}
self._ein_kw = {"optimize": path}
iss, oss, *_ = numpy_subscripts.split("->")
iss_spl = iss.split(",")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment