Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
5835e470
Commit
5835e470
authored
May 14, 2020
by
Philipp Frank
Browse files
compute einsum_path on initialization
parent
cd38eef1
Changes
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/einsum.py
View file @
5835e470
...
...
@@ -54,7 +54,7 @@ class MultiLinearEinsum(Operator):
subscripts
,
key_order
=
None
,
static_mf
=
None
,
optimize
=
Fals
e
optimize
=
Tru
e
):
self
.
_domain
=
MultiDomain
.
make
(
domain
)
self
.
_sscr
=
subscripts
...
...
@@ -66,19 +66,20 @@ class MultiLinearEinsum(Operator):
ve
=
"`key_order` mus be specified if additional fields are munged"
raise
ValueError
(
ve
)
self
.
_stat_mf
=
static_mf
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
iss
,
self
.
_oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
len_consist
=
len
(
self
.
_key_order
)
==
len
(
iss_spl
)
sscr_consist
=
all
(
o
in
iss
for
o
in
self
.
_oss
)
if
rest
or
not
sscr_consist
or
","
in
self
.
_oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
shapes
=
()
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
):
dom
=
self
.
_domain
[
k
]
if
k
in
self
.
_domain
.
keys
(
)
else
self
.
_stat_mf
[
k
].
domain
if
len
(
dom
.
shape
)
!=
len
(
ss
):
ve
=
f
"invalid order of keys
{
key_order
}
for subscripts
{
subscripts
}
"
ve
=
f
"invalid order of keys
{
self
.
_
key_order
}
for subscripts
{
subscripts
}
"
raise
ValueError
(
ve
)
shapes
+=
(
dom
.
shape
,
)
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
))
tgt
=
[]
...
...
@@ -100,6 +101,9 @@ class MultiLinearEinsum(Operator):
self
.
_sscr_endswith
[
k
]
=
"->"
.
join
(
(
","
.
join
(
left_ss_spl
),
self
.
_oss
)
)
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
self
.
_sscr
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_ein_kw
=
{
"optimize"
:
path
}
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
...
...
@@ -158,7 +162,7 @@ class LinearEinsum(LinearOperator):
optimize: bool, optional
Parameter passed on to einsum.
"""
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
Fals
e
):
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
Tru
e
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_mf
=
mf
self
.
_sscr
=
subscripts
...
...
@@ -174,11 +178,14 @@ class LinearEinsum(LinearOperator):
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
key_order
}
for subscripts
{
subscripts
}
"
shapes
=
()
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]):
if
len
(
self
.
_mf
[
k
].
shape
)
!=
len
(
ss
):
raise
ValueError
(
ve
)
shapes
+=
(
self
.
_mf
[
k
].
shape
,)
if
len
(
self
.
_domain
.
shape
)
!=
len
(
iss_spl
[
-
1
]):
raise
ValueError
(
ve
)
shapes
+=
(
self
.
_domain
.
shape
,)
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]))
dom_sscr
[
id
(
self
)]
=
iss_spl
[
-
1
]
...
...
@@ -192,6 +199,9 @@ class LinearEinsum(LinearOperator):
assert
k_hit
==
id
(
self
)
tgt
+=
[
self
.
_domain
[
dom_k_idx
]]
self
.
_target
=
DomainTuple
.
make
(
tgt
)
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
self
.
_sscr
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_ein_kw
=
{
"optimize"
:
path
}
adj_iss
=
","
.
join
((
","
.
join
(
iss_spl
[:
-
1
]),
oss
))
self
.
_adj_sscr
=
"->"
.
join
((
adj_iss
,
iss_spl
[
-
1
]))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment