Commit e547f77b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'partial_insert' into 'NIFTy_6'

Introduce partial insert

See merge request !394
parents 4d5cfe1a 67337be2
Pipeline #65242 passed with stages
in 14 minutes and 10 seconds
......@@ -79,6 +79,35 @@ class Operator(metaclass=NiftyMeta):
return NotImplemented
return _OpChain.make((self, x))
def partial_insert(self, x):
from ..multi_domain import MultiDomain
if not isinstance(x, Operator):
raise TypeError
if not isinstance(self.domain, MultiDomain):
raise TypeError
if not isinstance(x.target, MultiDomain):
raise TypeError
bigdom = MultiDomain.union([self.domain, x.target])
k1, k2 = set(self.domain.keys()), set(x.target.keys())
le, ri = k2 - k1, k1 - k2
leop, riop = self, x
if len(ri) > 0:
riop = riop + self.identity_operator(
MultiDomain.make({kk: bigdom[kk]
for kk in ri}))
if len(le) > 0:
leop = leop + self.identity_operator(
MultiDomain.make({kk: bigdom[kk]
for kk in le}))
return leop @ riop
@staticmethod
def identity_operator(dom):
from .block_diagonal_operator import BlockDiagonalOperator
from .scaling_operator import ScalingOperator
idops = {kk: ScalingOperator(dd, 1.) for kk, dd in dom.items()}
return BlockDiagonalOperator(dom, idops)
def __mul__(self, x):
if isinstance(x, Operator):
return _OpProd(self, x)
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
import pytest
from numpy.testing import assert_, assert_allclose
import nifty6 as ift
from ..common import list2fixture
pmp = pytest.mark.parametrize
dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128])
def test_part_mf_insert():
dom = ift.RGSpace(3)
op1 = ift.ScalingOperator(dom, 1.32).ducktape('a').ducktape_left('a1')
op2 = ift.ScalingOperator(dom, 1).exp().ducktape('b').ducktape_left('b1')
op3 = ift.ScalingOperator(dom, 1).sin().ducktape('c').ducktape_left('c1')
op4 = ift.ScalingOperator(dom, 1).ducktape('c0').ducktape_left('c')**2
op5 = ift.ScalingOperator(dom, 1).tan().ducktape('d0').ducktape_left('d')
a = op1 + op2 + op3
b = op4 + op5
op = a.partial_insert(b)
fld = ift.from_random('normal', op.domain)
ift.extra.check_jacobian_consistency(op, fld)
assert_(op.domain is ift.MultiDomain.union(
[op1.domain, op2.domain, op4.domain, op5.domain]))
assert_(op.target is ift.MultiDomain.union(
[op1.target, op2.target, op3.target, op5.target]))
x, y = fld.val, op(fld).val
assert_allclose(y['a1'], x['a']*1.32)
assert_allclose(y['b1'], np.exp(x['b']))
assert_allclose(y['c1'], np.sin(x['c0']**2))
assert_allclose(y['d'], np.tan(x['d0']))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment