Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
88a0ae9f
Commit
88a0ae9f
authored
May 15, 2020
by
Philipp Frank
Browse files
Merge branch 'nifty6_einsum_pf' into 'nifty6_einsum'
Nifty6 einsum suggestions pf See merge request
!462
parents
cd38eef1
d296d7a7
Pipeline
#74968
passed with stages
in 26 minutes and 52 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/einsum.py
View file @
88a0ae9f
...
...
@@ -17,6 +17,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
numpy
as
np
import
string
from
..domain_tuple
import
DomainTuple
from
..linearization
import
Linearization
from
..field
import
Field
...
...
@@ -45,19 +46,12 @@ class MultiLinearEinsum(Operator):
`key_order` is not part of the `domain`. Fields in this object are
supposed to be static as they will not appear as FieldAdapter in the
Linearization.
optimize: bool, optional
Parameter passed on to einsum.
optimize: bool,
String or List,
optional
Parameter passed on to einsum
_path
.
"""
def
__init__
(
self
,
domain
,
subscripts
,
key_order
=
None
,
static_mf
=
None
,
optimize
=
False
):
def
__init__
(
self
,
domain
,
subscripts
,
key_order
=
None
,
static_mf
=
None
,
optimize
=
'optimal'
):
self
.
_domain
=
MultiDomain
.
make
(
domain
)
self
.
_sscr
=
subscripts
if
key_order
is
None
:
self
.
_key_order
=
tuple
(
self
.
_domain
.
keys
())
else
:
...
...
@@ -66,23 +60,31 @@ class MultiLinearEinsum(Operator):
ve
=
"`key_order` mus be specified if additional fields are munged"
raise
ValueError
(
ve
)
self
.
_stat_mf
=
static_mf
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
iss
,
self
.
_oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss
,
oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
len_consist
=
len
(
self
.
_key_order
)
==
len
(
iss_spl
)
sscr_consist
=
all
(
o
in
iss
for
o
in
self
.
_
oss
)
if
rest
or
not
sscr_consist
or
","
in
self
.
_
oss
or
not
len_consist
:
sscr_consist
=
all
(
o
in
iss
for
o
in
oss
)
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
self
.
_key_order
}
for subscripts
{
subscripts
}
"
shapes
,
numpy_subscripts
,
subscriptmap
=
{},
''
,{}
alphabet
=
list
(
string
.
ascii_lowercase
)[::
-
1
]
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
):
dom
=
self
.
_domain
[
k
]
if
k
in
self
.
_domain
.
keys
(
)
else
self
.
_stat_mf
[
k
].
domain
if
len
(
dom
.
shape
)
!=
len
(
ss
):
ve
=
f
"invalid order of keys
{
key_order
}
for subscripts
{
subscripts
}
"
if
len
(
dom
)
!=
len
(
ss
):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
[
alphabet
.
pop
()
for
_
in
range
(
len
(
dom
[
i
].
shape
))]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
','
shapes
[
k
]
=
dom
.
shape
numpy_subscripts
=
numpy_subscripts
[:
-
1
]
+
'->'
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
))
tgt
=
[]
for
o
in
self
.
_
oss
:
for
o
in
oss
:
k_hit
=
tuple
(
k
for
k
,
sscr
in
dom_sscr
.
items
()
if
o
in
sscr
)[
0
]
dom_k_idx
=
dom_sscr
[
k_hit
].
index
(
o
)
if
k_hit
in
self
.
_domain
.
keys
():
...
...
@@ -92,14 +94,30 @@ class MultiLinearEinsum(Operator):
ve
=
f
"
{
k_hit
}
is not in domain nor in static_mf"
raise
ValueError
(
ve
)
tgt
+=
[
self
.
_stat_mf
[
k_hit
].
domain
[
dom_k_idx
]]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
o
])
self
.
_target
=
DomainTuple
.
make
(
tgt
)
numpy_iss
,
numpy_oss
,
*
_
=
numpy_subscripts
.
split
(
"->"
)
numpy_iss_spl
=
numpy_iss
.
split
(
","
)
self
.
_sscr_endswith
=
dict
()
for
k
,
(
i
,
ss
)
in
zip
(
self
.
_key_order
,
enumerate
(
iss_spl
)):
left_ss_spl
=
(
*
iss_spl
[:
i
],
*
iss_spl
[
i
+
1
:],
ss
)
self
.
_sscr_endswith
[
k
]
=
"->"
.
join
(
(
","
.
join
(
left_ss_spl
),
self
.
_oss
)
)
self
.
_linpaths
=
dict
()
for
k
,
(
i
,
ss
)
in
zip
(
self
.
_key_order
,
enumerate
(
numpy_iss_spl
)):
left_ss_spl
=
(
*
numpy_iss_spl
[:
i
],
*
numpy_iss_spl
[
i
+
1
:],
ss
)
linpath
=
'->'
.
join
((
','
.
join
(
left_ss_spl
),
numpy_oss
))
plc
=
tuple
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
q
])
for
q
in
shapes
.
keys
()
if
q
!=
k
)
plc
+=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
]),)
self
.
_sscr_endswith
[
k
]
=
linpath
self
.
_linpaths
[
k
]
=
np
.
einsum_path
(
linpath
,
*
plc
,
optimize
=
optimize
)[
0
]
if
isinstance
(
optimize
,
list
):
path
=
optimize
else
:
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shapes
[
k
])
for
k
in
shapes
.
keys
())
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_sscr
=
numpy_subscripts
self
.
_ein_kw
=
{
"optimize"
:
path
}
def
apply
(
self
,
x
):
self
.
_check_input
(
x
)
...
...
@@ -129,7 +147,9 @@ class MultiLinearEinsum(Operator):
mf_wo_k
,
ss
,
key_order
=
tuple
(
plc
.
keys
()),
**
self
.
_ein_kw
optimize
=
self
.
_linpaths
[
wrt
],
_target
=
self
.
_target
,
_calling_as_lin
=
True
).
ducktape
(
wrt
)
jac
=
jac
+
jac_k
if
jac
is
not
None
else
jac_k
return
x
.
new
(
Field
.
from_raw
(
self
.
target
,
res
),
jac
)
...
...
@@ -155,43 +175,82 @@ class LinearEinsum(LinearOperator):
key_order: tuple of str, optional
The order of the keys in the multi-field. If not specified, defaults to
the order of the keys in the multi-field.
optimize: bool, optional
Parameter passed on to einsum.
optimize: bool,
String or List,
optional
Parameter passed on to einsum
_path
.
"""
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
False
):
def
__init__
(
self
,
domain
,
mf
,
subscripts
,
key_order
=
None
,
optimize
=
'optimal'
,
_target
=
None
,
_calling_as_lin
=
False
):
self
.
_domain
=
DomainTuple
.
make
(
domain
)
self
.
_mf
=
mf
self
.
_sscr
=
subscripts
if
key_order
is
None
:
self
.
_key_order
=
tuple
(
self
.
_mf
.
domain
.
keys
())
if
_calling_as_lin
:
self
.
_init2
(
mf
,
subscripts
,
key_order
,
optimize
,
_target
)
else
:
self
.
_key_order
=
key_order
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
iss
,
oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
sscr_consist
=
all
(
o
in
iss
for
o
in
oss
)
len_consist
=
len
(
self
.
_key_order
)
==
len
(
iss_spl
[:
-
1
])
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
key_order
}
for subscripts
{
subscripts
}
"
for
k
,
ss
in
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]):
if
len
(
self
.
_mf
[
k
].
shape
)
!=
len
(
ss
):
self
.
_mf
=
mf
if
key_order
is
None
:
_key_order
=
tuple
(
self
.
_mf
.
domain
.
keys
())
else
:
_key_order
=
key_order
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
iss
,
oss
,
*
rest
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
sscr_consist
=
all
(
o
in
iss
for
o
in
oss
)
len_consist
=
len
(
_key_order
)
==
len
(
iss_spl
[:
-
1
])
if
rest
or
not
sscr_consist
or
","
in
oss
or
not
len_consist
:
raise
ValueError
(
f
"invalid subscripts specified; got
{
subscripts
}
"
)
ve
=
f
"invalid order of keys
{
_key_order
}
for subscripts
{
subscripts
}
"
shapes
,
numpy_subscripts
,
subscriptmap
=
(),
''
,{}
alphabet
=
list
(
string
.
ascii_lowercase
)
for
k
,
ss
in
zip
(
_key_order
,
iss_spl
[:
-
1
]):
dom
=
self
.
_mf
[
k
].
domain
if
len
(
dom
)
!=
len
(
ss
):
raise
ValueError
(
ve
)
for
i
,
a
in
enumerate
(
list
(
ss
)):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
[
alphabet
.
pop
()
for
_
in
range
(
len
(
dom
[
i
].
shape
))]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
numpy_subscripts
+=
','
shapes
+=
(
dom
.
shape
,)
if
len
(
self
.
_domain
)
!=
len
(
iss_spl
[
-
1
]):
raise
ValueError
(
ve
)
if
len
(
self
.
_domain
.
shape
)
!=
len
(
iss_spl
[
-
1
]):
raise
ValueError
(
ve
)
dom_sscr
=
dict
(
zip
(
self
.
_key_order
,
iss_spl
[:
-
1
]))
dom_sscr
[
id
(
self
)]
=
iss_spl
[
-
1
]
tgt
=
[]
for
o
in
oss
:
k_hit
=
tuple
(
k
for
k
,
sscr
in
dom_sscr
.
items
()
if
o
in
sscr
)[
0
]
dom_k_idx
=
dom_sscr
[
k_hit
].
index
(
o
)
if
k_hit
in
self
.
_key_order
:
tgt
+=
[
self
.
_mf
.
domain
[
k_hit
][
dom_k_idx
]]
for
i
,
a
in
enumerate
(
list
(
iss_spl
[
-
1
])):
if
a
not
in
subscriptmap
.
keys
():
subscriptmap
[
a
]
=
[
alphabet
.
pop
()
for
_
in
range
(
len
(
self
.
_domain
[
i
].
shape
))]
numpy_subscripts
+=
''
.
join
(
subscriptmap
[
a
])
shapes
+=
(
self
.
_domain
.
shape
,)
numpy_subscripts
+=
'->'
dom_sscr
=
dict
(
zip
(
_key_order
,
iss_spl
[:
-
1
]))
dom_sscr
[
id
(
self
)]
=
iss_spl
[
-
1
]
tgt
=
[]
for
o
in
oss
:
k_hit
=
tuple
(
k
for
k
,
sscr
in
dom_sscr
.
items
()
if
o
in
sscr
)[
0
]
dom_k_idx
=
dom_sscr
[
k_hit
].
index
(
o
)
if
k_hit
in
_key_order
:
tgt
+=
[
self
.
_mf
.
domain
[
k_hit
][
dom_k_idx
]]
else
:
assert
k_hit
==
id
(
self
)
tgt
+=
[
self
.
_domain
[
dom_k_idx
]]
numpy_subscripts
+=
""
.
join
(
subscriptmap
[
o
])
_target
=
DomainTuple
.
make
(
tgt
)
self
.
_sscr
=
numpy_subscripts
if
isinstance
(
optimize
,
list
):
path
=
optimize
else
:
assert
k_hit
==
id
(
self
)
tgt
+=
[
self
.
_domain
[
dom_k_idx
]]
self
.
_target
=
DomainTuple
.
make
(
tgt
)
plc
=
(
np
.
broadcast_to
(
np
.
nan
,
shp
)
for
shp
in
shapes
)
path
=
np
.
einsum_path
(
numpy_subscripts
,
*
plc
,
optimize
=
optimize
)[
0
]
self
.
_init2
(
mf
,
numpy_subscripts
,
_key_order
,
path
,
_target
)
def
_init2
(
self
,
mf
,
subscripts
,
keyorder
,
optimize
,
target
):
self
.
_ein_kw
=
{
"optimize"
:
optimize
}
self
.
_mf
=
mf
self
.
_sscr
=
subscripts
self
.
_key_order
=
keyorder
self
.
_target
=
target
iss
,
oss
,
*
_
=
subscripts
.
split
(
"->"
)
iss_spl
=
iss
.
split
(
","
)
adj_iss
=
","
.
join
((
","
.
join
(
iss_spl
[:
-
1
]),
oss
))
self
.
_adj_sscr
=
"->"
.
join
((
adj_iss
,
iss_spl
[
-
1
]))
...
...
test/test_operators/test_einsum.py
View file @
88a0ae9f
...
...
@@ -21,31 +21,35 @@ from numpy.testing import assert_allclose
from
nifty6.extra
import
check_jacobian_consistency
,
consistency_check
import
nifty6
as
ift
from
..common
import
setup_function
,
teardown_function
from
..common
import
list2fixture
,
setup_function
,
teardown_function
pmp
=
pytest
.
mark
.
parametrize
spaces
=
(
ift
.
UnstructuredDomain
(
4
),
ift
.
RGSpace
((
3
,
2
)),
ift
.
LMSpace
(
5
),
ift
.
HPSpace
(
4
),
ift
.
GLSpace
(
4
))
space1
=
list2fixture
(
spaces
)
space2
=
list2fixture
(
spaces
)
@
pmp
(
"n_unstructured"
,
(
3
,
9
))
@
pmp
(
"nside"
,
(
4
,
8
))
def
test_linear_einsum_outer
(
n_unstructured
,
nside
,
n_invocations
=
10
):
def
test_linear_einsum_outer
(
space1
,
space2
,
n_invocations
=
10
):
setup_function
()
pos_space
=
ift
.
HPSpace
(
nside
)
mf_dom
=
ift
.
MultiDomain
.
make
(
{
"dom01"
:
ift
.
UnstructuredDomain
(
n_unstructured
),
"dom01"
:
space1
,
"dom02"
:
ift
.
DomainTuple
.
make
(
(
ift
.
UnstructuredDomain
(
n_unstructured
),
pos_
space
)
(
space1
,
space
2
)
)
}
)
mf
=
ift
.
from_random
(
"normal"
,
mf_dom
)
ss
=
"i,ij,j->ij"
key_order
=
(
"dom01"
,
"dom02"
)
le
=
ift
.
LinearEinsum
(
pos_
space
,
mf
,
ss
,
key_order
=
key_order
)
le
=
ift
.
LinearEinsum
(
space
2
,
mf
,
ss
,
key_order
=
key_order
)
assert
consistency_check
(
le
)
is
None
le_ift
=
ift
.
DiagonalOperator
(
...
...
@@ -63,26 +67,22 @@ def test_linear_einsum_outer(n_unstructured, nside, n_invocations=10):
teardown_function
()
@
pmp
(
"n_unstructured"
,
(
3
,
9
))
@
pmp
(
"nside"
,
(
4
,
8
))
def
test_linear_einsum_contraction
(
n_unstructured
,
nside
,
n_invocations
=
10
):
def
test_linear_einsum_contraction
(
space1
,
space2
,
n_invocations
=
10
):
setup_function
()
pos_space
=
ift
.
HPSpace
(
nside
)
mf_dom
=
ift
.
MultiDomain
.
make
(
{
"dom01"
:
ift
.
UnstructuredDomain
(
n_unstructured
),
"dom01"
:
space1
,
"dom02"
:
ift
.
DomainTuple
.
make
(
(
ift
.
UnstructuredDomain
(
n_unstructured
),
pos_
space
)
(
space1
,
space
2
)
)
}
)
mf
=
ift
.
from_random
(
"normal"
,
mf_dom
)
ss
=
"i,ij,j->i"
key_order
=
(
"dom01"
,
"dom02"
)
le
=
ift
.
LinearEinsum
(
pos_
space
,
mf
,
ss
,
key_order
=
key_order
)
le
=
ift
.
LinearEinsum
(
space
2
,
mf
,
ss
,
key_order
=
key_order
)
assert
consistency_check
(
le
)
is
None
le_ift
=
ift
.
ContractionOperator
(
mf_dom
[
"dom02"
],
1
)
@
ift
.
DiagonalOperator
(
...
...
@@ -100,24 +100,16 @@ def test_linear_einsum_contraction(n_unstructured, nside, n_invocations=10):
teardown_function
()
@
pmp
(
"n_unstructured"
,
(
3
,
9
))
@
pmp
(
"nside"
,
(
4
,
8
))
def
test_multi_linear_einsum_outer
(
n_unstructured
,
nside
,
n_invocations
=
10
,
ntries
=
100
space1
,
space2
,
n_invocations
=
10
,
ntries
=
100
):
setup_function
()
pos_space
=
ift
.
HPSpace
(
nside
)
mf_dom
=
ift
.
MultiDomain
.
make
(
{
"dom01"
:
ift
.
UnstructuredDomain
(
n_unstructured
),
"dom02"
:
ift
.
DomainTuple
.
make
(
(
ift
.
UnstructuredDomain
(
n_unstructured
),
pos_space
)
),
"dom03"
:
pos_space
"dom01"
:
space1
,
"dom02"
:
ift
.
DomainTuple
.
make
((
space1
,
space2
)),
"dom03"
:
space2
}
)
ss
=
"i,ij,j->ij"
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment