Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
ef11e4e2
Commit
ef11e4e2
authored
May 14, 2020
by
Philipp Frank
Browse files
first version precalculate paths for lin
parent
5c745090
Changes
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/einsum.py
View file @
ef11e4e2
...
...
@@ -67,7 +67,7 @@ class MultiLinearEinsum(Operator):
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
shapes
,
my
subscripts
,
subscriptmap
=
()
,
''
,{}
shapes
,
numpy_
subscripts
,
subscriptmap
=
{}
,
''
,{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
):
dom
=
self
.
_domain
[
k
]
if
k
in
self
.
_domain
.
keys
(
...
...
@@ -78,10 +78,10 @@ class MultiLinearEinsum(Operator):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
dom
[
i
].
shape
)].
copy
()
del
alphabet
[:
len
(
dom
[
i
].
shape
)]
my
subscripts
+=
''
.
join
(
subscriptmap
[
a
])
my
subscripts
+=
','
shapes
+
=
(
dom
.
shape
,
)
my
subscripts
=
my
subscripts
[:
-
1
]
+
'->'
numpy_
subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_
subscripts
+=
','
shapes
[
k
]
=
dom
.
shape
numpy_
subscripts
=
numpy_
subscripts
[:
-
1
]
+
'->'
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
))
tgt
=
[]
for
o
in
oss
:
...
...
@@ -94,18 +94,31 @@ class MultiLinearEinsum(Operator):
ve
=
f
"
{
k_hit
}
is not in domain nor in static_mf"
raise
ValueError
(
ve
)
tgt
+=
[
self
.
_stat_mf
[
k_hit
].
domain
[
dom_k_idx
]]
my
subscripts
+=
''
.
join
(
subscriptmap
[
o
])
numpy_
subscripts
+=
''
.
join
(
subscriptmap
[
o
])
self
.
_target
=
DomainTuple
.
make
(
tgt
)
numpy_iss
,
numpy_oss
,
*
_
=
numpy_subscripts
.
split
(
"->"
)
numpy_iss_spl
=
numpy_iss
.
split
(
","
)
self
.
_sscr_endswith
=
dict
()
for
k
,
(
i
,
ss
)
in
zip
(
self
.
_key_order
,
enumerate
(
iss_spl
)):
self
.
_linpaths
=
dict
()
for
k
,
(
i
,
ss
),
nss
in
zip
(
self
.
_key_order
,
enumerate
(
iss_spl
),
numpy_iss_spl
):
left_ss_spl
=
(
*
iss_spl
[:
i
],
*
iss_spl
[
i
+
1
:],
ss
)
self
.
_sscr_endswith
[
k
]
=
'->'
.
join
(
(
','
.
join
(
left_ss_spl
),
oss
)
)
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
mysubscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_sscr
=
mysubscripts
left_ss_spl
=
(
*
numpy_iss_spl
[:
i
],
*
numpy_iss_spl
[
i
+
1
:],
nss
)
linpath
=
'->'
.
join
((
','
.
join
(
left_ss_spl
),
numpy_oss
))
plc
=
tuple
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
q
])
for
q
in
shapes
.
keys
()
if
q
!=
k
)
plc
+=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
]),)
linpath
=
np
.
einsum_path
(
linpath
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_linpaths
[
k
]
=
linpath
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
])
for
k
in
shapes
.
keys
())
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_sscr
=
numpy_subscripts
self
.
_ein_kw
=
{
"optimize"
:
path
}
def
apply
(
self
,
x
):
...
...
@@ -136,7 +149,7 @@ class MultiLinearEinsum(Operator):
mf_wo_k
,
ss
,
key_order
=
tuple
(
plc
.
keys
()),
**
self
.
_
e
in
_kw
optimize
=
self
.
_
l
in
paths
[
wrt
]
).
ducktape
(
wrt
)
jac
=
jac
+
jac_k
if
jac
is
not
None
else
jac_k
return
x
.
new
(
Field
.
from_raw
(
self
.
target
,
res
),
jac
)
...
...
@@ -180,7 +193,7 @@ class LinearEinsum(LinearOperator):
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
shapes
,
my
subscripts
,
subscriptmap
=
(),
''
,{}
shapes
,
numpy_
subscripts
,
subscriptmap
=
(),
''
,{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]):
dom
=
self
.
_mf
[
k
].
domain
...
...
@@ -190,8 +203,8 @@ class LinearEinsum(LinearOperator):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
dom
[
i
].
shape
)].
copy
()
del
alphabet
[:
len
(
dom
[
i
].
shape
)]
my
subscripts
+=
''
.
join
(
subscriptmap
[
a
])
my
subscripts
+=
','
numpy_
subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_
subscripts
+=
','
shapes
+=
(
dom
.
shape
,)
if
len
(
self
.
_domain
)
!=
len
(
iss_spl
[
-
1
]):
raise
ValueError
(
ve
)
...
...
@@ -199,9 +212,9 @@ class LinearEinsum(LinearOperator):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
self
.
_domain
[
i
].
shape
)].
copy
()
del
alphabet
[:
len
(
self
.
_domain
[
i
].
shape
)]
my
subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_
subscripts
+=
''
.
join
(
subscriptmap
[
a
])
shapes
+=
(
self
.
_domain
.
shape
,)
my
subscripts
+=
'->'
numpy_
subscripts
+=
'->'
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]))
dom_sscr
[
id
(
self
)]
=
iss_spl
[
-
1
]
...
...
@@ -214,14 +227,18 @@ class LinearEinsum(LinearOperator):
else
:
assert
k_hit
==
id
(
self
)
tgt
+=
[
self
.
_domain
[
dom_k_idx
]]
my
subscripts
+=
""
.
join
(
subscriptmap
[
o
])
numpy_
subscripts
+=
""
.
join
(
subscriptmap
[
o
])
self
.
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_sscr
=
mysubscripts
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
mysubscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_ein_kw
=
{
"optimize"
:
path
}
self
.
_sscr
=
numpy_subscripts
if
isinstance
(
optimize
,
list
):
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
else
:
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_ein_kw
=
{
"optimize"
:
path
}
iss
,
oss
,
*
_
=
my
subscripts
.
split
(
"->"
)
iss
,
oss
,
*
_
=
numpy_
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
adj_iss
=
","
.
join
((
","
.
join
(
iss_spl
[:
-
1
]),
oss
))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment