qht_operator.py 2.06 KB
Newer Older
1 2
from ..domain_tuple import DomainTuple
from ..field import Field
3
from .. import dobj
4 5 6 7 8
from ..utilities import hartley
from .linear_operator import LinearOperator


class QHTOperator(LinearOperator):
Reimar H Leike's avatar
Reimar H Leike committed
9 10 11 12 13 14 15 16 17 18 19 20 21 22
    """
    Does a Hartley transform on LogRGSpace

    This operator takes a field on a LogRGSpace and transforms it
    according to the Hartley transform. The zero modes are not transformed
    because they are infinitely far away.

    Parameters
    ----------
    domain : LogRGSpace
        The domain needs to be a LogRGSpace.
    target : LogRGSpace
        The target needs to be a LogRGSpace.
    """
23 24 25 26 27 28 29
    def __init__(self, domain, target):
        if not domain.harmonic:
            raise TypeError(
                "HarmonicTransformOperator only works on a harmonic space")
        if target.harmonic:
            raise TypeError("Target is not a codomain of domain")

Martin Reinecke's avatar
fixes  
Martin Reinecke committed
30
        from ..domains.log_rg_space import LogRGSpace
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
        if not isinstance(domain, LogRGSpace):
            raise ValueError("Domain has to be a LogRGSpace!")
        if not isinstance(target, LogRGSpace):
            raise ValueError("Target has to be a LogRGSpace!")

        self._domain = DomainTuple.make(domain)
        self._target = DomainTuple.make(target)

    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    def apply(self, x, mode):
        self._check_input(x, mode)
        x = x.val * self.domain[0].scalar_dvol()
        n = len(self.domain[0].shape)
        rng = range(n) if mode == self.TIMES else reversed(range(n))
Martin Reinecke's avatar
Martin Reinecke committed
52 53
        ax = dobj.distaxis(x)
        globshape = x.shape
54 55
        for i in rng:
            sl = (slice(None),)*i + (slice(1, None),)
Martin Reinecke's avatar
Martin Reinecke committed
56 57
            if i == ax:
                x = dobj.redistribute(x, nodist=(ax,))
Martin Reinecke's avatar
Martin Reinecke committed
58 59
            tmp = dobj.local_data(x)
            tmp[sl] = hartley(tmp[sl], axes=(i,))
Martin Reinecke's avatar
Martin Reinecke committed
60 61
            if i == ax:
                x = dobj.redistribute(x, dist=ax)
62 63 64 65 66
        return Field(self._tgt(mode), val=x)

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES