Commit 17727718 authored by Philipp Frank's avatar Philipp Frank
Browse files

Einsum handling for complex conjugation

parent d296d7a7
Pipeline #74969 passed with stages
in 26 minutes and 27 seconds
......@@ -12,7 +12,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Gordian Edenhofer
# Authors: Gordian Edenhofer, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -28,9 +28,7 @@ from .linear_operator import LinearOperator
class MultiLinearEinsum(Operator):
"""Multi-linear Einsum operator with corresponding derivates
FIXME: This operator does not perform any complex conjugation!
"""Multi-linear Einsum operator with corresponding derivates.
Parameters
----------
......@@ -48,6 +46,13 @@ class MultiLinearEinsum(Operator):
Linearization.
optimize: bool, String or List, optional
Parameter passed on to einsum_path.
Notes
-----
By convention :class:`MultiLinearEinsum` only performs operations with
lower indices. Therefore no complex conjugation is performed on complex
Inputs. To achieve operations with upper/lower indices use
:class:`PartialConjugate` before applying this operator.
"""
def __init__(self, domain, subscripts,
key_order=None, static_mf=None, optimize='optimal'):
......@@ -159,7 +164,6 @@ class MultiLinearEinsum(Operator):
class LinearEinsum(LinearOperator):
"""Linear Einsum operator with exactly one freely varying field
FIXME: This operator does not perform any complex conjugation!
Parameters
----------
......@@ -259,11 +263,11 @@ class LinearEinsum(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
dom, ss = self.target, self._sscr
dom, ss, mf = self.target, self._sscr, self._mf
else:
dom, ss = self.domain, self._adj_sscr
dom, ss, mf = self.domain, self._adj_sscr, self._mf.conjugate()
res = np.einsum(
ss, *(self._mf.val[k] for k in self._key_order), x.val,
ss, *(mf[k].val for k in self._key_order), x.val,
**self._ein_kw
)
return Field.from_raw(dom, res)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from .endomorphic_operator import EndomorphicOperator
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
class PartialConjugate(EndomorphicOperator):
"""Perform partial conjugation of a :class:`MultiField`
Parameters
----------
domain : MultiDomain
The operator's input domain and output target
conjugation_keys : iterable of string
The keys of the :class:`MultiField` for which complex conjugation
should be performed.
"""
def __init__(self, domain, conjugation_keys):
if not isinstance(domain, MultiDomain):
raise ValueError("MultiDomain expected!")
indom = (key in domain.keys() for key in conjugation_keys)
if sum(indom) != len(conjugation_keys):
raise ValueError("conjugation_keys not in domain!")
self._domain = domain
self._conjugation_keys = conjugation_keys
self._capabilities = self._all_ops
def apply(self, x, mode):
self._check_input(x, mode)
x = x.to_dict()
for k in self._conjugation_keys:
x[k] = x[k].conjugate()
return MultiField.from_dict(x, self._domain)
......@@ -12,12 +12,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Gordian Edenhofer
# Authors: Gordian Edenhofer, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from numpy.testing import assert_allclose
import numpy as np
from nifty6.extra import check_jacobian_consistency, consistency_check
import nifty6 as ift
......@@ -27,14 +28,14 @@ pmp = pytest.mark.parametrize
spaces = (ift.UnstructuredDomain(4),
ift.RGSpace((3,2)),
ift.LMSpace(5),
ift.HPSpace(4),
ift.GLSpace(4))
space1 = list2fixture(spaces)
space2 = list2fixture(spaces)
dtype = list2fixture([np.float64, np.complex128])
def test_linear_einsum_outer(space1, space2, n_invocations=10):
def test_linear_einsum_outer(space1, space2, dtype, n_invocations=10):
setup_function()
mf_dom = ift.MultiDomain.make(
......@@ -46,11 +47,11 @@ def test_linear_einsum_outer(space1, space2, n_invocations=10):
)
}
)
mf = ift.from_random("normal", mf_dom)
mf = ift.from_random("normal", mf_dom, dtype=dtype)
ss = "i,ij,j->ij"
key_order = ("dom01", "dom02")
le = ift.LinearEinsum(space2, mf, ss, key_order=key_order)
assert consistency_check(le) is None
assert consistency_check(le, domain_dtype=dtype,target_dtype=dtype) is None
le_ift = ift.DiagonalOperator(
mf["dom01"], domain=mf_dom["dom02"], spaces=0
......@@ -59,15 +60,15 @@ def test_linear_einsum_outer(space1, space2, n_invocations=10):
)
for _ in range(n_invocations):
r = ift.from_random("normal", le.domain)
r = ift.from_random("normal", le.domain, dtype=dtype)
assert_allclose(le(r).val, le_ift(r).val)
r_adj = ift.from_random("normal", le.target)
r_adj = ift.from_random("normal", le.target, dtype=dtype)
assert_allclose(le.adjoint(r_adj).val, le_ift.adjoint(r_adj).val)
teardown_function()
def test_linear_einsum_contraction(space1, space2, n_invocations=10):
def test_linear_einsum_contraction(space1, space2, dtype, n_invocations=10):
setup_function()
mf_dom = ift.MultiDomain.make(
......@@ -79,11 +80,11 @@ def test_linear_einsum_contraction(space1, space2, n_invocations=10):
)
}
)
mf = ift.from_random("normal", mf_dom)
mf = ift.from_random("normal", mf_dom, dtype=dtype)
ss = "i,ij,j->i"
key_order = ("dom01", "dom02")
le = ift.LinearEinsum(space2, mf, ss, key_order=key_order)
assert consistency_check(le) is None
assert consistency_check(le, domain_dtype=dtype,target_dtype=dtype) is None
le_ift = ift.ContractionOperator(mf_dom["dom02"], 1) @ ift.DiagonalOperator(
mf["dom01"], domain=mf_dom["dom02"], spaces=0
......@@ -92,16 +93,16 @@ def test_linear_einsum_contraction(space1, space2, n_invocations=10):
)
for _ in range(n_invocations):
r = ift.from_random("normal", le.domain)
r = ift.from_random("normal", le.domain, dtype=dtype)
assert_allclose(le(r).val, le_ift(r).val)
r_adj = ift.from_random("normal", le.target)
r_adj = ift.from_random("normal", le.target, dtype=dtype)
assert_allclose(le.adjoint(r_adj).val, le_ift.adjoint(r_adj).val)
teardown_function()
def test_multi_linear_einsum_outer(
space1, space2, n_invocations=10, ntries=100
space1, space2, dtype, n_invocations=10, ntries=100
):
setup_function()
......@@ -116,7 +117,7 @@ def test_multi_linear_einsum_outer(
key_order = ("dom01", "dom02", "dom03")
mle = ift.MultiLinearEinsum(mf_dom, ss, key_order=key_order)
check_jacobian_consistency(
mle, ift.from_random("normal", mle.domain), ntries=ntries
mle, ift.from_random("normal", mle.domain, dtype=dtype), ntries=ntries
)
outer_i = ift.OuterProduct(
......@@ -133,12 +134,13 @@ def test_multi_linear_einsum_outer(
) * (outer_j @ ift.FieldAdapter(mf_dom["dom03"], "dom03"))
for _ in range(n_invocations):
rl = ift.Linearization.make_var(ift.from_random("normal", mle.domain))
rl = ift.Linearization.make_var(
ift.from_random("normal", mle.domain, dtype=dtype))
mle_rl, mle_ift_rl = mle(rl), mle_ift(rl)
assert_allclose(mle_rl.val.val, mle_ift_rl.val.val)
assert_allclose(mle_rl.jac(rl.val).val, mle_ift_rl.jac(rl.val).val)
rj_adj = ift.from_random("normal", mle_rl.jac.target)
rj_adj = ift.from_random("normal", mle_rl.jac.target, dtype=dtype)
mle_j_val = mle_rl.jac.adjoint(rj_adj).val
mle_ift_j_val = mle_ift_rl.jac.adjoint(rj_adj).val
for k in mle_ift.domain.keys():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment