power_space.py 5.57 KB
Newer Older
theos's avatar
theos committed
1 2
# -*- coding: utf-8 -*-

3
import pickle
theos's avatar
theos committed
4 5
import numpy as np

6 7
from keepers import Versionable

8 9
import d2o

10 11
from power_index_factory import PowerIndexFactory

12
from nifty.spaces.space import Space
13
from nifty.spaces.rg_space import RGSpace
14
from nifty.nifty_utilities import cast_axis_to_tuple
theos's avatar
theos committed
15 16


17 18
class PowerSpace(Versionable, Space):

19 20
    # ---Overwritten properties and methods---

21 22
    def __init__(self, harmonic_domain=RGSpace((1,)),
                 distribution_strategy='not',
Jait Dixit's avatar
Jait Dixit committed
23
                 log=False, nbin=None, binbounds=None,
24 25 26
                 dtype=np.dtype('float')):

        super(PowerSpace, self).__init__(dtype)
27 28
        self._ignore_for_hash += ['_pindex', '_kindex', '_rho', '_pundex',
                                  '_k_array']
29 30

        if not isinstance(harmonic_domain, Space):
31 32
            raise ValueError(
                "harmonic_domain must be a Space.")
33
        if not harmonic_domain.harmonic:
34 35
            raise ValueError(
                "harmonic_domain must be a harmonic space.")
36 37
        self._harmonic_domain = harmonic_domain

Jait Dixit's avatar
Jait Dixit committed
38 39 40 41 42 43
        power_index = PowerIndexFactory.get_power_index(
                        domain=self.harmonic_domain,
                        distribution_strategy=distribution_strategy,
                        log=log,
                        nbin=nbin,
                        binbounds=binbounds)
44 45 46 47 48 49 50 51 52

        config = power_index['config']
        self._log = config['log']
        self._nbin = config['nbin']
        self._binbounds = config['binbounds']

        self._pindex = power_index['pindex']
        self._kindex = power_index['kindex']
        self._rho = power_index['rho']
53 54
        self._pundex = power_index['pundex']
        self._k_array = power_index['k_array']
55

56 57 58 59 60 61
    def pre_cast(self, x, axes=None):
        if callable(x):
            return x(self.kindex)
        else:
            return x

62 63 64 65 66
    # ---Mandatory properties and methods---

    @property
    def harmonic(self):
        return True
67

68 69
    @property
    def shape(self):
70
        return self.kindex.shape
71

72 73 74 75 76 77 78
    @property
    def dim(self):
        return self.shape[0]

    @property
    def total_volume(self):
        # every power-pixel has a volume of 1
79 80 81
        return reduce(lambda x, y: x*y, self.pindex.shape)

    def copy(self):
82
        distribution_strategy = self.pindex.distribution_strategy
83
        return self.__class__(harmonic_domain=self.harmonic_domain,
84
                              distribution_strategy=distribution_strategy,
85 86 87 88
                              log=self.log,
                              nbin=self.nbin,
                              binbounds=self.binbounds,
                              dtype=self.dtype)
89

90
    def weight(self, x, power=1, axes=None, inplace=False):
91 92 93 94
        total_shape = x.shape

        axes = cast_axis_to_tuple(axes, len(total_shape))
        if len(axes) != 1:
95 96
            raise ValueError(
                "axes must be of length 1.")
97 98 99 100

        reshaper = [1, ] * len(total_shape)
        reshaper[axes[0]] = self.shape[0]

101
        weight = self.rho.reshape(reshaper)
102 103
        if power != 1:
            weight = weight ** power
104 105 106 107 108 109

        if inplace:
            x *= weight
            result_x = x
        else:
            result_x = x*weight
110 111 112

        return result_x

113
    def get_distance_array(self, distribution_strategy):
114 115 116 117
        result = d2o.distributed_data_object(
                                self.kindex,
                                distribution_strategy=distribution_strategy)
        return result
theos's avatar
theos committed
118

119
    def get_fft_smoothing_kernel_function(self, sigma):
120
        raise NotImplementedError(
121
            "There is no fft smoothing function for PowerSpace.")
theos's avatar
theos committed
122

123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
    # ---Added properties and methods---

    @property
    def harmonic_domain(self):
        return self._harmonic_domain

    @property
    def log(self):
        return self._log

    @property
    def nbin(self):
        return self._nbin

    @property
    def binbounds(self):
        return self._binbounds

    @property
    def pindex(self):
        return self._pindex

    @property
    def kindex(self):
        return self._kindex

    @property
    def rho(self):
        return self._rho
152

153 154 155 156 157 158 159
    @property
    def pundex(self):
        return self._pundex

    @property
    def k_array(self):
        return self._k_array
160 161 162 163

    # ---Serialization---

    def _to_hdf5(self, hdf5_group):
Jait Dixit's avatar
Jait Dixit committed
164 165 166 167 168 169 170
        hdf5_group['log'] = self.log
        hdf5_group['nbin'] = pickle.dumps(self.nbin)
        hdf5_group['binbounds'] = pickle.dumps(self.binbounds)
        hdf5_group['kindex'] = self.kindex
        hdf5_group['rho'] = self.rho
        hdf5_group['pundex'] = self.pundex
        hdf5_group['dtype'] = pickle.dumps(self.dtype)
171 172 173 174 175 176 177 178 179

        return {
            'harmonic_domain': self.harmonic_domain,
            'pindex': self.pindex,
            'k_array': self.k_array
        }

    @classmethod
    def _from_hdf5(cls, hdf5_group, loopback_get):
Jait Dixit's avatar
Jait Dixit committed
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
        # make an empty PowerSpace object
        new_ps = EmptyPowerSpace()
        # reset class
        new_ps.__class__ = cls
        # set all values
        new_ps.dtype = pickle.loads(hdf5_group['dtype'][()])
        new_ps._harmonic_domain = loopback_get('harmonic_domain')
        new_ps._log = hdf5_group['log'][()]
        new_ps._nbin = pickle.loads(hdf5_group['nbin'][()])
        new_ps._binbounds = pickle.loads(hdf5_group['binbounds'][()])

        new_ps._pindex = loopback_get('pindex')
        new_ps._kindex = hdf5_group['kindex'][:]
        new_ps._rho = hdf5_group['rho'][:]
        new_ps._pundex = hdf5_group['pundex'][:]
        new_ps._k_array = loopback_get('k_array')

        return new_ps


class EmptyPowerSpace(PowerSpace):
    def __init__(self):
        pass