power_space.py 5.49 KB
Newer Older
theos's avatar
theos committed
1
2
# -*- coding: utf-8 -*-

3
import pickle
theos's avatar
theos committed
4
5
import numpy as np

6
7
from keepers import Versionable

8
9
import d2o

10
11
from power_index_factory import PowerIndexFactory

12
from nifty.spaces.space import Space
13
from nifty.spaces.rg_space import RGSpace
14
from nifty.nifty_utilities import cast_axis_to_tuple
theos's avatar
theos committed
15
16


17
18
19
20
class PowerSpace(Versionable, Space):

    _serializable = ('log', 'nbin', 'binbounds', 'kindex', 'rho',
                     'pundex', 'dtype')
21
22
23

    # ---Overwritten properties and methods---

24
25
    def __init__(self, harmonic_domain=RGSpace((1,)),
                 distribution_strategy='not',
26
                 log=False, nbin=None, binbounds=None, power_index=None,
27
28
29
                 dtype=np.dtype('float')):

        super(PowerSpace, self).__init__(dtype)
30
31
        self._ignore_for_hash += ['_pindex', '_kindex', '_rho', '_pundex',
                                  '_k_array']
32
33

        if not isinstance(harmonic_domain, Space):
34
35
            raise ValueError(
                "harmonic_domain must be a Space.")
36
        if not harmonic_domain.harmonic:
37
38
            raise ValueError(
                "harmonic_domain must be a harmonic space.")
39
40
        self._harmonic_domain = harmonic_domain

41
42
43
44
45
46
47
        if power_index is None:
            power_index = PowerIndexFactory.get_power_index(
                            domain=self.harmonic_domain,
                            distribution_strategy=distribution_strategy,
                            log=log,
                            nbin=nbin,
                            binbounds=binbounds)
48
49
50
51
52
53
54
55
56

        config = power_index['config']
        self._log = config['log']
        self._nbin = config['nbin']
        self._binbounds = config['binbounds']

        self._pindex = power_index['pindex']
        self._kindex = power_index['kindex']
        self._rho = power_index['rho']
57
58
        self._pundex = power_index['pundex']
        self._k_array = power_index['k_array']
59

60
61
62
63
64
65
    def pre_cast(self, x, axes=None):
        if callable(x):
            return x(self.kindex)
        else:
            return x

66
67
68
69
70
    # ---Mandatory properties and methods---

    @property
    def harmonic(self):
        return True
71

72
73
    @property
    def shape(self):
74
        return self.kindex.shape
75

76
77
78
79
80
81
82
    @property
    def dim(self):
        return self.shape[0]

    @property
    def total_volume(self):
        # every power-pixel has a volume of 1
83
84
85
        return reduce(lambda x, y: x*y, self.pindex.shape)

    def copy(self):
86
        distribution_strategy = self.pindex.distribution_strategy
87
        return self.__class__(harmonic_domain=self.harmonic_domain,
88
                              distribution_strategy=distribution_strategy,
89
90
91
92
                              log=self.log,
                              nbin=self.nbin,
                              binbounds=self.binbounds,
                              dtype=self.dtype)
93

94
    def weight(self, x, power=1, axes=None, inplace=False):
95
96
97
98
        total_shape = x.shape

        axes = cast_axis_to_tuple(axes, len(total_shape))
        if len(axes) != 1:
99
100
            raise ValueError(
                "axes must be of length 1.")
101
102
103
104

        reshaper = [1, ] * len(total_shape)
        reshaper[axes[0]] = self.shape[0]

105
        weight = self.rho.reshape(reshaper)
106
107
        if power != 1:
            weight = weight ** power
108
109
110
111
112
113

        if inplace:
            x *= weight
            result_x = x
        else:
            result_x = x*weight
114
115
116

        return result_x

117
    def get_distance_array(self, distribution_strategy):
118
119
120
121
        result = d2o.distributed_data_object(
                                self.kindex,
                                distribution_strategy=distribution_strategy)
        return result
theos's avatar
theos committed
122

123
    def get_fft_smoothing_kernel_function(self, sigma):
124
        raise NotImplementedError(
125
            "There is no fft smoothing function for PowerSpace.")
theos's avatar
theos committed
126

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    # ---Added properties and methods---

    @property
    def harmonic_domain(self):
        return self._harmonic_domain

    @property
    def log(self):
        return self._log

    @property
    def nbin(self):
        return self._nbin

    @property
    def binbounds(self):
        return self._binbounds

    @property
    def pindex(self):
        return self._pindex

    @property
    def kindex(self):
        return self._kindex

    @property
    def rho(self):
        return self._rho
156

157
158
159
160
161
162
163
    @property
    def pundex(self):
        return self._pundex

    @property
    def k_array(self):
        return self._k_array
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

    # ---Serialization---

    def _to_hdf5(self, hdf5_group):
        hdf5_group['serialized'] = [
            pickle.dumps(getattr(self, item)) for item in self._serializable
        ]

        return {
            'harmonic_domain': self.harmonic_domain,
            'pindex': self.pindex,
            'k_array': self.k_array
        }

    @classmethod
    def _from_hdf5(cls, hdf5_group, loopback_get):
        deserialized =\
            [pickle.loads(item) for item in hdf5_group['serialized']]

        dtype = deserialized[6]
        harmonic_domain = loopback_get('harmonic_domain')
        power_index = {
            'config': {
                'log': deserialized[0], 'nbin': deserialized[1],
                'binbounds': deserialized[2]
            },
            'pindex': loopback_get('pindex'),
            'kindex': deserialized[3],
            'rho': deserialized[4],
            'pundex': deserialized[5],
            'k_array': loopback_get('k_array')
        }

        result = cls(
            harmonic_domain=harmonic_domain,
            power_index=power_index,
            dtype=dtype
        )
        return result