power_space.py 3.87 KB
Newer Older
theos's avatar
theos committed
1
2
# -*- coding: utf-8 -*-

theos's avatar
theos committed
3
import numpy as np
4
from d2o import STRATEGIES
theos's avatar
theos committed
5

6
from nifty.config import about
7
from nifty.spaces.space import Space
theos's avatar
theos committed
8
from power_space_paradict import PowerSpaceParadict
9
from nifty.nifty_utilities import cast_axis_to_tuple
theos's avatar
theos committed
10
11
12


class PowerSpace(Space):
theos's avatar
theos committed
13
14
15
16
    def __init__(self, pindex, kindex, rho, config,
                 harmonic_domain, dtype=np.dtype('float'), **kwargs):
        # the **kwargs is in the __init__ in order to enable a
        # PowerSpace(**power_index) initialization
theos's avatar
theos committed
17
        self.dtype = np.dtype(dtype)
theos's avatar
theos committed
18
19
20
21
22
        self.paradict = PowerSpaceParadict(pindex=pindex,
                                           kindex=kindex,
                                           rho=rho,
                                           config=config,
                                           harmonic_domain=harmonic_domain)
theos's avatar
theos committed
23
        self._harmonic = True
24

25
26
    @property
    def shape(self):
theos's avatar
theos committed
27
        return self.paradict['kindex'].shape
28

29
30
31
32
33
34
35
36
37
    @property
    def dim(self):
        return self.shape[0]

    @property
    def total_volume(self):
        # every power-pixel has a volume of 1
        return reduce(lambda x, y: x*y, self.paradict['pindex'].shape)

38
    def weight(self, x, power=1, axes=None, inplace=False):
39
40
41
42
43
44
45
46
47
48
49
50
51
        total_shape = x.shape

        axes = cast_axis_to_tuple(axes, len(total_shape))
        if len(axes) != 1:
            raise ValueError(about._errors.cstring(
                "ERROR: axes must be of length 1."))

        reshaper = [1, ] * len(total_shape)
        reshaper[axes[0]] = self.shape[0]

        weight = self.paradict['rho'].reshape(reshaper)
        if power != 1:
            weight = weight ** power
52
53
54
55
56
57

        if inplace:
            x *= weight
            result_x = x
        else:
            result_x = x*weight
58
59
60
61
62
63
64

        return result_x

    def compute_k_array(self, distribution_strategy):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: There is no k_array implementation for PowerSpace."))

65
    def calculate_power_spectrum(self, x, axes=None):
theos's avatar
theos committed
66
67
68
        fieldabs = abs(x)
        fieldabs **= 2

theos's avatar
theos committed
69
        pindex = self.paradict['pindex']
70
71
72
73
74
75
76
77
78
        if axes is not None:
            pindex = self._shape_up_pindex(
                                    pindex=pindex,
                                    target_shape=x.shape,
                                    target_strategy=x.distribution_strategy,
                                    axes=axes)
        power_spectrum = pindex.bincount(weights=fieldabs,
                                         axis=axes)

theos's avatar
theos committed
79
        rho = self.paradict['rho']
80
81
82
83
84
85
        if axes is not None:
            new_rho_shape = [1, ] * len(power_spectrum.shape)
            new_rho_shape[axes[0]] = len(rho)
            rho = rho.reshape(new_rho_shape)
        power_spectrum /= rho

theos's avatar
theos committed
86
        power_spectrum **= 0.5
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        return power_spectrum

    def _shape_up_pindex(self, pindex, target_shape, target_strategy, axes):
        if pindex.distribution_strategy not in STRATEGIES['global']:
            raise ValueError("ERROR: pindex's distribution strategy must be "
                             "global-type")

        if pindex.distribution_strategy in STRATEGIES['slicing']:
            if ((0 not in axes) or
                    (target_strategy is not pindex.distribution_strategy)):
                raise ValueError(
                    "ERROR: A slicing distributor shall not be reshaped to "
                    "something non-sliced.")

        semiscaled_shape = [1, ] * len(target_shape)
        for i in axes:
            semiscaled_shape[i] = target_shape[i]
        local_data = pindex.get_local_data(copy=False)
        semiscaled_local_data = local_data.reshape(semiscaled_shape)
        result_obj = pindex.copy_empty(global_shape=target_shape,
                                       distribution_strategy=target_strategy)
        result_obj.set_full_data(semiscaled_local_data, copy=False)

        return result_obj