qht_operator.py 1.75 KB
Newer Older
1
2
from ..domain_tuple import DomainTuple
from ..field import Field
3
from .. import dobj
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from ..utilities import hartley
from .linear_operator import LinearOperator


class QHTOperator(LinearOperator):
    def __init__(self, domain, target):
        if not domain.harmonic:
            raise TypeError(
                "HarmonicTransformOperator only works on a harmonic space")
        if target.harmonic:
            raise TypeError("Target is not a codomain of domain")

        from ..domains import LogRGSpace
        if not isinstance(domain, LogRGSpace):
            raise ValueError("Domain has to be a LogRGSpace!")
        if not isinstance(target, LogRGSpace):
            raise ValueError("Target has to be a LogRGSpace!")

        self._domain = DomainTuple.make(domain)
        self._target = DomainTuple.make(target)

    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    def apply(self, x, mode):
        self._check_input(x, mode)
        x = x.val * self.domain[0].scalar_dvol()
        n = len(self.domain[0].shape)
        rng = range(n) if mode == self.TIMES else reversed(range(n))
        for i in rng:
            sl = (slice(None),)*i + (slice(1, None),)
Martin Reinecke's avatar
Martin Reinecke committed
40
41
42
43
44
45
46
47
48
49
50
            if i == dobj.distaxis(x):
                x = dobj.redistribute(x, nodist=(i,))
                ax = dobj.distaxis(x)
                x = dobj.local_data(x)
                x[sl] = hartley(x[sl], axes=(i,))
                x = dobj.from_local_data(x.shape, x, distaxis=ax)
                x = dobj.redistribute(x, dist=i)
            else:
                x[sl] = hartley(x[sl], axes=(i,))

        return Field(self._tgt(mode), val=x)
51
52
53
54

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES