rgrgtransformation.py 5.33 KB
Newer Older
Jait Dixit's avatar
Jait Dixit committed
1
2
3
import numpy as np
from transformation import Transformation
from rg_transforms import FFTW, GFFT
4
from nifty.config import dependency_injector as gdi
Jait Dixit's avatar
Jait Dixit committed
5
6
from nifty import RGSpace, nifty_configuration

7
8
9
import logging
logger = logging.getLogger('NIFTy.RGRGTransformation')

Jait Dixit's avatar
Jait Dixit committed
10
11

class RGRGTransformation(Transformation):
Jait Dixit's avatar
Jait Dixit committed
12
    def __init__(self, domain, codomain=None, module=None):
13
        super(RGRGTransformation, self).__init__(domain, codomain, module)
Jait Dixit's avatar
Jait Dixit committed
14
15
16
17

        if module is None:
            if nifty_configuration['fft_module'] == 'pyfftw':
                self._transform = FFTW(domain, codomain)
18
19
            elif (nifty_configuration['fft_module'] == 'gfft' or
                  nifty_configuration['fft_module'] == 'gfft_dummy'):
Jait Dixit's avatar
Jait Dixit committed
20
21
22
23
                self._transform = \
                    GFFT(domain,
                         codomain,
                         gdi.get(nifty_configuration['fft_module']))
Jait Dixit's avatar
Jait Dixit committed
24
            else:
Jait Dixit's avatar
Jait Dixit committed
25
26
                raise ValueError('ERROR: unknow default FFT module:' +
                                 nifty_configuration['fft_module'])
Jait Dixit's avatar
Jait Dixit committed
27
        else:
Jait Dixit's avatar
Jait Dixit committed
28
29
30
31
32
33
34
35
36
37
            if module == 'pyfftw':
                self._transform = FFTW(domain, codomain)
            elif module == 'gfft':
                self._transform = \
                    GFFT(domain, codomain, gdi.get('gfft'))
            elif module == 'gfft_dummy':
                self._transform = \
                    GFFT(domain, codomain, gdi.get('gfft_dummy'))
            else:
                raise ValueError('ERROR: unknow FFT module:' + module)
Jait Dixit's avatar
Jait Dixit committed
38

39
40
    @classmethod
    def get_codomain(cls, domain, dtype=None, zerocenter=None):
Jait Dixit's avatar
Jait Dixit committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        """
            Generates a compatible codomain to which transformations are
            reasonable, i.e.\  either a shifted grid or a Fourier conjugate
            grid.

            Parameters
            ----------
            domain: RGSpace
                Space for which a codomain is to be generated
            cozerocenter : {bool, numpy.ndarray}, *optional*
                Whether or not the grid is zerocentered for each axis or not
                (default: None).

            Returns
            -------
            codomain : nifty.rg_space
                A compatible codomain.
        """
        if not isinstance(domain, RGSpace):
            raise TypeError('ERROR: domain needs to be a RGSpace')

        # parse the cozerocenter input
theos's avatar
theos committed
63
        if zerocenter is None:
64
            zerocenter = domain.zerocenter
Jait Dixit's avatar
Jait Dixit committed
65
66
        # if the input is something scalar, cast it to a boolean
        else:
67
            temp = np.empty_like(domain.zerocenter)
theos's avatar
theos committed
68
            temp[:] = zerocenter
69
            zerocenter = temp
Jait Dixit's avatar
Jait Dixit committed
70
71

        # calculate the initialization parameters
72
73
        distances = 1 / (np.array(domain.shape) *
                         np.array(domain.distances))
theos's avatar
theos committed
74
        if dtype is None:
75
76
77
            # create a definitely complex dtype from the dtype of domain
            one = domain.dtype.type(1)
            dtype = np.dtype(type(one + 1j))
Jait Dixit's avatar
Jait Dixit committed
78

79
        new_space = RGSpace(domain.shape,
theos's avatar
theos committed
80
81
82
83
                            zerocenter=zerocenter,
                            distances=distances,
                            harmonic=(not domain.harmonic),
                            dtype=dtype)
84
        cls.check_codomain(domain, new_space)
Jait Dixit's avatar
Jait Dixit committed
85
86
        return new_space

Jait Dixit's avatar
Jait Dixit committed
87
88
89
    @staticmethod
    def check_codomain(domain, codomain):
        if not isinstance(domain, RGSpace):
90
            raise TypeError('ERROR: domain is not a RGSpace')
Jait Dixit's avatar
Jait Dixit committed
91
92
93
94
95

        if codomain is None:
            return False

        if not isinstance(codomain, RGSpace):
theos's avatar
theos committed
96
            return False
Jait Dixit's avatar
Jait Dixit committed
97

98
99
        if not np.all(np.array(domain.shape) ==
                      np.array(codomain.shape)):
Jait Dixit's avatar
Jait Dixit committed
100
101
102
103
104
            return False

        if domain.harmonic == codomain.harmonic:
            return False

105
106
        if codomain.harmonic and not issubclass(codomain.dtype.type,
                                                np.complexfloating):
107
            logger.warn("codomain is harmonic but dtype is real.")
108

Jait Dixit's avatar
Jait Dixit committed
109
110
        # Check if the distances match, i.e. dist' = 1 / (num * dist)
        if not np.all(
111
112
113
            np.absolute(np.array(domain.shape) *
                        np.array(domain.distances) *
                        np.array(codomain.distances) - 1) <
theos's avatar
theos committed
114
                10**-7):
Jait Dixit's avatar
Jait Dixit committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            return False

        return True

    def transform(self, val, axes=None, **kwargs):
        """
        RG -> RG transform method.

        Parameters
        ----------
        val : np.ndarray or distributed_data_object
            The value array which is to be transformed

        axes : None or tuple
            The axes along which the transformation should take place

        """
        if self._transform.codomain.harmonic:
133
134
135
136
137
            # correct for forward fft.
            # naively one would set power to 0.5 here in order to
            # apply effectively a factor of 1/sqrt(N) to the field.
            # BUT: the pixel volumes of the domain and codomain are different.
            # Hence, in order to produce the same scalar product, power===1.
Jait Dixit's avatar
Jait Dixit committed
138
            val = self._transform.domain.weight(val, power=1, axes=axes)
Jait Dixit's avatar
Jait Dixit committed
139
140
141
142
143

        # Perform the transformation
        Tval = self._transform.transform(val, axes, **kwargs)

        if not self._transform.codomain.harmonic:
144
145
            # correct for inverse fft.
            # See discussion above.
Jait Dixit's avatar
Jait Dixit committed
146
            Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes)
Jait Dixit's avatar
Jait Dixit committed
147
148

        return Tval