Commit 5aca56f7 authored by Philipp Arras's avatar Philipp Arras
Browse files

Cosmetics

parent 7130009c
Pipeline #75151 passed with stages
in 10 minutes and 33 seconds
......@@ -11,20 +11,22 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Gordian Edenhofer, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import string
import numpy as np
from ..domain_tuple import DomainTuple
from ..linearization import Linearization
from ..field import Field
from ..linearization import Linearization
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .operator import Operator
from .linear_operator import LinearOperator
from .operator import Operator
class MultiLinearEinsum(Operator):
......@@ -251,7 +253,6 @@ class LinearEinsum(LinearOperator):
path = np.einsum_path(numpy_subscripts, *plc, optimize=optimize)[0]
self._init_wo_preproc(mf, numpy_subscripts, _key_order, path, _target)
def _init_wo_preproc(self, mf, subscripts, keyorder, optimize, target):
self._ein_kw = {"optimize": optimize}
self._mf = mf
......
......@@ -11,21 +11,21 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
# Authors: Gordian Edenhofer, Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from numpy.testing import assert_allclose
import numpy as np
from nifty6.extra import check_jacobian_consistency, consistency_check
from numpy.testing import assert_, assert_allclose
import nifty6 as ift
from nifty6.extra import check_jacobian_consistency, consistency_check
from ..common import list2fixture, setup_function, teardown_function
spaces = (ift.UnstructuredDomain(4),
ift.RGSpace((3,2)),
ift.RGSpace((3, 2)),
ift.LMSpace(5),
ift.GLSpace(4))
......@@ -35,26 +35,18 @@ dtype = list2fixture([np.float64, np.complex128])
def test_linear_einsum_outer(space1, space2, dtype, n_invocations=10):
mf_dom = ift.MultiDomain.make(
{
"dom01": space1,
"dom02":
ift.DomainTuple.make(
(space1, space2)
)
}
)
mf_dom = ift.MultiDomain.make({
"dom01": space1,
"dom02": ift.DomainTuple.make((space1, space2))})
mf = ift.from_random("normal", mf_dom, dtype=dtype)
ss = "i,ij,j->ij"
key_order = ("dom01", "dom02")
le = ift.LinearEinsum(space2, mf, ss, key_order=key_order)
assert consistency_check(le, domain_dtype=dtype,target_dtype=dtype) is None
assert_(consistency_check(le, domain_dtype=dtype, target_dtype=dtype) is None)
le_ift = ift.DiagonalOperator(
mf["dom01"], domain=mf_dom["dom02"], spaces=0
) @ ift.DiagonalOperator(mf["dom02"]) @ ift.OuterProduct(
ift.full(mf_dom["dom01"], 1.), ift.DomainTuple.make(mf_dom["dom02"][1])
)
le_ift = ift.DiagonalOperator(mf["dom01"], domain=mf_dom["dom02"], spaces=0) @ ift.DiagonalOperator(mf["dom02"])
le_ift = le_ift @ ift.OuterProduct(ift.full(mf_dom["dom01"], 1.),
ift.DomainTuple.make(mf_dom["dom02"][1]))
for _ in range(n_invocations):
r = ift.from_random("normal", le.domain, dtype=dtype)
......@@ -64,26 +56,20 @@ def test_linear_einsum_outer(space1, space2, dtype, n_invocations=10):
def test_linear_einsum_contraction(space1, space2, dtype, n_invocations=10):
mf_dom = ift.MultiDomain.make(
{
"dom01": space1,
"dom02":
ift.DomainTuple.make(
(space1, space2)
)
}
)
mf_dom = ift.MultiDomain.make({
"dom01": space1,
"dom02": ift.DomainTuple.make((space1, space2))})
mf = ift.from_random("normal", mf_dom, dtype=dtype)
ss = "i,ij,j->i"
key_order = ("dom01", "dom02")
le = ift.LinearEinsum(space2, mf, ss, key_order=key_order)
assert consistency_check(le, domain_dtype=dtype,target_dtype=dtype) is None
assert_(consistency_check(le, domain_dtype=dtype, target_dtype=dtype) is None)
le_ift = ift.ContractionOperator(mf_dom["dom02"], 1) @ ift.DiagonalOperator(
mf["dom01"], domain=mf_dom["dom02"], spaces=0
) @ ift.DiagonalOperator(mf["dom02"]) @ ift.OuterProduct(
ift.full(mf_dom["dom01"], 1.), ift.DomainTuple.make(mf_dom["dom02"][1])
)
le_ift = ift.ContractionOperator(mf_dom["dom02"], 1)
le_ift = le_ift @ ift.DiagonalOperator(mf["dom01"], domain=mf_dom["dom02"], spaces=0)
le_ift = le_ift @ ift.DiagonalOperator(mf["dom02"])
le_ift = le_ift @ ift.OuterProduct(ift.full(mf_dom["dom01"], 1.),
ift.DomainTuple.make(mf_dom["dom02"][1]))
for _ in range(n_invocations):
r = ift.from_random("normal", le.domain, dtype=dtype)
......@@ -103,7 +89,7 @@ class _SwitchSpacesOperator(ift.LinearOperator):
n_spaces = len(self._domain)
if space1 >= n_spaces or space1 < 0 \
or space2 >= n_spaces or space2 < 0:
or space2 >= n_spaces or space2 < 0:
raise ValueError("invalid space value")
tgt = list(self._domain)
......@@ -119,33 +105,23 @@ class _SwitchSpacesOperator(ift.LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
val = np.moveaxis(x.val, self._axes_dom, self._axes_tgt)
dom = self._target
else:
val = np.moveaxis(x.val, self._axes_tgt, self._axes_dom)
dom = self._domain
return ift.Field(dom, val)
def test_multi_linear_einsum_outer(
space1, space2, dtype, n_invocations=10, ntries=100
):
mf_dom = ift.MultiDomain.make(
{
"dom01": space1,
"dom02":ift.DomainTuple.make((space1, space2)),
"dom03": space2
}
)
args = self._axes_dom, self._axes_tgt
if mode == self.ADJOINT_TIMES:
args = args[::-1]
return ift.Field(self._tgt(mode), np.moveaxis(x.val, *args))
def test_multi_linear_einsum_outer(space1, space2, dtype):
ntries = 100
n_invocations = 10
mf_dom = ift.MultiDomain.make({
"dom01": space1,
"dom02": ift.DomainTuple.make((space1, space2)),
"dom03": space2})
ss = "i,ij,j->ij"
key_order = ("dom01", "dom02", "dom03")
mle = ift.MultiLinearEinsum(mf_dom, ss, key_order=key_order)
check_jacobian_consistency(
mle, ift.from_random("normal", mle.domain, dtype=dtype), ntries=ntries
)
check_jacobian_consistency(mle, ift.from_random("normal", mle.domain, dtype=dtype), ntries=ntries)
outer_i = ift.OuterProduct(
ift.full(mf_dom["dom03"], 1.), ift.DomainTuple.make(mf_dom["dom02"][0])
......@@ -154,15 +130,13 @@ def test_multi_linear_einsum_outer(
ift.full(mf_dom["dom01"], 1.), ift.DomainTuple.make(mf_dom["dom02"][1])
)
# SwitchSpacesOperator is equivalent to LinearEinsum with "ij->ji"
mle_ift = _SwitchSpacesOperator(
outer_i.target, 1
) @ outer_i @ ift.FieldAdapter(mf_dom["dom01"], "dom01") * ift.FieldAdapter(
mf_dom["dom02"], "dom02"
) * (outer_j @ ift.FieldAdapter(mf_dom["dom03"], "dom03"))
mle_ift = _SwitchSpacesOperator(outer_i.target, 1) @ outer_i @ \
ift.FieldAdapter(mf_dom["dom01"], "dom01") * \
ift.FieldAdapter(mf_dom["dom02"], "dom02") * \
(outer_j @ ift.FieldAdapter(mf_dom["dom03"], "dom03"))
for _ in range(n_invocations):
rl = ift.Linearization.make_var(
ift.from_random("normal", mle.domain, dtype=dtype))
rl = ift.Linearization.make_var(ift.from_random("normal", mle.domain, dtype=dtype))
mle_rl, mle_ift_rl = mle(rl), mle_ift(rl)
assert_allclose(mle_rl.val.val, mle_ift_rl.val.val)
assert_allclose(mle_rl.jac(rl.val).val, mle_ift_rl.jac(rl.val).val)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment