fft_operator.py 5.24 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
#
Martin Reinecke's avatar
Martin Reinecke committed
14
# Copyright(C) 2013-2018 Max-Planck-Society
Theo Steininger's avatar
Theo Steininger committed
15 16 17
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import absolute_import, division, print_function
20

Martin Reinecke's avatar
Martin Reinecke committed
21
import numpy as np
22 23 24

from .. import dobj, utilities
from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
25
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
26
from ..domains.rg_space import RGSpace
Martin Reinecke's avatar
Martin Reinecke committed
27
from ..field import Field
28
from .linear_operator import LinearOperator
Jait Dixit's avatar
Jait Dixit committed
29 30


Jait Dixit's avatar
Jait Dixit committed
31
class FFTOperator(LinearOperator):
32
    """Transforms between a pair of position and harmonic RGSpaces.
33 34 35

    Parameters
    ----------
Philipp Arras's avatar
Philipp Arras committed
36
    domain: Domain, tuple of Domain or DomainTuple
37 38
        The domain of the data that is input by "times" and output by
        "adjoint_times".
Martin Reinecke's avatar
docs  
Martin Reinecke committed
39 40 41 42 43 44
    target: Domain, optional
        The target (sub-)domain of the transform operation.
        If omitted, a domain will be chosen automatically.
    space: int, optional
        The index of the subdomain on which the operator should act
        If None, it is set to 0 if `domain` contains exactly one space.
Martin Reinecke's avatar
Martin Reinecke committed
45
        `domain[space]` must be an RGSpace.
46
    """
47

48 49
    def __init__(self, domain, target=None, space=None):
        super(FFTOperator, self).__init__()
50 51

        # Initialize domain and target
Martin Reinecke's avatar
Martin Reinecke committed
52
        self._domain = DomainTuple.make(domain)
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
53
        self._space = utilities.infer_space(self._domain, space)
54

Martin Reinecke's avatar
Martin Reinecke committed
55
        adom = self._domain[self._space]
56 57
        if not isinstance(adom, RGSpace):
            raise TypeError("FFTOperator only works on RGSpaces")
58
        if target is None:
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
59
            target = adom.get_default_codomain()
60

Martin Reinecke's avatar
Martin Reinecke committed
61
        self._target = [dom for dom in self._domain]
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
62 63 64 65
        self._target[self._space] = target
        self._target = DomainTuple.make(self._target)
        adom.check_codomain(target)
        target.check_codomain(adom)
Jait Dixit's avatar
Jait Dixit committed
66

Martin Reinecke's avatar
Martin Reinecke committed
67
        utilities.fft_prep()
Martin Reinecke's avatar
Martin Reinecke committed
68

Martin Reinecke's avatar
Martin Reinecke committed
69 70
    def apply(self, x, mode):
        self._check_input(x, mode)
71
        if np.issubdtype(x.dtype, np.complexfloating):
72 73
            return (self._apply_cartesian(x.real, mode) +
                    1j*self._apply_cartesian(x.imag, mode))
Martin Reinecke's avatar
Martin Reinecke committed
74
        else:
75
            return self._apply_cartesian(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
76 77 78

    def _apply_cartesian(self, x, mode):
        axes = x.domain.axes[self._space]
79
        tdom = self._tgt(mode)
Martin Reinecke's avatar
Martin Reinecke committed
80 81
        oldax = dobj.distaxis(x.val)
        if oldax not in axes:  # straightforward, no redistribution needed
Martin Reinecke's avatar
Martin Reinecke committed
82
            ldat = x.local_data
Martin Reinecke's avatar
Martin Reinecke committed
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
            ldat = utilities.hartley(ldat, axes=axes)
            tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
        elif len(axes) < len(x.shape) or len(axes) == 1:
            # we can use one Hartley pass in between the redistributions
            tmp = dobj.redistribute(x.val, nodist=axes)
            newax = dobj.distaxis(tmp)
            ldat = dobj.local_data(tmp)
            ldat = utilities.hartley(ldat, axes=axes)
            tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
            tmp = dobj.redistribute(tmp, dist=oldax)
        else:  # two separate, full FFTs needed
            # ideal strategy for the moment would be:
            # - do real-to-complex FFT on all local axes
            # - fill up array
            # - redistribute array
            # - do complex-to-complex FFT on remaining axis
            # - add re+im
            # - redistribute back
            rem_axes = tuple(i for i in axes if i != oldax)
            tmp = x.val
            ldat = dobj.local_data(tmp)
Martin Reinecke's avatar
fix  
Martin Reinecke committed
104
            ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
Martin Reinecke's avatar
Martin Reinecke committed
105 106 107 108 109 110 111 112
            if oldax != 0:
                raise ValueError("bad distribution")
            ldat2 = ldat.reshape((ldat.shape[0],
                                  np.prod(ldat.shape[1:])))
            shp2d = (x.val.shape[0], np.prod(x.val.shape[1:]))
            tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
            tmp = dobj.transpose(tmp)
            ldat2 = dobj.local_data(tmp)
Martin Reinecke's avatar
Martin Reinecke committed
113
            ldat2 = utilities.my_fftn(ldat2, axes=(1,))
Martin Reinecke's avatar
Martin Reinecke committed
114 115 116 117 118 119
            ldat2 = ldat2.real+ldat2.imag
            tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
            tmp = dobj.transpose(tmp)
            ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
            tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
        Tval = Field(tdom, tmp)
Martin Reinecke's avatar
cleanup  
Martin Reinecke committed
120
        if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
Martin Reinecke's avatar
Martin Reinecke committed
121
            fct = self._domain[self._space].scalar_dvol
Martin Reinecke's avatar
Martin Reinecke committed
122
        else:
Martin Reinecke's avatar
Martin Reinecke committed
123
            fct = self._target[self._space].scalar_dvol
124
        return Tval if fct == 1 else Tval*fct
Martin Reinecke's avatar
Martin Reinecke committed
125

126 127 128 129
    @property
    def domain(self):
        return self._domain

Jait Dixit's avatar
Jait Dixit committed
130 131 132 133
    @property
    def target(self):
        return self._target

134
    @property
Martin Reinecke's avatar
Martin Reinecke committed
135
    def capability(self):
136
        return self._all_ops