Commit 1cb1c94e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

generalize DomainDistributor; some more operator tests

parent 6ea15ac4
......@@ -25,26 +25,16 @@ from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from .linear_operator import LinearOperator
from .. import utilities
# MR FIXME: this needs to be rewritten in a generic fashion
class DomainDistributor(LinearOperator):
def __init__(self, target, axis):
if dobj.ntask > 1:
raise NotImplementedError('UpProj class does not support MPI.')
assert len(target) == 2
assert axis in [0, 1]
if axis == 0:
domain = target[1]
self._size = target[0].size
else:
domain = target[0]
self._size = target[1].size
self._axis = axis
self._domain = DomainTuple.make(domain)
def __init__(self, target, spaces):
self._target = DomainTuple.make(target)
self._spaces = utilities.parse_spaces(spaces, len(self._target))
self._domain = [tgt for i, tgt in enumerate(self._target)
if i in self._spaces]
self._domain = DomainTuple.make(self._domain)
@property
def domain(self):
......@@ -57,23 +47,16 @@ class DomainDistributor(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
x = x.local_data
otherDirection = np.ones(self._size)
if self._axis == 0:
res = np.outer(otherDirection, x)
else:
res = np.outer(x, otherDirection)
res = res.reshape(dobj.local_shape(self.target.shape))
return Field.from_local_data(self.target, res)
ldat = x.local_data if 0 in self._spaces else x.to_global_data()
shp = []
for i, tgt in enumerate(self._target):
tmp = tgt.shape if i > 0 else tgt.local_shape
shp += tmp if i in self._spaces else(1,)*len(tgt.shape)
ldat = np.broadcast_to(ldat.reshape(shp), self._target.local_shape)
return Field.from_local_data(self._target, ldat)
else:
if self._axis == 0:
x = x.local_data.reshape(self._size, -1)
res = np.sum(x, axis=0)
else:
x = x.local_data.reshape(-1, self._size)
res = np.sum(x, axis=1)
res = res.reshape(dobj.local_shape(self.domain.shape))
return Field.from_local_data(self.domain, res)
return x.sum([s for s in range(len(x.domain))
if s not in self._spaces])
@property
def capability(self):
......
......@@ -27,12 +27,13 @@ from ..domains.power_space import PowerSpace
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
from .. import utilities
class ExpTransform(LinearOperator):
def __init__(self, target, dof, space=0):
self._target = DomainTuple.make(target)
self._space = int(space)
self._space = utilities.infer_space(self._target, space)
tgt = self._target[self._space]
if not ((isinstance(tgt, RGSpace) and tgt.harmonic) or
isinstance(tgt, PowerSpace)):
......
......@@ -8,13 +8,14 @@ from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
from .. import utilities
class FieldZeroPadder(LinearOperator):
def __init__(self, domain, factor, space=0):
super(FieldZeroPadder, self).__init__()
self._domain = DomainTuple.make(domain)
self._space = int(space)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
......@@ -52,11 +53,11 @@ class FieldZeroPadder(LinearOperator):
curax = dobj.distaxis(x)
if mode == self.ADJOINT_TIMES:
newarr = np.empty(dobj.local_shape(shp_out), dtype=x.dtype)
newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
newarr[()] = dobj.local_data(x)[(slice(None),)*ax +
(slice(0, shp_out[ax]),)]
else:
newarr = np.zeros(dobj.local_shape(shp_out), dtype=x.dtype)
newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
newarr[(slice(None),)*ax +
(slice(0, shp_in[ax]),)] = dobj.local_data(x)
newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax)
......
......@@ -22,7 +22,7 @@ from .. import dobj
from ..compat import *
from ..domain_tuple import DomainTuple
from ..field import Field
from ..utilities import hartley
from ..utilities import hartley, infer_space
from .linear_operator import LinearOperator
......@@ -47,7 +47,7 @@ class QHTOperator(LinearOperator):
"""
def __init__(self, domain, target, space=0):
self._domain = DomainTuple.make(domain)
self._space = int(space)
self._space = infer_space(self._domain, space)
from ..domains.log_rg_space import LogRGSpace
if not isinstance(self._domain[self._space], LogRGSpace):
......
......@@ -24,15 +24,16 @@ from ..domain_tuple import DomainTuple
from ..domains.log_rg_space import LogRGSpace
from ..field import Field
from .endomorphic_operator import EndomorphicOperator
from .. import utilities
class SymmetrizingOperator(EndomorphicOperator):
def __init__(self, domain, space=0):
self._domain = DomainTuple.make(domain)
self._space = int(space)
self._space = utilities.infer_space(self._domain, space)
dom = self._domain[self._space]
if not (isinstance(dom, LogRGSpace) and not dom.harmonic):
raise TypeError
raise TypeError("nonharmonic LogRGSpace needed")
@property
def domain(self):
......
......@@ -101,3 +101,25 @@ class Consistency_Tests(unittest.TestCase):
def testGeometryRemover(self, sp, dtype):
op = ift.GeometryRemover(sp)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 1, 2, 3, (0, 1), (0, 2), (0, 1, 2), (0, 2, 3), (1, 3)],
[np.float64, np.complex128]))
def testDomainDistributor(self, spaces, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.GLSpace(5),
ift.HPSpace(4))
op = ift.DomainDistributor(dom, spaces)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [np.float64, np.complex128]))
def testSymmetrizingOperator(self, space, dtype):
dom = (ift.LogRGSpace(10, [2.], [1.]), ift.UnstructuredDomain(13),
ift.LogRGSpace((5, 27), [1., 2.7], [0., 4.]), ift.HPSpace(4))
op = ift.SymmetrizingOperator(dom, space)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder(self, space, factor, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7),
ift.HPSpace(4))
op = ift.FieldZeroPadder(dom, factor, space)
ift.extra.consistency_check(op, dtype, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment