power_space.py 5.58 KB
Newer Older
theos's avatar
theos committed
1
2
# -*- coding: utf-8 -*-

theos's avatar
theos committed
3
4
import numpy as np

5
6
import d2o

7
8
from power_index_factory import PowerIndexFactory

9
from nifty.spaces.space import Space
10
from nifty.spaces.rg_space import RGSpace
11
from nifty.nifty_utilities import cast_axis_to_tuple
theos's avatar
theos committed
12
13


Theo Steininger's avatar
Theo Steininger committed
14
class PowerSpace(Space):
15

16
17
    # ---Overwritten properties and methods---

18
19
    def __init__(self, harmonic_domain=RGSpace((1,)),
                 distribution_strategy='not',
Jait Dixit's avatar
Jait Dixit committed
20
                 log=False, nbin=None, binbounds=None,
21
                 dtype=None):
22
23

        super(PowerSpace, self).__init__(dtype)
24
25
        self._ignore_for_hash += ['_pindex', '_kindex', '_rho', '_pundex',
                                  '_k_array']
26
27

        if not isinstance(harmonic_domain, Space):
28
29
            raise ValueError(
                "harmonic_domain must be a Space.")
30
        if not harmonic_domain.harmonic:
31
32
            raise ValueError(
                "harmonic_domain must be a harmonic space.")
33
34
        self._harmonic_domain = harmonic_domain

Jait Dixit's avatar
Jait Dixit committed
35
36
37
38
39
40
        power_index = PowerIndexFactory.get_power_index(
                        domain=self.harmonic_domain,
                        distribution_strategy=distribution_strategy,
                        log=log,
                        nbin=nbin,
                        binbounds=binbounds)
41
42
43
44
45
46
47
48
49

        config = power_index['config']
        self._log = config['log']
        self._nbin = config['nbin']
        self._binbounds = config['binbounds']

        self._pindex = power_index['pindex']
        self._kindex = power_index['kindex']
        self._rho = power_index['rho']
50
51
        self._pundex = power_index['pundex']
        self._k_array = power_index['k_array']
52

53
54
55
56
57
58
    def pre_cast(self, x, axes=None):
        if callable(x):
            return x(self.kindex)
        else:
            return x

59
60
61
62
63
    # ---Mandatory properties and methods---

    @property
    def harmonic(self):
        return True
64

65
66
    @property
    def shape(self):
67
        return self.kindex.shape
68

69
70
71
72
73
74
75
    @property
    def dim(self):
        return self.shape[0]

    @property
    def total_volume(self):
        # every power-pixel has a volume of 1
76
77
78
        return reduce(lambda x, y: x*y, self.pindex.shape)

    def copy(self):
79
        distribution_strategy = self.pindex.distribution_strategy
80
        return self.__class__(harmonic_domain=self.harmonic_domain,
81
                              distribution_strategy=distribution_strategy,
82
83
84
85
                              log=self.log,
                              nbin=self.nbin,
                              binbounds=self.binbounds,
                              dtype=self.dtype)
86

87
    def weight(self, x, power=1, axes=None, inplace=False):
88
89
90
91
        total_shape = x.shape

        axes = cast_axis_to_tuple(axes, len(total_shape))
        if len(axes) != 1:
92
93
            raise ValueError(
                "axes must be of length 1.")
94
95
96
97

        reshaper = [1, ] * len(total_shape)
        reshaper[axes[0]] = self.shape[0]

98
        weight = self.rho.reshape(reshaper)
99
100
        if power != 1:
            weight = weight ** power
101
102
103
104
105
106

        if inplace:
            x *= weight
            result_x = x
        else:
            result_x = x*weight
107
108
109

        return result_x

110
    def get_distance_array(self, distribution_strategy):
111
112
113
114
        result = d2o.distributed_data_object(
                                self.kindex,
                                distribution_strategy=distribution_strategy)
        return result
theos's avatar
theos committed
115

116
    def get_fft_smoothing_kernel_function(self, sigma):
117
        raise NotImplementedError(
118
            "There is no fft smoothing function for PowerSpace.")
theos's avatar
theos committed
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    # ---Added properties and methods---

    @property
    def harmonic_domain(self):
        return self._harmonic_domain

    @property
    def log(self):
        return self._log

    @property
    def nbin(self):
        return self._nbin

    @property
    def binbounds(self):
        return self._binbounds

    @property
    def pindex(self):
        return self._pindex

    @property
    def kindex(self):
        return self._kindex

    @property
    def rho(self):
        return self._rho
149

150
151
152
153
154
155
156
    @property
    def pundex(self):
        return self._pundex

    @property
    def k_array(self):
        return self._k_array
157
158
159
160

    # ---Serialization---

    def _to_hdf5(self, hdf5_group):
Jait Dixit's avatar
Jait Dixit committed
161
162
163
        hdf5_group['kindex'] = self.kindex
        hdf5_group['rho'] = self.rho
        hdf5_group['pundex'] = self.pundex
Theo Steininger's avatar
Theo Steininger committed
164
        hdf5_group.attrs['dtype'] = self.dtype.name
165
        hdf5_group['log'] = self.log
Theo Steininger's avatar
Theo Steininger committed
166
167
168
        # Store nbin as string, since it can be None
        hdf5_group.attrs['nbin'] = str(self.nbin)
        hdf5_group.attrs['binbounds'] = str(self.binbounds)
169
170
171
172
173
174
175
176

        return {
            'harmonic_domain': self.harmonic_domain,
            'pindex': self.pindex,
            'k_array': self.k_array
        }

    @classmethod
Theo Steininger's avatar
Theo Steininger committed
177
    def _from_hdf5(cls, hdf5_group, repository):
Jait Dixit's avatar
Jait Dixit committed
178
179
180
181
182
        # make an empty PowerSpace object
        new_ps = EmptyPowerSpace()
        # reset class
        new_ps.__class__ = cls
        # set all values
Theo Steininger's avatar
Theo Steininger committed
183
184
        new_ps.dtype = np.dtype(hdf5_group.attrs['dtype'])
        new_ps._harmonic_domain = repository.get('harmonic_domain', hdf5_group)
185
        new_ps._log = hdf5_group['log'][()]
Theo Steininger's avatar
Theo Steininger committed
186
187
        exec('new_ps._nbin = ' + hdf5_group.attrs['nbin'])
        exec('new_ps._binbounds = ' + hdf5_group.attrs['binbounds'])
Jait Dixit's avatar
Jait Dixit committed
188

Theo Steininger's avatar
Theo Steininger committed
189
        new_ps._pindex = repository.get('pindex', hdf5_group)
Jait Dixit's avatar
Jait Dixit committed
190
191
192
        new_ps._kindex = hdf5_group['kindex'][:]
        new_ps._rho = hdf5_group['rho'][:]
        new_ps._pundex = hdf5_group['pundex'][:]
Theo Steininger's avatar
Theo Steininger committed
193
        new_ps._k_array = repository.get('k_array', hdf5_group)
Jait Dixit's avatar
Jait Dixit committed
194
195
196
197
198
199
200

        return new_ps


class EmptyPowerSpace(PowerSpace):
    def __init__(self):
        pass