power_space.py 6.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# NIFTy
# Copyright (C) 2017  Theo Steininger
#
# Author: Theo Steininger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
theos's avatar
theos committed
18

theos's avatar
theos committed
19
20
import numpy as np

21
22
import d2o

23
24
from power_index_factory import PowerIndexFactory

25
from nifty.spaces.space import Space
26
from nifty.spaces.rg_space import RGSpace
theos's avatar
theos committed
27
28


Theo Steininger's avatar
Theo Steininger committed
29
class PowerSpace(Space):
30

31
32
    # ---Overwritten properties and methods---

33
34
    def __init__(self, harmonic_domain=RGSpace((1,)),
                 distribution_strategy='not',
Martin Reinecke's avatar
Martin Reinecke committed
35
                 log=False, nbin=None, binbounds=None):
36

Martin Reinecke's avatar
Martin Reinecke committed
37
        super(PowerSpace, self).__init__()
38
39
        self._ignore_for_hash += ['_pindex', '_kindex', '_rho', '_pundex',
                                  '_k_array']
40
41

        if not isinstance(harmonic_domain, Space):
42
43
            raise ValueError(
                "harmonic_domain must be a Space.")
44
        if not harmonic_domain.harmonic:
45
46
            raise ValueError(
                "harmonic_domain must be a harmonic space.")
47
48
        self._harmonic_domain = harmonic_domain

Jait Dixit's avatar
Jait Dixit committed
49
50
51
52
53
54
        power_index = PowerIndexFactory.get_power_index(
                        domain=self.harmonic_domain,
                        distribution_strategy=distribution_strategy,
                        log=log,
                        nbin=nbin,
                        binbounds=binbounds)
55
56
57
58
59
60
61
62
63

        config = power_index['config']
        self._log = config['log']
        self._nbin = config['nbin']
        self._binbounds = config['binbounds']

        self._pindex = power_index['pindex']
        self._kindex = power_index['kindex']
        self._rho = power_index['rho']
64
65
        self._pundex = power_index['pundex']
        self._k_array = power_index['k_array']
66

67
68
69
70
71
72
    def pre_cast(self, x, axes=None):
        if callable(x):
            return x(self.kindex)
        else:
            return x

73
74
75
76
77
    # ---Mandatory properties and methods---

    @property
    def harmonic(self):
        return True
78

79
80
    @property
    def shape(self):
81
        return self.kindex.shape
82

83
84
85
86
87
88
89
    @property
    def dim(self):
        return self.shape[0]

    @property
    def total_volume(self):
        # every power-pixel has a volume of 1
Jait Dixit's avatar
Jait Dixit committed
90
        return float(reduce(lambda x, y: x*y, self.pindex.shape))
91
92

    def copy(self):
93
        distribution_strategy = self.pindex.distribution_strategy
94
        return self.__class__(harmonic_domain=self.harmonic_domain,
95
                              distribution_strategy=distribution_strategy,
96
97
                              log=self.log,
                              nbin=self.nbin,
Martin Reinecke's avatar
Martin Reinecke committed
98
                              binbounds=self.binbounds)
99

100
    def weight(self, x, power=1, axes=None, inplace=False):
Jait Dixit's avatar
Jait Dixit committed
101
102
        reshaper = [1, ] * len(x.shape)
        # we know len(axes) is always 1
103
104
        reshaper[axes[0]] = self.shape[0]

105
        weight = self.rho.reshape(reshaper)
106
107
        if power != 1:
            weight = weight ** power
108
109
110
111
112
113

        if inplace:
            x *= weight
            result_x = x
        else:
            result_x = x*weight
114
115
116

        return result_x

117
    def get_distance_array(self, distribution_strategy):
118
        result = d2o.distributed_data_object(
Martin Reinecke's avatar
Martin Reinecke committed
119
                                self.kindex, dtype=np.float64,
120
121
                                distribution_strategy=distribution_strategy)
        return result
theos's avatar
theos committed
122

123
    def get_fft_smoothing_kernel_function(self, sigma):
124
        raise NotImplementedError(
125
            "There is no fft smoothing function for PowerSpace.")
theos's avatar
theos committed
126

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    # ---Added properties and methods---

    @property
    def harmonic_domain(self):
        return self._harmonic_domain

    @property
    def log(self):
        return self._log

    @property
    def nbin(self):
        return self._nbin

    @property
    def binbounds(self):
        return self._binbounds

    @property
    def pindex(self):
        return self._pindex

    @property
    def kindex(self):
        return self._kindex

    @property
    def rho(self):
        return self._rho
156

157
158
159
160
161
162
163
    @property
    def pundex(self):
        return self._pundex

    @property
    def k_array(self):
        return self._k_array
164
165
166
167

    # ---Serialization---

    def _to_hdf5(self, hdf5_group):
Jait Dixit's avatar
Jait Dixit committed
168
169
170
        hdf5_group['kindex'] = self.kindex
        hdf5_group['rho'] = self.rho
        hdf5_group['pundex'] = self.pundex
171
        hdf5_group['log'] = self.log
Theo Steininger's avatar
Theo Steininger committed
172
173
174
        # Store nbin as string, since it can be None
        hdf5_group.attrs['nbin'] = str(self.nbin)
        hdf5_group.attrs['binbounds'] = str(self.binbounds)
175
176
177
178
179
180
181
182

        return {
            'harmonic_domain': self.harmonic_domain,
            'pindex': self.pindex,
            'k_array': self.k_array
        }

    @classmethod
Theo Steininger's avatar
Theo Steininger committed
183
    def _from_hdf5(cls, hdf5_group, repository):
Jait Dixit's avatar
Jait Dixit committed
184
185
186
187
        # make an empty PowerSpace object
        new_ps = EmptyPowerSpace()
        # reset class
        new_ps.__class__ = cls
Jait Dixit's avatar
Jait Dixit committed
188
        # call instructor so that classes are properly setup
Martin Reinecke's avatar
Martin Reinecke committed
189
        super(PowerSpace, new_ps).__init__()
Jait Dixit's avatar
Jait Dixit committed
190
        # set all values
Theo Steininger's avatar
Theo Steininger committed
191
        new_ps._harmonic_domain = repository.get('harmonic_domain', hdf5_group)
192
        new_ps._log = hdf5_group['log'][()]
Theo Steininger's avatar
Theo Steininger committed
193
194
        exec('new_ps._nbin = ' + hdf5_group.attrs['nbin'])
        exec('new_ps._binbounds = ' + hdf5_group.attrs['binbounds'])
Jait Dixit's avatar
Jait Dixit committed
195

Theo Steininger's avatar
Theo Steininger committed
196
        new_ps._pindex = repository.get('pindex', hdf5_group)
Jait Dixit's avatar
Jait Dixit committed
197
198
199
        new_ps._kindex = hdf5_group['kindex'][:]
        new_ps._rho = hdf5_group['rho'][:]
        new_ps._pundex = hdf5_group['pundex'][:]
Theo Steininger's avatar
Theo Steininger committed
200
        new_ps._k_array = repository.get('k_array', hdf5_group)
Jait Dixit's avatar
Jait Dixit committed
201
        new_ps._ignore_for_hash += ['_pindex', '_kindex', '_rho', '_pundex',
202
                                    '_k_array']
Jait Dixit's avatar
Jait Dixit committed
203
204
205
206
207
208
209

        return new_ps


class EmptyPowerSpace(PowerSpace):
    def __init__(self):
        pass