rgrgtransformation.py 5.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

Jait Dixit's avatar
Jait Dixit committed
19
20
import numpy as np
from transformation import Transformation
Theo Steininger's avatar
Theo Steininger committed
21
from rg_transforms import FFTW, NUMPYFFT
Jait Dixit's avatar
Jait Dixit committed
22
23
24
25
from nifty import RGSpace, nifty_configuration


class RGRGTransformation(Transformation):
Jait Dixit's avatar
Jait Dixit committed
26
    def __init__(self, domain, codomain=None, module=None):
27
        super(RGRGTransformation, self).__init__(domain, codomain, module)
Jait Dixit's avatar
Jait Dixit committed
28
29

        if module is None:
Theo Steininger's avatar
Theo Steininger committed
30
            if nifty_configuration['fft_module'] == 'fftw':
Jait Dixit's avatar
Jait Dixit committed
31
                self._transform = FFTW(self.domain, self.codomain)
Theo Steininger's avatar
Theo Steininger committed
32
33
            elif nifty_configuration['fft_module'] == 'numpy':
                self._transform = NUMPYFFT(self.domain, self.codomain)
Jait Dixit's avatar
Jait Dixit committed
34
            else:
35
                raise ValueError('Unsupported default FFT module:' +
Jait Dixit's avatar
Jait Dixit committed
36
                                 nifty_configuration['fft_module'])
Jait Dixit's avatar
Jait Dixit committed
37
        else:
Theo Steininger's avatar
Theo Steininger committed
38
            if module == 'fftw':
Jait Dixit's avatar
Jait Dixit committed
39
                self._transform = FFTW(self.domain, self.codomain)
Theo Steininger's avatar
Theo Steininger committed
40
41
            elif module == 'numpy':
                self._transform = NUMPYFFT(self.domain, self.codomain)
Jait Dixit's avatar
Jait Dixit committed
42
            else:
43
                raise ValueError('Unsupported FFT module:' + module)
Jait Dixit's avatar
Jait Dixit committed
44

45
    @classmethod
Martin Reinecke's avatar
Martin Reinecke committed
46
    def get_codomain(cls, domain, zerocenter=None):
Jait Dixit's avatar
Jait Dixit committed
47
48
49
50
51
52
53
54
55
        """
            Generates a compatible codomain to which transformations are
            reasonable, i.e.\  either a shifted grid or a Fourier conjugate
            grid.

            Parameters
            ----------
            domain: RGSpace
                Space for which a codomain is to be generated
Martin Reinecke's avatar
Martin Reinecke committed
56
            zerocenter : {bool, numpy.ndarray}, *optional*
Jait Dixit's avatar
Jait Dixit committed
57
58
59
60
61
62
63
64
65
                Whether or not the grid is zerocentered for each axis or not
                (default: None).

            Returns
            -------
            codomain : nifty.rg_space
                A compatible codomain.
        """
        if not isinstance(domain, RGSpace):
66
            raise TypeError("domain needs to be a RGSpace")
Jait Dixit's avatar
Jait Dixit committed
67

Martin Reinecke's avatar
Martin Reinecke committed
68
        # parse the zerocenter input
69
        if zerocenter is None:
70
            zerocenter = domain.zerocenter
Jait Dixit's avatar
Jait Dixit committed
71
72
        # if the input is something scalar, cast it to a boolean
        else:
73
            temp = np.empty_like(domain.zerocenter)
74
            temp[:] = zerocenter
75
            zerocenter = temp
Jait Dixit's avatar
Jait Dixit committed
76
77

        # calculate the initialization parameters
78
79
        distances = 1 / (np.array(domain.shape) *
                         np.array(domain.distances))
Jait Dixit's avatar
Jait Dixit committed
80

81
        new_space = RGSpace(domain.shape,
82
83
                            zerocenter=zerocenter,
                            distances=distances,
Martin Reinecke's avatar
Martin Reinecke committed
84
                            harmonic=(not domain.harmonic))
85
86

        # better safe than sorry
87
        cls.check_codomain(domain, new_space)
Jait Dixit's avatar
Jait Dixit committed
88
89
        return new_space

90
91
    @classmethod
    def check_codomain(cls, domain, codomain):
Jait Dixit's avatar
Jait Dixit committed
92
        if not isinstance(domain, RGSpace):
93
            raise TypeError("domain is not a RGSpace")
Jait Dixit's avatar
Jait Dixit committed
94
95

        if not isinstance(codomain, RGSpace):
96
            raise TypeError("domain is not a RGSpace")
Jait Dixit's avatar
Jait Dixit committed
97

98
99
        if not np.all(np.array(domain.shape) ==
                      np.array(codomain.shape)):
100
101
            raise AttributeError("The shapes of domain and codomain must be "
                                 "identical.")
Jait Dixit's avatar
Jait Dixit committed
102
103

        if domain.harmonic == codomain.harmonic:
104
105
            raise AttributeError("domain.harmonic and codomain.harmonic must "
                                 "not be the same.")
106

Jait Dixit's avatar
Jait Dixit committed
107
108
        # Check if the distances match, i.e. dist' = 1 / (num * dist)
        if not np.all(
109
110
111
            np.absolute(np.array(domain.shape) *
                        np.array(domain.distances) *
                        np.array(codomain.distances) - 1) <
112
                10**-7):
113
114
            raise AttributeError("The grid-distances of domain and codomain "
                                 "do not match.")
Jait Dixit's avatar
Jait Dixit committed
115

116
        super(RGRGTransformation, cls).check_codomain(domain, codomain)
Jait Dixit's avatar
Jait Dixit committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    def transform(self, val, axes=None, **kwargs):
        """
        RG -> RG transform method.

        Parameters
        ----------
        val : np.ndarray or distributed_data_object
            The value array which is to be transformed

        axes : None or tuple
            The axes along which the transformation should take place

        """
        if self._transform.codomain.harmonic:
132
133
134
135
136
            # correct for forward fft.
            # naively one would set power to 0.5 here in order to
            # apply effectively a factor of 1/sqrt(N) to the field.
            # BUT: the pixel volumes of the domain and codomain are different.
            # Hence, in order to produce the same scalar product, power===1.
Jait Dixit's avatar
Jait Dixit committed
137
            val = self._transform.domain.weight(val, power=1, axes=axes)
Jait Dixit's avatar
Jait Dixit committed
138
139
140
141
142

        # Perform the transformation
        Tval = self._transform.transform(val, axes, **kwargs)

        if not self._transform.codomain.harmonic:
143
144
            # correct for inverse fft.
            # See discussion above.
Jait Dixit's avatar
Jait Dixit committed
145
            Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes)
Jait Dixit's avatar
Jait Dixit committed
146
147

        return Tval