diff --git a/nifty6/operators/operator.py b/nifty6/operators/operator.py index 4d2ebcc51d4e851ecf0f23ca230fda50d967d445..31b24772b9d7fc5eafd5acf41357b2d7eaaa132a 100644 --- a/nifty6/operators/operator.py +++ b/nifty6/operators/operator.py @@ -79,6 +79,35 @@ class Operator(metaclass=NiftyMeta): return NotImplemented return _OpChain.make((self, x)) + def partial_insert(self, x): + from ..multi_domain import MultiDomain + if not isinstance(x, Operator): + raise TypeError + if not isinstance(self.domain, MultiDomain): + raise TypeError + if not isinstance(x.target, MultiDomain): + raise TypeError + bigdom = MultiDomain.union([self.domain, x.target]) + k1, k2 = set(self.domain.keys()), set(x.target.keys()) + le, ri = k2 - k1, k1 - k2 + leop, riop = self, x + if len(ri) > 0: + riop = riop + self.identity_operator( + MultiDomain.make({kk: bigdom[kk] + for kk in ri})) + if len(le) > 0: + leop = leop + self.identity_operator( + MultiDomain.make({kk: bigdom[kk] + for kk in le})) + return leop @ riop + + @staticmethod + def identity_operator(dom): + from .block_diagonal_operator import BlockDiagonalOperator + from .scaling_operator import ScalingOperator + idops = {kk: ScalingOperator(dd, 1.) for kk, dd in dom.items()} + return BlockDiagonalOperator(dom, idops) + def __mul__(self, x): if isinstance(x, Operator): return _OpProd(self, x) diff --git a/test/test_operators/test_partial_multifield_insert.py b/test/test_operators/test_partial_multifield_insert.py new file mode 100644 index 0000000000000000000000000000000000000000..9e6cba61cbb78d3c0dbfd66541dcb54f3919d6e3 --- /dev/null +++ b/test/test_operators/test_partial_multifield_insert.py @@ -0,0 +1,50 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2019 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. + +import numpy as np +import pytest +from numpy.testing import assert_, assert_allclose + +import nifty6 as ift + +from ..common import list2fixture + +pmp = pytest.mark.parametrize +dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128]) + + +def test_part_mf_insert(): + dom = ift.RGSpace(3) + op1 = ift.ScalingOperator(dom, 1.32).ducktape('a').ducktape_left('a1') + op2 = ift.ScalingOperator(dom, 1).exp().ducktape('b').ducktape_left('b1') + op3 = ift.ScalingOperator(dom, 1).sin().ducktape('c').ducktape_left('c1') + op4 = ift.ScalingOperator(dom, 1).ducktape('c0').ducktape_left('c')**2 + op5 = ift.ScalingOperator(dom, 1).tan().ducktape('d0').ducktape_left('d') + a = op1 + op2 + op3 + b = op4 + op5 + op = a.partial_insert(b) + fld = ift.from_random('normal', op.domain) + ift.extra.check_jacobian_consistency(op, fld) + assert_(op.domain is ift.MultiDomain.union( + [op1.domain, op2.domain, op4.domain, op5.domain])) + assert_(op.target is ift.MultiDomain.union( + [op1.target, op2.target, op3.target, op5.target])) + x, y = fld.val, op(fld).val + assert_allclose(y['a1'], x['a']*1.32) + assert_allclose(y['b1'], np.exp(x['b'])) + assert_allclose(y['c1'], np.sin(x['c0']**2)) + assert_allclose(y['d'], np.tan(x['d0']))