qht_operator.py 2.15 KB
Newer Older
1
2
from ..domain_tuple import DomainTuple
from ..field import Field
3
from .. import dobj
4
5
6
7
8
from ..utilities import hartley
from .linear_operator import LinearOperator


class QHTOperator(LinearOperator):
Reimar H Leike's avatar
Reimar H Leike committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
    """
    Does a Hartley transform on LogRGSpace

    This operator takes a field on a LogRGSpace and transforms it
    according to the Hartley transform. The zero modes are not transformed
    because they are infinitely far away.

    Parameters
    ----------
    domain : LogRGSpace
        The domain needs to be a LogRGSpace.
    target : LogRGSpace
        The target needs to be a LogRGSpace.
    """
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def __init__(self, domain, target):
        if not domain.harmonic:
            raise TypeError(
                "HarmonicTransformOperator only works on a harmonic space")
        if target.harmonic:
            raise TypeError("Target is not a codomain of domain")

        from ..domains import LogRGSpace
        if not isinstance(domain, LogRGSpace):
            raise ValueError("Domain has to be a LogRGSpace!")
        if not isinstance(target, LogRGSpace):
            raise ValueError("Target has to be a LogRGSpace!")

        self._domain = DomainTuple.make(domain)
        self._target = DomainTuple.make(target)

    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    def apply(self, x, mode):
        self._check_input(x, mode)
        x = x.val * self.domain[0].scalar_dvol()
        n = len(self.domain[0].shape)
        rng = range(n) if mode == self.TIMES else reversed(range(n))
Martin Reinecke's avatar
Martin Reinecke committed
52
53
        ax = dobj.distaxis(x)
        globshape = x.shape
54
55
        for i in rng:
            sl = (slice(None),)*i + (slice(1, None),)
Martin Reinecke's avatar
Martin Reinecke committed
56
57
58
59
            if i == ax:
                x = dobj.redistribute(x, nodist=(ax,))
            curax = dobj.distaxis(x)
            x = dobj.local_data(x)
60
            x[sl] = hartley(x[sl], axes=(i,))
Martin Reinecke's avatar
Martin Reinecke committed
61
62
63
            x = dobj.from_local_data(globshape, x, distaxis=curax)
            if i == ax:
                x = dobj.redistribute(x, dist=ax)
64
65
66
67
68
        return Field(self._tgt(mode), val=x)

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES