Commit 2cdc9e00 authored by Philipp Frank's avatar Philipp Frank
Browse files

use einsum indices for spaces instead of numpy indices

parent 5835e470
Pipeline #74959 failed with stages
in 20 minutes and 51 seconds
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np import numpy as np
import string
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..linearization import Linearization from ..linearization import Linearization
from ..field import Field from ..field import Field
...@@ -57,7 +58,6 @@ class MultiLinearEinsum(Operator): ...@@ -57,7 +58,6 @@ class MultiLinearEinsum(Operator):
optimize=True optimize=True
): ):
self._domain = MultiDomain.make(domain) self._domain = MultiDomain.make(domain)
self._sscr = subscripts
if key_order is None: if key_order is None:
self._key_order = tuple(self._domain.keys()) self._key_order = tuple(self._domain.keys())
else: else:
...@@ -66,24 +66,34 @@ class MultiLinearEinsum(Operator): ...@@ -66,24 +66,34 @@ class MultiLinearEinsum(Operator):
ve = "`key_order` mus be specified if additional fields are munged" ve = "`key_order` mus be specified if additional fields are munged"
raise ValueError(ve) raise ValueError(ve)
self._stat_mf = static_mf self._stat_mf = static_mf
iss, self._oss, *rest = subscripts.split("->") iss, oss, *rest = subscripts.split("->")
iss_spl = iss.split(",") iss_spl = iss.split(",")
len_consist = len(self._key_order) == len(iss_spl) len_consist = len(self._key_order) == len(iss_spl)
sscr_consist = all(o in iss for o in self._oss) sscr_consist = all(o in iss for o in oss)
if rest or not sscr_consist or "," in self._oss or not len_consist: if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}") raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
shapes = () shapes = ()
mysubscripts = ""
subscriptmap = {}
alphabet = list(string.ascii_lowercase)
for k, ss in zip(self._key_order, iss_spl): for k, ss in zip(self._key_order, iss_spl):
dom = self._domain[k] if k in self._domain.keys( dom = self._domain[k] if k in self._domain.keys(
) else self._stat_mf[k].domain ) else self._stat_mf[k].domain
if len(dom.shape) != len(ss): if len(dom) != len(ss):
ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
raise ValueError(ve) raise ValueError(ve)
for i, a in enumerate(list(ss)):
if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(dom[i].shape)].copy()
for j in range(len(dom[i].shape)):
del alphabet[0]
mysubscripts += ''.join(subscriptmap[a])
mysubscripts += ','
shapes += (dom.shape, ) shapes += (dom.shape, )
mysubscripts = mysubscripts[:-1] + '->'
dom_sscr = dict(zip(self._key_order, iss_spl)) dom_sscr = dict(zip(self._key_order, iss_spl))
tgt = [] tgt = []
for o in self._oss: for o in oss:
k_hit = tuple(k for k, sscr in dom_sscr.items() if o in sscr)[0] k_hit = tuple(k for k, sscr in dom_sscr.items() if o in sscr)[0]
dom_k_idx = dom_sscr[k_hit].index(o) dom_k_idx = dom_sscr[k_hit].index(o)
if k_hit in self._domain.keys(): if k_hit in self._domain.keys():
...@@ -93,16 +103,18 @@ class MultiLinearEinsum(Operator): ...@@ -93,16 +103,18 @@ class MultiLinearEinsum(Operator):
ve = f"{k_hit} is not in domain nor in static_mf" ve = f"{k_hit} is not in domain nor in static_mf"
raise ValueError(ve) raise ValueError(ve)
tgt += [self._stat_mf[k_hit].domain[dom_k_idx]] tgt += [self._stat_mf[k_hit].domain[dom_k_idx]]
mysubscripts += "".join(subscriptmap[o])
self._target = DomainTuple.make(tgt) self._target = DomainTuple.make(tgt)
self._sscr_endswith = dict() self._sscr_endswith = dict()
for k, (i, ss) in zip(self._key_order, enumerate(iss_spl)): for k, (i, ss) in zip(self._key_order, enumerate(iss_spl)):
left_ss_spl = (*iss_spl[:i], *iss_spl[i + 1:], ss) left_ss_spl = (*iss_spl[:i], *iss_spl[i + 1:], ss)
self._sscr_endswith[k] = "->".join( self._sscr_endswith[k] = "->".join(
(",".join(left_ss_spl), self._oss) (",".join(left_ss_spl), oss)
) )
plc = (np.broadcast_to(np.nan, shp) for shp in shapes) plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(self._sscr, *plc, optimize=optimize)[0] path = np.einsum_path(mysubscripts, *plc, optimize=optimize)[0]
self._sscr = mysubscripts
self._ein_kw = {"optimize": path} self._ein_kw = {"optimize": path}
def apply(self, x): def apply(self, x):
...@@ -165,7 +177,6 @@ class LinearEinsum(LinearOperator): ...@@ -165,7 +177,6 @@ class LinearEinsum(LinearOperator):
def __init__(self, domain, mf, subscripts, key_order=None, optimize=True): def __init__(self, domain, mf, subscripts, key_order=None, optimize=True):
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._mf = mf self._mf = mf
self._sscr = subscripts
if key_order is None: if key_order is None:
self._key_order = tuple(self._mf.domain.keys()) self._key_order = tuple(self._mf.domain.keys())
else: else:
...@@ -177,15 +188,33 @@ class LinearEinsum(LinearOperator): ...@@ -177,15 +188,33 @@ class LinearEinsum(LinearOperator):
len_consist = len(self._key_order) == len(iss_spl[:-1]) len_consist = len(self._key_order) == len(iss_spl[:-1])
if rest or not sscr_consist or "," in oss or not len_consist: if rest or not sscr_consist or "," in oss or not len_consist:
raise ValueError(f"invalid subscripts specified; got {subscripts}") raise ValueError(f"invalid subscripts specified; got {subscripts}")
ve = f"invalid order of keys {key_order} for subscripts {subscripts}" ve = f"invalid order of keys {self._key_order} for subscripts {subscripts}"
shapes = () shapes = ()
mysubscripts = ""
subscriptmap = {}
alphabet = list(string.ascii_lowercase)
for k, ss in zip(self._key_order, iss_spl[:-1]): for k, ss in zip(self._key_order, iss_spl[:-1]):
if len(self._mf[k].shape) != len(ss): dom = self._mf[k].domain
if len(dom) != len(ss):
raise ValueError(ve) raise ValueError(ve)
shapes +=(self._mf[k].shape,) for i, a in enumerate(list(ss)):
if len(self._domain.shape) != len(iss_spl[-1]): if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(dom[i].shape)].copy()
for j in range(len(dom[i].shape)):
del alphabet[0]
mysubscripts += ''.join(subscriptmap[a])
mysubscripts += ','
shapes +=(dom.shape,)
if len(self._domain) != len(iss_spl[-1]):
raise ValueError(ve) raise ValueError(ve)
for i, a in enumerate(list(iss_spl[-1])):
if a not in subscriptmap.keys():
subscriptmap[a] = alphabet[:len(self._domain[i].shape)].copy()
for j in range(len(self._domain[i].shape)):
del alphabet[0]
mysubscripts += ''.join(subscriptmap[a])
shapes += (self._domain.shape,) shapes += (self._domain.shape,)
mysubscripts += '->'
dom_sscr = dict(zip(self._key_order, iss_spl[:-1])) dom_sscr = dict(zip(self._key_order, iss_spl[:-1]))
dom_sscr[id(self)] = iss_spl[-1] dom_sscr[id(self)] = iss_spl[-1]
...@@ -198,11 +227,16 @@ class LinearEinsum(LinearOperator): ...@@ -198,11 +227,16 @@ class LinearEinsum(LinearOperator):
else: else:
assert k_hit == id(self) assert k_hit == id(self)
tgt += [self._domain[dom_k_idx]] tgt += [self._domain[dom_k_idx]]
mysubscripts += "".join(subscriptmap[o])
self._target = DomainTuple.make(tgt) self._target = DomainTuple.make(tgt)
self._sscr = mysubscripts
plc = (np.broadcast_to(np.nan, shp) for shp in shapes) plc = (np.broadcast_to(np.nan, shp) for shp in shapes)
path = np.einsum_path(self._sscr, *plc, optimize=optimize)[0] path = np.einsum_path(mysubscripts, *plc, optimize=optimize)[0]
self._ein_kw = {"optimize": path} self._ein_kw = {"optimize": path}
iss, oss, *_ = mysubscripts.split("->")
iss_spl = iss.split(",")
adj_iss = ",".join((",".join(iss_spl[:-1]), oss)) adj_iss = ",".join((",".join(iss_spl[:-1]), oss))
self._adj_sscr = "->".join((adj_iss, iss_spl[-1])) self._adj_sscr = "->".join((adj_iss, iss_spl[-1]))
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
......
...@@ -21,31 +21,35 @@ from numpy.testing import assert_allclose ...@@ -21,31 +21,35 @@ from numpy.testing import assert_allclose
from nifty6.extra import check_jacobian_consistency, consistency_check from nifty6.extra import check_jacobian_consistency, consistency_check
import nifty6 as ift import nifty6 as ift
from ..common import setup_function, teardown_function from ..common import list2fixture, setup_function, teardown_function
pmp = pytest.mark.parametrize pmp = pytest.mark.parametrize
spaces = (ift.UnstructuredDomain(4),
ift.RGSpace((3,2)),
ift.LMSpace(5),
ift.HPSpace(4),
ift.GLSpace(4))
space1 = list2fixture(spaces)
space2 = list2fixture(spaces)
@pmp("n_unstructured", (3, 9))
@pmp("nside", (4, 8)) def test_linear_einsum_outer(space1, space2, n_invocations=10):
def test_linear_einsum_outer(n_unstructured, nside, n_invocations=10):
setup_function() setup_function()
pos_space = ift.HPSpace(nside)
mf_dom = ift.MultiDomain.make( mf_dom = ift.MultiDomain.make(
{ {
"dom01": "dom01": space1,
ift.UnstructuredDomain(n_unstructured),
"dom02": "dom02":
ift.DomainTuple.make( ift.DomainTuple.make(
(ift.UnstructuredDomain(n_unstructured), pos_space) (space1, space2)
) )
} }
) )
mf = ift.from_random("normal", mf_dom) mf = ift.from_random("normal", mf_dom)
ss = "i,ij,j->ij" ss = "i,ij,j->ij"
key_order = ("dom01", "dom02") key_order = ("dom01", "dom02")
le = ift.LinearEinsum(pos_space, mf, ss, key_order=key_order) le = ift.LinearEinsum(space2, mf, ss, key_order=key_order)
assert consistency_check(le) is None assert consistency_check(le) is None
le_ift = ift.DiagonalOperator( le_ift = ift.DiagonalOperator(
...@@ -63,26 +67,22 @@ def test_linear_einsum_outer(n_unstructured, nside, n_invocations=10): ...@@ -63,26 +67,22 @@ def test_linear_einsum_outer(n_unstructured, nside, n_invocations=10):
teardown_function() teardown_function()
@pmp("n_unstructured", (3, 9)) def test_linear_einsum_contraction(space1, space2, n_invocations=10):
@pmp("nside", (4, 8))
def test_linear_einsum_contraction(n_unstructured, nside, n_invocations=10):
setup_function() setup_function()
pos_space = ift.HPSpace(nside)
mf_dom = ift.MultiDomain.make( mf_dom = ift.MultiDomain.make(
{ {
"dom01": "dom01": space1,
ift.UnstructuredDomain(n_unstructured),
"dom02": "dom02":
ift.DomainTuple.make( ift.DomainTuple.make(
(ift.UnstructuredDomain(n_unstructured), pos_space) (space1, space2)
) )
} }
) )
mf = ift.from_random("normal", mf_dom) mf = ift.from_random("normal", mf_dom)
ss = "i,ij,j->i" ss = "i,ij,j->i"
key_order = ("dom01", "dom02") key_order = ("dom01", "dom02")
le = ift.LinearEinsum(pos_space, mf, ss, key_order=key_order) le = ift.LinearEinsum(space2, mf, ss, key_order=key_order)
assert consistency_check(le) is None assert consistency_check(le) is None
le_ift = ift.ContractionOperator(mf_dom["dom02"], 1) @ ift.DiagonalOperator( le_ift = ift.ContractionOperator(mf_dom["dom02"], 1) @ ift.DiagonalOperator(
...@@ -100,24 +100,16 @@ def test_linear_einsum_contraction(n_unstructured, nside, n_invocations=10): ...@@ -100,24 +100,16 @@ def test_linear_einsum_contraction(n_unstructured, nside, n_invocations=10):
teardown_function() teardown_function()
@pmp("n_unstructured", (3, 9))
@pmp("nside", (4, 8))
def test_multi_linear_einsum_outer( def test_multi_linear_einsum_outer(
n_unstructured, nside, n_invocations=10, ntries=100 space1, space2, n_invocations=10, ntries=100
): ):
setup_function() setup_function()
pos_space = ift.HPSpace(nside)
mf_dom = ift.MultiDomain.make( mf_dom = ift.MultiDomain.make(
{ {
"dom01": "dom01": space1,
ift.UnstructuredDomain(n_unstructured), "dom02":ift.DomainTuple.make((space1, space2)),
"dom02": "dom03": space2
ift.DomainTuple.make(
(ift.UnstructuredDomain(n_unstructured), pos_space)
),
"dom03":
pos_space
} }
) )
ss = "i,ij,j->ij" ss = "i,ij,j->ij"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment