rgrgtransformation.py 6.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13
14
15
16
17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
18

Jait Dixit's avatar
Jait Dixit committed
19
20
import numpy as np
from transformation import Transformation
Theo Steininger's avatar
Theo Steininger committed
21
from rg_transforms import MPIFFT, SerialFFT
Jait Dixit's avatar
Jait Dixit committed
22
23
24
25
from nifty import RGSpace, nifty_configuration


class RGRGTransformation(Transformation):
26
27
28

    # ---Overwritten properties and methods---

Jait Dixit's avatar
Jait Dixit committed
29
    def __init__(self, domain, codomain=None, module=None):
30
        super(RGRGTransformation, self).__init__(domain, codomain, module)
Jait Dixit's avatar
Jait Dixit committed
31
32

        if module is None:
33
34
            module = nifty_configuration['fft_module']

Theo Steininger's avatar
Theo Steininger committed
35
        if module == 'fftw_mpi':
36
            self._transform = MPIFFT(self.domain, self.codomain)
Theo Steininger's avatar
Theo Steininger committed
37
38
39
40
41
42
        elif module == 'fftw':
            self._transform = SerialFFT(self.domain, self.codomain,
                                        use_fftw=True)
        elif module == 'numpy':
            self._transform = SerialFFT(self.domain, self.codomain,
                                        use_fftw=False)
Jait Dixit's avatar
Jait Dixit committed
43
        else:
44
            raise ValueError('Unsupported FFT module:' + module)
Jait Dixit's avatar
Jait Dixit committed
45

46
47
        self.harmonic_base = nifty_configuration['harmonic_rg_base']

48
49
50
51
52
53
    # ---Mandatory properties and methods---

    @property
    def unitary(self):
        return True

54
    @classmethod
Martin Reinecke's avatar
Martin Reinecke committed
55
    def get_codomain(cls, domain, zerocenter=None):
Jait Dixit's avatar
Jait Dixit committed
56
57
58
59
60
61
62
63
64
        """
            Generates a compatible codomain to which transformations are
            reasonable, i.e.\  either a shifted grid or a Fourier conjugate
            grid.

            Parameters
            ----------
            domain: RGSpace
                Space for which a codomain is to be generated
Martin Reinecke's avatar
Martin Reinecke committed
65
            zerocenter : {bool, numpy.ndarray}, *optional*
Jait Dixit's avatar
Jait Dixit committed
66
67
68
69
70
71
72
73
74
                Whether or not the grid is zerocentered for each axis or not
                (default: None).

            Returns
            -------
            codomain : nifty.rg_space
                A compatible codomain.
        """
        if not isinstance(domain, RGSpace):
75
            raise TypeError("domain needs to be a RGSpace")
Jait Dixit's avatar
Jait Dixit committed
76

Martin Reinecke's avatar
Martin Reinecke committed
77
        # parse the zerocenter input
theos's avatar
theos committed
78
        if zerocenter is None:
79
            zerocenter = domain.zerocenter
Jait Dixit's avatar
Jait Dixit committed
80
81
        # if the input is something scalar, cast it to a boolean
        else:
82
            temp = np.empty_like(domain.zerocenter)
theos's avatar
theos committed
83
            temp[:] = zerocenter
84
            zerocenter = temp
Jait Dixit's avatar
Jait Dixit committed
85
86

        # calculate the initialization parameters
87
88
        distances = 1 / (np.array(domain.shape) *
                         np.array(domain.distances))
Jait Dixit's avatar
Jait Dixit committed
89

90
        new_space = RGSpace(domain.shape,
theos's avatar
theos committed
91
92
                            zerocenter=zerocenter,
                            distances=distances,
Martin Reinecke's avatar
Martin Reinecke committed
93
                            harmonic=(not domain.harmonic))
94
95

        # better safe than sorry
96
        cls.check_codomain(domain, new_space)
Jait Dixit's avatar
Jait Dixit committed
97
98
        return new_space

99
100
    @classmethod
    def check_codomain(cls, domain, codomain):
Jait Dixit's avatar
Jait Dixit committed
101
        if not isinstance(domain, RGSpace):
102
            raise TypeError("domain is not a RGSpace")
Jait Dixit's avatar
Jait Dixit committed
103
104

        if not isinstance(codomain, RGSpace):
105
            raise TypeError("domain is not a RGSpace")
Jait Dixit's avatar
Jait Dixit committed
106

107
108
        if not np.all(np.array(domain.shape) ==
                      np.array(codomain.shape)):
109
110
            raise AttributeError("The shapes of domain and codomain must be "
                                 "identical.")
Jait Dixit's avatar
Jait Dixit committed
111
112

        if domain.harmonic == codomain.harmonic:
113
114
            raise AttributeError("domain.harmonic and codomain.harmonic must "
                                 "not be the same.")
115

Jait Dixit's avatar
Jait Dixit committed
116
117
        # Check if the distances match, i.e. dist' = 1 / (num * dist)
        if not np.all(
118
119
120
            np.absolute(np.array(domain.shape) *
                        np.array(domain.distances) *
                        np.array(codomain.distances) - 1) <
121
                1e-7):
122
123
            raise AttributeError("The grid-distances of domain and codomain "
                                 "do not match.")
Jait Dixit's avatar
Jait Dixit committed
124

125
        super(RGRGTransformation, cls).check_codomain(domain, codomain)
Jait Dixit's avatar
Jait Dixit committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    def transform(self, val, axes=None, **kwargs):
        """
        RG -> RG transform method.

        Parameters
        ----------
        val : np.ndarray or distributed_data_object
            The value array which is to be transformed

        axes : None or tuple
            The axes along which the transformation should take place

        """
        if self._transform.codomain.harmonic:
141
142
143
144
145
            # correct for forward fft.
            # naively one would set power to 0.5 here in order to
            # apply effectively a factor of 1/sqrt(N) to the field.
            # BUT: the pixel volumes of the domain and codomain are different.
            # Hence, in order to produce the same scalar product, power===1.
Jait Dixit's avatar
Jait Dixit committed
146
            val = self._transform.domain.weight(val, power=1, axes=axes)
Jait Dixit's avatar
Jait Dixit committed
147
148

        # Perform the transformation
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        if self.harmonic_base == 'complex':
            Tval = self._transform.transform(val, axes, **kwargs)
        else:
            if issubclass(val.dtype.type, np.complexfloating):
                Tval_real = self._transform.transform(val.real, axes,
                                                      **kwargs)
                Tval_imag = self._transform.transform(val.imag, axes,
                                                      **kwargs)
                if self.codomain.harmonic:
                    Tval_real.data.real += Tval_real.data.imag
                    Tval_real.data.imag = \
                        Tval_imag.data.real + Tval_imag.data.imag
                else:
                    Tval_real.data.real -= Tval_real.data.imag
                    Tval_real.data.imag = \
                        Tval_imag.data.real - Tval_imag.data.imag

                Tval = Tval_real
            else:
                Tval = self._transform.transform(val, axes, **kwargs)
                if self.codomain.harmonic:
                    Tval.data.real += Tval.data.imag
                else:
                    Tval.data.real -= Tval.data.imag
                Tval = Tval.real
Jait Dixit's avatar
Jait Dixit committed
174
175

        if not self._transform.codomain.harmonic:
176
177
            # correct for inverse fft.
            # See discussion above.
Jait Dixit's avatar
Jait Dixit committed
178
            Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes)
Jait Dixit's avatar
Jait Dixit committed
179
180

        return Tval