Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
724aec55
Commit
724aec55
authored
May 14, 2020
by
Philipp Frank
Browse files
fixup
parent
ef11e4e2
Changes
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/einsum.py
View file @
724aec55
...
...
@@ -46,11 +46,11 @@ class MultiLinearEinsum(Operator):
`key_order` is not part of the `domain`. Fields in this object are
supposed to be static as they will not appear as FieldAdapter in the
Linearization.
optimize: bool, optional
Parameter passed on to einsum.
optimize: bool,
String or List
optional
Parameter passed on to einsum
_path
.
"""
def
__init__
(
self
,
domain
,
subscripts
,
key_order
=
None
,
static_mf
=
None
,
optimize
=
True
):
key_order
=
None
,
static_mf
=
None
,
optimize
=
'optimal'
):
self
.
_domain
=
MultiDomain
.
make
(
domain
)
if
key_order
is
None
:
self
.
_key_order
=
tuple
(
self
.
_domain
.
keys
())
...
...
@@ -115,9 +115,11 @@ class MultiLinearEinsum(Operator):
plc
+=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
]),)
linpath
=
np
.
einsum_path
(
linpath
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_linpaths
[
k
]
=
linpath
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
])
for
k
in
shapes
.
keys
())
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
if
isinstance
(
optimize
,
list
):
path
=
optimize
else
:
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
])
for
k
in
shapes
.
keys
())
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_sscr
=
numpy_subscripts
self
.
_ein_kw
=
{
"optimize"
:
path
}
...
...
@@ -175,10 +177,10 @@ class LinearEinsum(LinearOperator):
key_order: tuple of str, optional
The order of the keys in the multi-field. If not specified, defaults to
the order of the keys in the multi-field.
optimize: bool, optional
Parameter passed on to einsum.
optimize: bool,
String or List
optional
Parameter passed on to einsum
_path
.
"""
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
True
):
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
'optimal'
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_mf
=
mf
if
key_order
is
None
:
...
...
@@ -232,11 +234,11 @@ class LinearEinsum(LinearOperator):
self
.
_sscr
=
numpy_subscripts
if
isinstance
(
optimize
,
list
):
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
path
=
optimize
else
:
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_ein_kw
=
{
"optimize"
:
path
}
self
.
_ein_kw
=
{
"optimize"
:
path
}
iss
,
oss
,
*
_
=
numpy_subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment