Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
2cdc9e00
Commit
2cdc9e00
authored
May 14, 2020
by
Philipp Frank
Browse files
use einsum indices for spaces instead of numpy indices
parent
5835e470
Pipeline
#74959
failed with stages
in 20 minutes and 51 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/einsum.py
View file @
2cdc9e00
...
...
@@ -17,6 +17,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
numpy
as
np
import
string
from
..domain_tuple
import
DomainTuple
from
..linearization
import
Linearization
from
..field
import
Field
...
...
@@ -57,7 +58,6 @@ class MultiLinearEinsum(Operator):
optimize
=
True
):
self
.
_domain
=
MultiDomain
.
make
(
domain
)
self
.
_sscr
=
subscripts
if
key_order
is
None
:
self
.
_key_order
=
tuple
(
self
.
_domain
.
keys
())
else
:
...
...
@@ -66,24 +66,34 @@ class MultiLinearEinsum(Operator):
ve
=
"`key_order` mus be specified if additional fields are munged"
raise
ValueError
(
ve
)
self
.
_stat_mf
=
static_mf
iss
,
self
.
_
oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss
,
oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
len_consist
=
len
(
self
.
_key_order
)
==
len
(
iss_spl
)
sscr_consist
=
all
(
o
in
iss
for
o
in
self
.
_
oss
)
if
rest
or
not
sscr_consist
or
","
in
self
.
_
oss
or
not
len_consist
:
sscr_consist
=
all
(
o
in
iss
for
o
in
oss
)
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
shapes
=
()
mysubscripts
=
""
subscriptmap
=
{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
):
dom
=
self
.
_domain
[
k
]
if
k
in
self
.
_domain
.
keys
(
)
else
self
.
_stat_mf
[
k
].
domain
if
len
(
dom
.
shape
)
!=
len
(
ss
):
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
if
len
(
dom
)
!=
len
(
ss
):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
dom
[
i
].
shape
)].
copy
()
for
j
in
range
(
len
(
dom
[
i
].
shape
)):
del
alphabet
[
0
]
mysubscripts
+=
''
.
join
(
subscriptmap
[
a
])
mysubscripts
+=
','
shapes
+=
(
dom
.
shape
,
)
mysubscripts
=
mysubscripts
[:
-
1
]
+
'->'
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
))
tgt
=
[]
for
o
in
self
.
_
oss
:
for
o
in
oss
:
k_hit
=
tuple
(
k
for
k
,
sscr
in
dom_sscr
.
items
()
if
o
in
sscr
)[
0
]
dom_k_idx
=
dom_sscr
[
k_hit
].
index
(
o
)
if
k_hit
in
self
.
_domain
.
keys
():
...
...
@@ -93,16 +103,18 @@ class MultiLinearEinsum(Operator):
ve
=
f
"
{
k_hit
}
is not in domain nor in static_mf"
raise
ValueError
(
ve
)
tgt
+=
[
self
.
_stat_mf
[
k_hit
].
domain
[
dom_k_idx
]]
mysubscripts
+=
""
.
join
(
subscriptmap
[
o
])
self
.
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_sscr_endswith
=
dict
()
for
k
,
(
i
,
ss
)
in
zip
(
self
.
_key_order
,
enumerate
(
iss_spl
)):
left_ss_spl
=
(
*
iss_spl
[:
i
],
*
iss_spl
[
i
+
1
:],
ss
)
self
.
_sscr_endswith
[
k
]
=
"->"
.
join
(
(
","
.
join
(
left_ss_spl
),
self
.
_
oss
)
(
","
.
join
(
left_ss_spl
),
oss
)
)
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
self
.
_sscr
,
*
plc
,
optimize
=
optimize
)[
0
]
path
=
np
.
einsum_path
(
mysubscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_sscr
=
mysubscripts
self
.
_ein_kw
=
{
"optimize"
:
path
}
def
apply
(
self
,
x
):
...
...
@@ -165,7 +177,6 @@ class LinearEinsum(LinearOperator):
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
True
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_mf
=
mf
self
.
_sscr
=
subscripts
if
key_order
is
None
:
self
.
_key_order
=
tuple
(
self
.
_mf
.
domain
.
keys
())
else
:
...
...
@@ -177,15 +188,33 @@ class LinearEinsum(LinearOperator):
len_consist
=
len
(
self
.
_key_order
)
==
len
(
iss_spl
[:
-
1
])
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
key_order
}
for subscripts
{
subscripts
}
"
ve
=
f
"invalid order of keys
{
self
.
_
key_order
}
for subscripts
{
subscripts
}
"
shapes
=
()
mysubscripts
=
""
subscriptmap
=
{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]):
if
len
(
self
.
_mf
[
k
].
shape
)
!=
len
(
ss
):
dom
=
self
.
_mf
[
k
].
domain
if
len
(
dom
)
!=
len
(
ss
):
raise
ValueError
(
ve
)
shapes
+=
(
self
.
_mf
[
k
].
shape
,)
if
len
(
self
.
_domain
.
shape
)
!=
len
(
iss_spl
[
-
1
]):
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
dom
[
i
].
shape
)].
copy
()
for
j
in
range
(
len
(
dom
[
i
].
shape
)):
del
alphabet
[
0
]
mysubscripts
+=
''
.
join
(
subscriptmap
[
a
])
mysubscripts
+=
','
shapes
+=
(
dom
.
shape
,)
if
len
(
self
.
_domain
)
!=
len
(
iss_spl
[
-
1
]):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
iss_spl
[
-
1
])):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
self
.
_domain
[
i
].
shape
)].
copy
()
for
j
in
range
(
len
(
self
.
_domain
[
i
].
shape
)):
del
alphabet
[
0
]
mysubscripts
+=
''
.
join
(
subscriptmap
[
a
])
shapes
+=
(
self
.
_domain
.
shape
,)
mysubscripts
+=
'->'
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]))
dom_sscr
[
id
(
self
)]
=
iss_spl
[
-
1
]
...
...
@@ -198,10 +227,15 @@ class LinearEinsum(LinearOperator):
else
:
assert
k_hit
==
id
(
self
)
tgt
+=
[
self
.
_domain
[
dom_k_idx
]]
mysubscripts
+=
""
.
join
(
subscriptmap
[
o
])
self
.
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_sscr
=
mysubscripts
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
self
.
_sscr
,
*
plc
,
optimize
=
optimize
)[
0
]
path
=
np
.
einsum_path
(
mysubscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_ein_kw
=
{
"optimize"
:
path
}
iss
,
oss
,
*
_
=
mysubscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
adj_iss
=
","
.
join
((
","
.
join
(
iss_spl
[:
-
1
]),
oss
))
self
.
_adj_sscr
=
"->"
.
join
((
adj_iss
,
iss_spl
[
-
1
]))
...
...
test/test_operators/test_einsum.py
View file @
2cdc9e00
...
...
@@ -21,31 +21,35 @@ from numpy.testing import assert_allclose
from
nifty6.extra
import
check_jacobian_consistency
,
consistency_check
import
nifty6
as
ift
from
..common
import
setup_function
,
teardown_function
from
..common
import
list2fixture
,
setup_function
,
teardown_function
pmp
=
pytest
.
mark
.
parametrize
spaces
=
(
ift
.
UnstructuredDomain
(
4
),
ift
.
RGSpace
((
3
,
2
)),
ift
.
LMSpace
(
5
),
ift
.
HPSpace
(
4
),
ift
.
GLSpace
(
4
))
space1
=
list2fixture
(
spaces
)
space2
=
list2fixture
(
spaces
)
@
pmp
(
"n_unstructured"
,
(
3
,
9
))
@
pmp
(
"nside"
,
(
4
,
8
))
def
test_linear_einsum_outer
(
n_unstructured
,
nside
,
n_invocations
=
10
):
def
test_linear_einsum_outer
(
space1
,
space2
,
n_invocations
=
10
):
setup_function
()
pos_space
=
ift
.
HPSpace
(
nside
)
mf_dom
=
ift
.
MultiDomain
.
make
(
{
"dom01"
:
ift
.
UnstructuredDomain
(
n_unstructured
),
"dom01"
:
space1
,
"dom02"
:
ift
.
DomainTuple
.
make
(
(
ift
.
UnstructuredDomain
(
n_unstructured
),
pos_
space
)
(
space1
,
space
2
)
)
}
)
mf
=
ift
.
from_random
(
"normal"
,
mf_dom
)
ss
=
"i,ij,j->ij"
key_order
=
(
"dom01"
,
"dom02"
)
le
=
ift
.
LinearEinsum
(
pos_
space
,
mf
,
ss
,
key_order
=
key_order
)
le
=
ift
.
LinearEinsum
(
space
2
,
mf
,
ss
,
key_order
=
key_order
)
assert
consistency_check
(
le
)
is
None
le_ift
=
ift
.
DiagonalOperator
(
...
...
@@ -63,26 +67,22 @@ def test_linear_einsum_outer(n_unstructured, nside, n_invocations=10):
teardown_function
()
@
pmp
(
"n_unstructured"
,
(
3
,
9
))
@
pmp
(
"nside"
,
(
4
,
8
))
def
test_linear_einsum_contraction
(
n_unstructured
,
nside
,
n_invocations
=
10
):
def
test_linear_einsum_contraction
(
space1
,
space2
,
n_invocations
=
10
):
setup_function
()
pos_space
=
ift
.
HPSpace
(
nside
)
mf_dom
=
ift
.
MultiDomain
.
make
(
{
"dom01"
:
ift
.
UnstructuredDomain
(
n_unstructured
),
"dom01"
:
space1
,
"dom02"
:
ift
.
DomainTuple
.
make
(
(
ift
.
UnstructuredDomain
(
n_unstructured
),
pos_
space
)
(
space1
,
space
2
)
)
}
)
mf
=
ift
.
from_random
(
"normal"
,
mf_dom
)
ss
=
"i,ij,j->i"
key_order
=
(
"dom01"
,
"dom02"
)
le
=
ift
.
LinearEinsum
(
pos_
space
,
mf
,
ss
,
key_order
=
key_order
)
le
=
ift
.
LinearEinsum
(
space
2
,
mf
,
ss
,
key_order
=
key_order
)
assert
consistency_check
(
le
)
is
None
le_ift
=
ift
.
ContractionOperator
(
mf_dom
[
"dom02"
],
1
)
@
ift
.
DiagonalOperator
(
...
...
@@ -100,24 +100,16 @@ def test_linear_einsum_contraction(n_unstructured, nside, n_invocations=10):
teardown_function
()
@
pmp
(
"n_unstructured"
,
(
3
,
9
))
@
pmp
(
"nside"
,
(
4
,
8
))
def
test_multi_linear_einsum_outer
(
n_unstructured
,
nside
,
n_invocations
=
10
,
ntries
=
100
space1
,
space2
,
n_invocations
=
10
,
ntries
=
100
):
setup_function
()
pos_space
=
ift
.
HPSpace
(
nside
)
mf_dom
=
ift
.
MultiDomain
.
make
(
{
"dom01"
:
ift
.
UnstructuredDomain
(
n_unstructured
),
"dom02"
:
ift
.
DomainTuple
.
make
(
(
ift
.
UnstructuredDomain
(
n_unstructured
),
pos_space
)
),
"dom03"
:
pos_space
"dom01"
:
space1
,
"dom02"
:
ift
.
DomainTuple
.
make
((
space1
,
space2
)),
"dom03"
:
space2
}
)
ss
=
"i,ij,j->ij"
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment