Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
bb066937
Commit
bb066937
authored
May 14, 2020
by
Philipp Frank
Browse files
pop list
parent
724aec55
Pipeline
#74963
passed with stages
in 26 minutes and 17 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/einsum.py
View file @
bb066937
...
@@ -46,7 +46,7 @@ class MultiLinearEinsum(Operator):
...
@@ -46,7 +46,7 @@ class MultiLinearEinsum(Operator):
`key_order` is not part of the `domain`. Fields in this object are
`key_order` is not part of the `domain`. Fields in this object are
supposed to be static as they will not appear as FieldAdapter in the
supposed to be static as they will not appear as FieldAdapter in the
Linearization.
Linearization.
optimize: bool, String or List optional
optimize: bool, String or List
,
optional
Parameter passed on to einsum_path.
Parameter passed on to einsum_path.
"""
"""
def
__init__
(
self
,
domain
,
subscripts
,
def
__init__
(
self
,
domain
,
subscripts
,
...
@@ -68,7 +68,7 @@ class MultiLinearEinsum(Operator):
...
@@ -68,7 +68,7 @@ class MultiLinearEinsum(Operator):
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
shapes
,
numpy_subscripts
,
subscriptmap
=
{},
''
,{}
shapes
,
numpy_subscripts
,
subscriptmap
=
{},
''
,{}
alphabet
=
list
(
string
.
ascii_lowercase
)
alphabet
=
list
(
string
.
ascii_lowercase
)
[::
-
1
]
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
):
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
):
dom
=
self
.
_domain
[
k
]
if
k
in
self
.
_domain
.
keys
(
dom
=
self
.
_domain
[
k
]
if
k
in
self
.
_domain
.
keys
(
)
else
self
.
_stat_mf
[
k
].
domain
)
else
self
.
_stat_mf
[
k
].
domain
...
@@ -76,8 +76,8 @@ class MultiLinearEinsum(Operator):
...
@@ -76,8 +76,8 @@ class MultiLinearEinsum(Operator):
raise
ValueError
(
ve
)
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
ss
)):
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
dom
[
i
].
shape
)].
c
op
y
()
subscriptmap
[
a
]
=
[
alphabet
.
p
op
()
for
_
in
del
alphabet
[:
len
(
dom
[
i
].
shape
)]
range
(
len
(
dom
[
i
].
shape
)
)
]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
','
numpy_subscripts
+=
','
shapes
[
k
]
=
dom
.
shape
shapes
[
k
]
=
dom
.
shape
...
@@ -177,7 +177,7 @@ class LinearEinsum(LinearOperator):
...
@@ -177,7 +177,7 @@ class LinearEinsum(LinearOperator):
key_order: tuple of str, optional
key_order: tuple of str, optional
The order of the keys in the multi-field. If not specified, defaults to
The order of the keys in the multi-field. If not specified, defaults to
the order of the keys in the multi-field.
the order of the keys in the multi-field.
optimize: bool, String or List optional
optimize: bool, String or List
,
optional
Parameter passed on to einsum_path.
Parameter passed on to einsum_path.
"""
"""
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
'optimal'
):
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
'optimal'
):
...
@@ -203,8 +203,8 @@ class LinearEinsum(LinearOperator):
...
@@ -203,8 +203,8 @@ class LinearEinsum(LinearOperator):
raise
ValueError
(
ve
)
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
ss
)):
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
dom
[
i
].
shape
)].
c
op
y
()
subscriptmap
[
a
]
=
[
alphabet
.
p
op
()
for
_
in
del
alphabet
[:
len
(
dom
[
i
].
shape
)]
range
(
len
(
dom
[
i
].
shape
)
)
]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
','
numpy_subscripts
+=
','
shapes
+=
(
dom
.
shape
,)
shapes
+=
(
dom
.
shape
,)
...
@@ -212,8 +212,8 @@ class LinearEinsum(LinearOperator):
...
@@ -212,8 +212,8 @@ class LinearEinsum(LinearOperator):
raise
ValueError
(
ve
)
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
iss_spl
[
-
1
])):
for
i
,
a
in
enumerate
(
list
(
iss_spl
[
-
1
])):
if
a
not
in
subscriptmap
.
keys
():
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
alphabet
[:
len
(
self
.
_domain
[
i
].
shape
)].
c
op
y
()
subscriptmap
[
a
]
=
[
alphabet
.
p
op
()
for
_
in
del
alphabet
[:
len
(
self
.
_domain
[
i
].
shape
)]
range
(
len
(
self
.
_domain
[
i
].
shape
)
)
]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
shapes
+=
(
self
.
_domain
.
shape
,)
shapes
+=
(
self
.
_domain
.
shape
,)
numpy_subscripts
+=
'->'
numpy_subscripts
+=
'->'
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment