power_space.py 3.83 KB
Newer Older
theos's avatar
theos committed
1
2
# -*- coding: utf-8 -*-

theos's avatar
theos committed
3
4
import numpy as np

5
6
from power_index_factory import PowerIndexFactory

7
from nifty.config import about
8
from nifty.spaces.space import Space
9
from nifty.spaces.rg_space import RGSpace
10
from nifty.nifty_utilities import cast_axis_to_tuple
theos's avatar
theos committed
11
12
13


class PowerSpace(Space):
14
15
16

    # ---Overwritten properties and methods---

17
    def __init__(self, harmonic_domain=RGSpace((1,)), distribution_strategy='not',
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
                 log=False, nbin=None, binbounds=None,
                 dtype=np.dtype('float')):

        super(PowerSpace, self).__init__(dtype)
        self._ignore_for_hash += ['_pindex', '_kindex', '_rho']

        if not isinstance(harmonic_domain, Space):
            raise ValueError(about._errors.cstring(
                "ERROR: harmonic_domain must be a Space."))
        if not harmonic_domain.harmonic:
            raise ValueError(about._errors.cstring(
                "ERROR: harmonic_domain must be a harmonic space."))
        self._harmonic_domain = harmonic_domain

        power_index = PowerIndexFactory.get_power_index(
                        domain=self.harmonic_domain,
34
                        distribution_strategy=distribution_strategy,
35
36
37
38
39
40
41
42
43
44
45
46
                        log=log,
                        nbin=nbin,
                        binbounds=binbounds)

        config = power_index['config']
        self._log = config['log']
        self._nbin = config['nbin']
        self._binbounds = config['binbounds']

        self._pindex = power_index['pindex']
        self._kindex = power_index['kindex']
        self._rho = power_index['rho']
47
48
        self._pundex = power_index['pundex']
        self._k_array = power_index['k_array']
49
50
51
52
53
54
55
56
57
58

    def compute_k_array(self, distribution_strategy):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: There is no k_array implementation for PowerSpace."))

    # ---Mandatory properties and methods---

    @property
    def harmonic(self):
        return True
59

60
61
    @property
    def shape(self):
62
        return self.kindex.shape
63

64
65
66
67
68
69
70
    @property
    def dim(self):
        return self.shape[0]

    @property
    def total_volume(self):
        # every power-pixel has a volume of 1
71
72
73
        return reduce(lambda x, y: x*y, self.pindex.shape)

    def copy(self):
74
        distribution_strategy = self.pindex.distribution_strategy
75
        return self.__class__(harmonic_domain=self.harmonic_domain,
76
                              distribution_strategy=distribution_strategy,
77
78
79
80
                              log=self.log,
                              nbin=self.nbin,
                              binbounds=self.binbounds,
                              dtype=self.dtype)
81

82
    def weight(self, x, power=1, axes=None, inplace=False):
83
84
85
86
87
88
89
90
91
92
        total_shape = x.shape

        axes = cast_axis_to_tuple(axes, len(total_shape))
        if len(axes) != 1:
            raise ValueError(about._errors.cstring(
                "ERROR: axes must be of length 1."))

        reshaper = [1, ] * len(total_shape)
        reshaper[axes[0]] = self.shape[0]

93
        weight = self.rho.reshape(reshaper)
94
95
        if power != 1:
            weight = weight ** power
96
97
98
99
100
101

        if inplace:
            x *= weight
            result_x = x
        else:
            result_x = x*weight
102
103
104

        return result_x

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    # ---Added properties and methods---

    @property
    def harmonic_domain(self):
        return self._harmonic_domain

    @property
    def log(self):
        return self._log

    @property
    def nbin(self):
        return self._nbin

    @property
    def binbounds(self):
        return self._binbounds

    @property
    def pindex(self):
        return self._pindex

    @property
    def kindex(self):
        return self._kindex

    @property
    def rho(self):
        return self._rho
134

135
136
137
138
139
140
141
    @property
    def pundex(self):
        return self._pundex

    @property
    def k_array(self):
        return self._k_array