Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
5c745090
Commit
5c745090
authored
May 14, 2020
by
Philipp Frank
Browse files
cleanup
parent
2cdc9e00
Pipeline
#74960
passed with stages
in 26 minutes and 20 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/einsum.py
View file @
5c745090
...
...
@@ -49,14 +49,8 @@ class MultiLinearEinsum(Operator):
optimize: bool, optional
Parameter passed on to einsum.
"""
def
__init__
(
self
,
domain
,
subscripts
,
key_order
=
None
,
static_mf
=
None
,
optimize
=
True
):
def
__init__
(
self
,
domain
,
subscripts
,
key_order
=
None
,
static_mf
=
None
,
optimize
=
True
):
self
.
_domain
=
MultiDomain
.
make
(
domain
)
if
key_order
is
None
:
self
.
_key_order
=
tuple
(
self
.
_domain
.
keys
())
...
...
@@ -73,9 +67,7 @@ class MultiLinearEinsum(Operator):
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
shapes
=
()
mysubscripts
=
""
subscriptmap
=
{}
shapes
,
mysubscripts
,
subscriptmap
=
(),
''
,{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
):
dom
=
self
.
_domain
[
k
]
if
k
in
self
.
_domain
.
keys
(
...
...
@@ -85,8 +77,7 @@ class MultiLinearEinsum(Operator):
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
dom
[
i
].
shape
)].
copy
()
for
j
in
range
(
len
(
dom
[
i
].
shape
)):
del
alphabet
[
0
]
del
alphabet
[:
len
(
dom
[
i
].
shape
)]
mysubscripts
+=
''
.
join
(
subscriptmap
[
a
])
mysubscripts
+=
','
shapes
+=
(
dom
.
shape
,
)
...
...
@@ -103,14 +94,14 @@ class MultiLinearEinsum(Operator):
ve
=
f
"
{
k_hit
}
is not in domain nor in static_mf"
raise
ValueError
(
ve
)
tgt
+=
[
self
.
_stat_mf
[
k_hit
].
domain
[
dom_k_idx
]]
mysubscripts
+=
""
.
join
(
subscriptmap
[
o
])
mysubscripts
+=
''
.
join
(
subscriptmap
[
o
])
self
.
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_sscr_endswith
=
dict
()
for
k
,
(
i
,
ss
)
in
zip
(
self
.
_key_order
,
enumerate
(
iss_spl
)):
left_ss_spl
=
(
*
iss_spl
[:
i
],
*
iss_spl
[
i
+
1
:],
ss
)
self
.
_sscr_endswith
[
k
]
=
"
->
"
.
join
(
(
","
.
join
(
left_ss_spl
),
oss
)
self
.
_sscr_endswith
[
k
]
=
'
->
'
.
join
(
(
','
.
join
(
left_ss_spl
),
oss
)
)
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
mysubscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
...
...
@@ -189,9 +180,7 @@ class LinearEinsum(LinearOperator):
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
shapes
=
()
mysubscripts
=
""
subscriptmap
=
{}
shapes
,
mysubscripts
,
subscriptmap
=
(),
''
,{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]):
dom
=
self
.
_mf
[
k
].
domain
...
...
@@ -200,8 +189,7 @@ class LinearEinsum(LinearOperator):
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
dom
[
i
].
shape
)].
copy
()
for
j
in
range
(
len
(
dom
[
i
].
shape
)):
del
alphabet
[
0
]
del
alphabet
[:
len
(
dom
[
i
].
shape
)]
mysubscripts
+=
''
.
join
(
subscriptmap
[
a
])
mysubscripts
+=
','
shapes
+=
(
dom
.
shape
,)
...
...
@@ -210,8 +198,7 @@ class LinearEinsum(LinearOperator):
for
i
,
a
in
enumerate
(
list
(
iss_spl
[
-
1
])):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
self
.
_domain
[
i
].
shape
)].
copy
()
for
j
in
range
(
len
(
self
.
_domain
[
i
].
shape
)):
del
alphabet
[
0
]
del
alphabet
[:
len
(
self
.
_domain
[
i
].
shape
)]
mysubscripts
+=
''
.
join
(
subscriptmap
[
a
])
shapes
+=
(
self
.
_domain
.
shape
,)
mysubscripts
+=
'->'
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment