rgrgtransformation.py 5.31 KB
Newer Older
Jait Dixit's avatar
Jait Dixit committed
1
2
3
import numpy as np
from transformation import Transformation
from rg_transforms import FFTW, GFFT
4
from nifty.config import dependency_injector as gdi
Jait Dixit's avatar
Jait Dixit committed
5
6
7
8
from nifty import RGSpace, nifty_configuration


class RGRGTransformation(Transformation):
Jait Dixit's avatar
Jait Dixit committed
9
    def __init__(self, domain, codomain=None, module=None):
Jait Dixit's avatar
Jait Dixit committed
10
        super(RGRGTransformation, self).__init__(domain, codomain)
Jait Dixit's avatar
Jait Dixit committed
11
12
13

        if module is None:
            if nifty_configuration['fft_module'] == 'pyfftw':
Jait Dixit's avatar
Jait Dixit committed
14
                self._transform = FFTW(self.domain, self.codomain)
15
16
            elif (nifty_configuration['fft_module'] == 'gfft' or
                  nifty_configuration['fft_module'] == 'gfft_dummy'):
Jait Dixit's avatar
Jait Dixit committed
17
                self._transform = \
Jait Dixit's avatar
Jait Dixit committed
18
19
                    GFFT(self.domain,
                         self.codomain,
Jait Dixit's avatar
Jait Dixit committed
20
                         gdi.get(nifty_configuration['fft_module']))
Jait Dixit's avatar
Jait Dixit committed
21
            else:
Jait Dixit's avatar
Jait Dixit committed
22
23
                raise ValueError('ERROR: unknow default FFT module:' +
                                 nifty_configuration['fft_module'])
Jait Dixit's avatar
Jait Dixit committed
24
        else:
Jait Dixit's avatar
Jait Dixit committed
25
            if module == 'pyfftw':
Jait Dixit's avatar
Jait Dixit committed
26
                self._transform = FFTW(self.domain, self.codomain)
Jait Dixit's avatar
Jait Dixit committed
27
28
            elif module == 'gfft':
                self._transform = \
Jait Dixit's avatar
Jait Dixit committed
29
                    GFFT(self.domain, self.codomain, gdi.get('gfft'))
Jait Dixit's avatar
Jait Dixit committed
30
31
            elif module == 'gfft_dummy':
                self._transform = \
Jait Dixit's avatar
Jait Dixit committed
32
                    GFFT(self.domain, self.codomain, gdi.get('gfft_dummy'))
Jait Dixit's avatar
Jait Dixit committed
33
34
            else:
                raise ValueError('ERROR: unknow FFT module:' + module)
Jait Dixit's avatar
Jait Dixit committed
35

36
37
    @classmethod
    def get_codomain(cls, domain, dtype=None, zerocenter=None):
Jait Dixit's avatar
Jait Dixit committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        """
            Generates a compatible codomain to which transformations are
            reasonable, i.e.\  either a shifted grid or a Fourier conjugate
            grid.

            Parameters
            ----------
            domain: RGSpace
                Space for which a codomain is to be generated
            cozerocenter : {bool, numpy.ndarray}, *optional*
                Whether or not the grid is zerocentered for each axis or not
                (default: None).

            Returns
            -------
            codomain : nifty.rg_space
                A compatible codomain.
        """
        if not isinstance(domain, RGSpace):
            raise TypeError('ERROR: domain needs to be a RGSpace')

        # parse the cozerocenter input
theos's avatar
theos committed
60
        if zerocenter is None:
61
            zerocenter = domain.zerocenter
Jait Dixit's avatar
Jait Dixit committed
62
63
        # if the input is something scalar, cast it to a boolean
        else:
64
            temp = np.empty_like(domain.zerocenter)
theos's avatar
theos committed
65
            temp[:] = zerocenter
66
            zerocenter = temp
Jait Dixit's avatar
Jait Dixit committed
67
68

        # calculate the initialization parameters
69
70
        distances = 1 / (np.array(domain.shape) *
                         np.array(domain.distances))
theos's avatar
theos committed
71
        if dtype is None:
72
73
74
            # create a definitely complex dtype from the dtype of domain
            one = domain.dtype.type(1)
            dtype = np.dtype(type(one + 1j))
Jait Dixit's avatar
Jait Dixit committed
75

76
        new_space = RGSpace(domain.shape,
theos's avatar
theos committed
77
78
79
80
                            zerocenter=zerocenter,
                            distances=distances,
                            harmonic=(not domain.harmonic),
                            dtype=dtype)
81
        cls.check_codomain(domain, new_space)
Jait Dixit's avatar
Jait Dixit committed
82
83
        return new_space

84
85
    @classmethod
    def check_codomain(cls, domain, codomain):
Jait Dixit's avatar
Jait Dixit committed
86
        if not isinstance(domain, RGSpace):
87
            raise TypeError('ERROR: domain is not a RGSpace')
Jait Dixit's avatar
Jait Dixit committed
88
89
90
91
92

        if codomain is None:
            return False

        if not isinstance(codomain, RGSpace):
theos's avatar
theos committed
93
            return False
Jait Dixit's avatar
Jait Dixit committed
94

95
96
        if not np.all(np.array(domain.shape) ==
                      np.array(codomain.shape)):
Jait Dixit's avatar
Jait Dixit committed
97
98
99
100
101
            return False

        if domain.harmonic == codomain.harmonic:
            return False

102
103
        if codomain.harmonic and not issubclass(codomain.dtype.type,
                                                np.complexfloating):
104
            cls.logger.warn("Codomain is harmonic but dtype is real.")
105

Jait Dixit's avatar
Jait Dixit committed
106
107
        # Check if the distances match, i.e. dist' = 1 / (num * dist)
        if not np.all(
108
109
110
            np.absolute(np.array(domain.shape) *
                        np.array(domain.distances) *
                        np.array(codomain.distances) - 1) <
theos's avatar
theos committed
111
                10**-7):
Jait Dixit's avatar
Jait Dixit committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
            return False

        return True

    def transform(self, val, axes=None, **kwargs):
        """
        RG -> RG transform method.

        Parameters
        ----------
        val : np.ndarray or distributed_data_object
            The value array which is to be transformed

        axes : None or tuple
            The axes along which the transformation should take place

        """
        if self._transform.codomain.harmonic:
130
131
132
133
134
            # correct for forward fft.
            # naively one would set power to 0.5 here in order to
            # apply effectively a factor of 1/sqrt(N) to the field.
            # BUT: the pixel volumes of the domain and codomain are different.
            # Hence, in order to produce the same scalar product, power===1.
Jait Dixit's avatar
Jait Dixit committed
135
            val = self._transform.domain.weight(val, power=1, axes=axes)
Jait Dixit's avatar
Jait Dixit committed
136
137
138
139
140

        # Perform the transformation
        Tval = self._transform.transform(val, axes, **kwargs)

        if not self._transform.codomain.harmonic:
141
142
            # correct for inverse fft.
            # See discussion above.
Jait Dixit's avatar
Jait Dixit committed
143
            Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes)
Jait Dixit's avatar
Jait Dixit committed
144
145

        return Tval