Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
17727718
Commit
17727718
authored
May 15, 2020
by
Philipp Frank
Browse files
Einsum handling for complex conjugation
parent
d296d7a7
Pipeline
#74969
passed with stages
in 26 minutes and 27 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nifty6/operators/einsum.py
View file @
17727718
...
...
@@ -12,7 +12,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Gordian Edenhofer
# Authors: Gordian Edenhofer
, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...
...
@@ -28,9 +28,7 @@ from .linear_operator import LinearOperator
class
MultiLinearEinsum
(
Operator
):
"""Multi-linear Einsum operator with corresponding derivates
FIXME: This operator does not perform any complex conjugation!
"""Multi-linear Einsum operator with corresponding derivates.
Parameters
----------
...
...
@@ -48,6 +46,13 @@ class MultiLinearEinsum(Operator):
Linearization.
optimize: bool, String or List, optional
Parameter passed on to einsum_path.
Notes
-----
By convention :class:`MultiLinearEinsum` only performs operations with
lower indices. Therefore no complex conjugation is performed on complex
Inputs. To achieve operations with upper/lower indices use
:class:`PartialConjugate` before applying this operator.
"""
def
__init__
(
self
,
domain
,
subscripts
,
key_order
=
None
,
static_mf
=
None
,
optimize
=
'optimal'
):
...
...
@@ -159,7 +164,6 @@ class MultiLinearEinsum(Operator):
class
LinearEinsum
(
LinearOperator
):
"""Linear Einsum operator with exactly one freely varying field
FIXME: This operator does not perform any complex conjugation!
Parameters
----------
...
...
@@ -259,11 +263,11 @@ class LinearEinsum(LinearOperator):
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
dom
,
ss
=
self
.
target
,
self
.
_sscr
dom
,
ss
,
mf
=
self
.
target
,
self
.
_sscr
,
self
.
_mf
else
:
dom
,
ss
=
self
.
domain
,
self
.
_adj_sscr
dom
,
ss
,
mf
=
self
.
domain
,
self
.
_adj_sscr
,
self
.
_mf
.
conjugate
()
res
=
np
.
einsum
(
ss
,
*
(
self
.
_mf
.
val
[
k
]
for
k
in
self
.
_key_order
),
x
.
val
,
ss
,
*
(
mf
[
k
]
.
val
for
k
in
self
.
_key_order
),
x
.
val
,
**
self
.
_ein_kw
)
return
Field
.
from_raw
(
dom
,
res
)
nifty6/operators/partial_conjugate.py
0 → 100644
View file @
17727718
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from
.endomorphic_operator
import
EndomorphicOperator
from
..multi_domain
import
MultiDomain
from
..multi_field
import
MultiField
class
PartialConjugate
(
EndomorphicOperator
):
"""Perform partial conjugation of a :class:`MultiField`
Parameters
----------
domain : MultiDomain
The operator's input domain and output target
conjugation_keys : iterable of string
The keys of the :class:`MultiField` for which complex conjugation
should be performed.
"""
def
__init__
(
self
,
domain
,
conjugation_keys
):
if
not
isinstance
(
domain
,
MultiDomain
):
raise
ValueError
(
"MultiDomain expected!"
)
indom
=
(
key
in
domain
.
keys
()
for
key
in
conjugation_keys
)
if
sum
(
indom
)
!=
len
(
conjugation_keys
):
raise
ValueError
(
"conjugation_keys not in domain!"
)
self
.
_domain
=
domain
self
.
_conjugation_keys
=
conjugation_keys
self
.
_capabilities
=
self
.
_all_ops
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
x
=
x
.
to_dict
()
for
k
in
self
.
_conjugation_keys
:
x
[
k
]
=
x
[
k
].
conjugate
()
return
MultiField
.
from_dict
(
x
,
self
.
_domain
)
test/test_operators/test_einsum.py
View file @
17727718
...
...
@@ -12,12 +12,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Gordian Edenhofer
# Authors: Gordian Edenhofer
, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import
pytest
from
numpy.testing
import
assert_allclose
import
numpy
as
np
from
nifty6.extra
import
check_jacobian_consistency
,
consistency_check
import
nifty6
as
ift
...
...
@@ -27,14 +28,14 @@ pmp = pytest.mark.parametrize
spaces
=
(
ift
.
UnstructuredDomain
(
4
),
ift
.
RGSpace
((
3
,
2
)),
ift
.
LMSpace
(
5
),
ift
.
HPSpace
(
4
),
ift
.
GLSpace
(
4
))
space1
=
list2fixture
(
spaces
)
space2
=
list2fixture
(
spaces
)
dtype
=
list2fixture
([
np
.
float64
,
np
.
complex128
])
def
test_linear_einsum_outer
(
space1
,
space2
,
n_invocations
=
10
):
def
test_linear_einsum_outer
(
space1
,
space2
,
dtype
,
n_invocations
=
10
):
setup_function
()
mf_dom
=
ift
.
MultiDomain
.
make
(
...
...
@@ -46,11 +47,11 @@ def test_linear_einsum_outer(space1, space2, n_invocations=10):
)
}
)
mf
=
ift
.
from_random
(
"normal"
,
mf_dom
)
mf
=
ift
.
from_random
(
"normal"
,
mf_dom
,
dtype
=
dtype
)
ss
=
"i,ij,j->ij"
key_order
=
(
"dom01"
,
"dom02"
)
le
=
ift
.
LinearEinsum
(
space2
,
mf
,
ss
,
key_order
=
key_order
)
assert
consistency_check
(
le
)
is
None
assert
consistency_check
(
le
,
domain_dtype
=
dtype
,
target_dtype
=
dtype
)
is
None
le_ift
=
ift
.
DiagonalOperator
(
mf
[
"dom01"
],
domain
=
mf_dom
[
"dom02"
],
spaces
=
0
...
...
@@ -59,15 +60,15 @@ def test_linear_einsum_outer(space1, space2, n_invocations=10):
)
for
_
in
range
(
n_invocations
):
r
=
ift
.
from_random
(
"normal"
,
le
.
domain
)
r
=
ift
.
from_random
(
"normal"
,
le
.
domain
,
dtype
=
dtype
)
assert_allclose
(
le
(
r
).
val
,
le_ift
(
r
).
val
)
r_adj
=
ift
.
from_random
(
"normal"
,
le
.
target
)
r_adj
=
ift
.
from_random
(
"normal"
,
le
.
target
,
dtype
=
dtype
)
assert_allclose
(
le
.
adjoint
(
r_adj
).
val
,
le_ift
.
adjoint
(
r_adj
).
val
)
teardown_function
()
def
test_linear_einsum_contraction
(
space1
,
space2
,
n_invocations
=
10
):
def
test_linear_einsum_contraction
(
space1
,
space2
,
dtype
,
n_invocations
=
10
):
setup_function
()
mf_dom
=
ift
.
MultiDomain
.
make
(
...
...
@@ -79,11 +80,11 @@ def test_linear_einsum_contraction(space1, space2, n_invocations=10):
)
}
)
mf
=
ift
.
from_random
(
"normal"
,
mf_dom
)
mf
=
ift
.
from_random
(
"normal"
,
mf_dom
,
dtype
=
dtype
)
ss
=
"i,ij,j->i"
key_order
=
(
"dom01"
,
"dom02"
)
le
=
ift
.
LinearEinsum
(
space2
,
mf
,
ss
,
key_order
=
key_order
)
assert
consistency_check
(
le
)
is
None
assert
consistency_check
(
le
,
domain_dtype
=
dtype
,
target_dtype
=
dtype
)
is
None
le_ift
=
ift
.
ContractionOperator
(
mf_dom
[
"dom02"
],
1
)
@
ift
.
DiagonalOperator
(
mf
[
"dom01"
],
domain
=
mf_dom
[
"dom02"
],
spaces
=
0
...
...
@@ -92,16 +93,16 @@ def test_linear_einsum_contraction(space1, space2, n_invocations=10):
)
for
_
in
range
(
n_invocations
):
r
=
ift
.
from_random
(
"normal"
,
le
.
domain
)
r
=
ift
.
from_random
(
"normal"
,
le
.
domain
,
dtype
=
dtype
)
assert_allclose
(
le
(
r
).
val
,
le_ift
(
r
).
val
)
r_adj
=
ift
.
from_random
(
"normal"
,
le
.
target
)
r_adj
=
ift
.
from_random
(
"normal"
,
le
.
target
,
dtype
=
dtype
)
assert_allclose
(
le
.
adjoint
(
r_adj
).
val
,
le_ift
.
adjoint
(
r_adj
).
val
)
teardown_function
()
def
test_multi_linear_einsum_outer
(
space1
,
space2
,
n_invocations
=
10
,
ntries
=
100
space1
,
space2
,
dtype
,
n_invocations
=
10
,
ntries
=
100
):
setup_function
()
...
...
@@ -116,7 +117,7 @@ def test_multi_linear_einsum_outer(
key_order
=
(
"dom01"
,
"dom02"
,
"dom03"
)
mle
=
ift
.
MultiLinearEinsum
(
mf_dom
,
ss
,
key_order
=
key_order
)
check_jacobian_consistency
(
mle
,
ift
.
from_random
(
"normal"
,
mle
.
domain
),
ntries
=
ntries
mle
,
ift
.
from_random
(
"normal"
,
mle
.
domain
,
dtype
=
dtype
),
ntries
=
ntries
)
outer_i
=
ift
.
OuterProduct
(
...
...
@@ -133,12 +134,13 @@ def test_multi_linear_einsum_outer(
)
*
(
outer_j
@
ift
.
FieldAdapter
(
mf_dom
[
"dom03"
],
"dom03"
))
for
_
in
range
(
n_invocations
):
rl
=
ift
.
Linearization
.
make_var
(
ift
.
from_random
(
"normal"
,
mle
.
domain
))
rl
=
ift
.
Linearization
.
make_var
(
ift
.
from_random
(
"normal"
,
mle
.
domain
,
dtype
=
dtype
))
mle_rl
,
mle_ift_rl
=
mle
(
rl
),
mle_ift
(
rl
)
assert_allclose
(
mle_rl
.
val
.
val
,
mle_ift_rl
.
val
.
val
)
assert_allclose
(
mle_rl
.
jac
(
rl
.
val
).
val
,
mle_ift_rl
.
jac
(
rl
.
val
).
val
)
rj_adj
=
ift
.
from_random
(
"normal"
,
mle_rl
.
jac
.
target
)
rj_adj
=
ift
.
from_random
(
"normal"
,
mle_rl
.
jac
.
target
,
dtype
=
dtype
)
mle_j_val
=
mle_rl
.
jac
.
adjoint
(
rj_adj
).
val
mle_ift_j_val
=
mle_ift_rl
.
jac
.
adjoint
(
rj_adj
).
val
for
k
in
mle_ift
.
domain
.
keys
():
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment