Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
d296d7a7
Commit
d296d7a7
authored
May 15, 2020
by
Philipp Frank
Browse files
split LinearEinsum init for calling as lin
parent
bb066937
Pipeline
#74967
passed with stages
in 26 minutes and 16 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/einsum.py
View file @
d296d7a7
...
...
@@ -103,18 +103,14 @@ class MultiLinearEinsum(Operator):
self
.
_sscr_endswith
=
dict
()
self
.
_linpaths
=
dict
()
for
k
,
(
i
,
ss
),
nss
in
zip
(
self
.
_key_order
,
enumerate
(
iss_spl
),
numpy_iss_spl
):
left_ss_spl
=
(
*
iss_spl
[:
i
],
*
iss_spl
[
i
+
1
:],
ss
)
self
.
_sscr_endswith
[
k
]
=
'->'
.
join
(
(
','
.
join
(
left_ss_spl
),
oss
)
)
left_ss_spl
=
(
*
numpy_iss_spl
[:
i
],
*
numpy_iss_spl
[
i
+
1
:],
nss
)
for
k
,
(
i
,
ss
)
in
zip
(
self
.
_key_order
,
enumerate
(
numpy_iss_spl
)):
left_ss_spl
=
(
*
numpy_iss_spl
[:
i
],
*
numpy_iss_spl
[
i
+
1
:],
ss
)
linpath
=
'->'
.
join
((
','
.
join
(
left_ss_spl
),
numpy_oss
))
plc
=
tuple
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
q
])
for
q
in
shapes
.
keys
()
if
q
!=
k
)
plc
+=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
]),)
linpath
=
np
.
einsum_path
(
linpath
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_linpaths
[
k
]
=
linpath
self
.
_sscr_endswith
[
k
]
=
linpath
self
.
_linpaths
[
k
]
=
np
.
einsum_path
(
linpath
,
*
plc
,
optimize
=
optimize
)[
0
]
if
isinstance
(
optimize
,
list
):
path
=
optimize
else
:
...
...
@@ -151,7 +147,9 @@ class MultiLinearEinsum(Operator):
mf_wo_k
,
ss
,
key_order
=
tuple
(
plc
.
keys
()),
optimize
=
self
.
_linpaths
[
wrt
]
optimize
=
self
.
_linpaths
[
wrt
],
_target
=
self
.
_target
,
_calling_as_lin
=
True
).
ducktape
(
wrt
)
jac
=
jac
+
jac_k
if
jac
is
not
None
else
jac_k
return
x
.
new
(
Field
.
from_raw
(
self
.
target
,
res
),
jac
)
...
...
@@ -180,67 +178,78 @@ class LinearEinsum(LinearOperator):
optimize: bool, String or List, optional
Parameter passed on to einsum_path.
"""
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
'optimal'
):
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
'optimal'
,
_target
=
None
,
_calling_as_lin
=
False
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_mf
=
mf
if
key_order
is
None
:
self
.
_key_order
=
tuple
(
self
.
_mf
.
domain
.
keys
())
if
_calling_as_lin
:
self
.
_init2
(
mf
,
subscripts
,
key_order
,
optimize
,
_target
)
else
:
self
.
_key_order
=
key_order
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
iss
,
oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
sscr_consist
=
all
(
o
in
iss
for
o
in
oss
)
len_consist
=
len
(
self
.
_key_order
)
==
len
(
iss_spl
[:
-
1
])
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
shapes
,
numpy_subscripts
,
subscriptmap
=
(),
''
,{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]):
dom
=
self
.
_mf
[
k
].
domain
if
len
(
dom
)
!=
len
(
ss
):
self
.
_mf
=
mf
if
key_order
is
None
:
_key_order
=
tuple
(
self
.
_mf
.
domain
.
keys
())
else
:
_key_order
=
key_order
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
iss
,
oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
sscr_consist
=
all
(
o
in
iss
for
o
in
oss
)
len_consist
=
len
(
_key_order
)
==
len
(
iss_spl
[:
-
1
])
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
_key_order
}
for subscripts
{
subscripts
}
"
shapes
,
numpy_subscripts
,
subscriptmap
=
(),
''
,{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
_key_order
,
iss_spl
[:
-
1
]):
dom
=
self
.
_mf
[
k
].
domain
if
len
(
dom
)
!=
len
(
ss
):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
[
alphabet
.
pop
()
for
_
in
range
(
len
(
dom
[
i
].
shape
))]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
','
shapes
+=
(
dom
.
shape
,)
if
len
(
self
.
_domain
)
!=
len
(
iss_spl
[
-
1
]):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
ss
)):
for
i
,
a
in
enumerate
(
list
(
i
ss
_spl
[
-
1
]
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
[
alphabet
.
pop
()
for
_
in
range
(
len
(
dom
[
i
].
shape
))]
range
(
len
(
self
.
_domain
[
i
].
shape
))]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
','
shapes
+=
(
dom
.
shape
,)
if
len
(
self
.
_domain
)
!=
len
(
iss_spl
[
-
1
]):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
iss_spl
[
-
1
]
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
[
alphabet
.
pop
()
for
_
in
range
(
len
(
self
.
_domain
[
i
].
shape
))
]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
]
)
shapes
+=
(
self
.
_domain
.
shape
,)
numpy_subscripts
+=
'->'
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
])
)
dom_sscr
[
id
(
self
)]
=
iss_spl
[
-
1
]
tgt
=
[]
for
o
in
oss
:
k_hit
=
tuple
(
k
for
k
,
sscr
in
dom_sscr
.
items
()
if
o
in
sscr
)[
0
]
dom_k_idx
=
dom_sscr
[
k_hit
].
index
(
o
)
if
k_hit
in
self
.
_key_order
:
tgt
+=
[
self
.
_mf
.
domain
[
k_hit
][
dom_k_idx
]]
shapes
+=
(
self
.
_domain
.
shape
,)
numpy_subscripts
+=
'->'
dom_sscr
=
dict
(
zip
(
_key_order
,
iss_spl
[:
-
1
])
)
dom_sscr
[
id
(
self
)]
=
iss_spl
[
-
1
]
tgt
=
[]
for
o
in
oss
:
k_hit
=
tuple
(
k
for
k
,
sscr
in
dom_sscr
.
items
()
if
o
in
sscr
)[
0
]
dom_k_idx
=
dom_sscr
[
k_hit
].
index
(
o
)
if
k_hit
in
_key_order
:
tgt
+=
[
self
.
_mf
.
domain
[
k_hit
][
dom_k_idx
]]
else
:
assert
k_hit
==
id
(
self
)
tgt
+=
[
self
.
_domain
[
dom_k_idx
]
]
numpy_subscripts
+=
""
.
join
(
subscriptmap
[
o
])
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_sscr
=
numpy_subscripts
if
isinstance
(
optimize
,
list
)
:
path
=
optimize
else
:
assert
k_hit
==
id
(
self
)
tgt
+=
[
self
.
_domain
[
dom_k_idx
]]
numpy_subscripts
+=
""
.
join
(
subscriptmap
[
o
])
self
.
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_sscr
=
numpy_subscripts
if
isinstance
(
optimize
,
list
):
path
=
optimize
else
:
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_ein_kw
=
{
"optimize"
:
path
}
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_init2
(
mf
,
numpy_subscripts
,
_key_order
,
path
,
_target
)
iss
,
oss
,
*
_
=
numpy_subscripts
.
split
(
"->"
)
def
_init2
(
self
,
mf
,
subscripts
,
keyorder
,
optimize
,
target
):
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
self
.
_mf
=
mf
self
.
_sscr
=
subscripts
self
.
_key_order
=
keyorder
self
.
_target
=
target
iss
,
oss
,
*
_
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
adj_iss
=
","
.
join
((
","
.
join
(
iss_spl
[:
-
1
]),
oss
))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment