power_space.py 3.71 KB
Newer Older
theos's avatar
theos committed
1
2
# -*- coding: utf-8 -*-

theos's avatar
theos committed
3
import numpy as np
4
from d2o import STRATEGIES
theos's avatar
theos committed
5

6
from nifty.config import about
theos's avatar
theos committed
7
from nifty.space import Space
theos's avatar
theos committed
8
from power_space_paradict import PowerSpaceParadict
9
from nifty.nifty_utilities import cast_axis_to_tuple
theos's avatar
theos committed
10
11
12


class PowerSpace(Space):
theos's avatar
theos committed
13
14
15
16
    def __init__(self, pindex, kindex, rho, config,
                 harmonic_domain, dtype=np.dtype('float'), **kwargs):
        # the **kwargs is in the __init__ in order to enable a
        # PowerSpace(**power_index) initialization
theos's avatar
theos committed
17
        self.dtype = np.dtype(dtype)
theos's avatar
theos committed
18
19
20
21
22
        self.paradict = PowerSpaceParadict(pindex=pindex,
                                           kindex=kindex,
                                           rho=rho,
                                           config=config,
                                           harmonic_domain=harmonic_domain)
theos's avatar
theos committed
23
        self._harmonic = True
24

25
26
    @property
    def shape(self):
theos's avatar
theos committed
27
        return self.paradict['kindex'].shape
28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    @property
    def dim(self):
        return self.shape[0]

    @property
    def total_volume(self):
        # every power-pixel has a volume of 1
        return reduce(lambda x, y: x*y, self.paradict['pindex'].shape)

    def weight(self, x, power=1, axes=None):
        total_shape = x.shape

        axes = cast_axis_to_tuple(axes, len(total_shape))
        if len(axes) != 1:
            raise ValueError(about._errors.cstring(
                "ERROR: axes must be of length 1."))

        reshaper = [1, ] * len(total_shape)
        reshaper[axes[0]] = self.shape[0]

        weight = self.paradict['rho'].reshape(reshaper)
        if power != 1:
            weight = weight ** power
        result_x = x * weight

        return result_x

    def compute_k_array(self, distribution_strategy):
        raise NotImplementedError(about._errors.cstring(
            "ERROR: There is no k_array implementation for PowerSpace."))

60
61
    def calculate_power_spectrum(self, x, axes=None):
        fieldabs = abs(x)**2
theos's avatar
theos committed
62
        pindex = self.paradict['pindex']
63
64
65
66
67
68
69
70
71
        if axes is not None:
            pindex = self._shape_up_pindex(
                                    pindex=pindex,
                                    target_shape=x.shape,
                                    target_strategy=x.distribution_strategy,
                                    axes=axes)
        power_spectrum = pindex.bincount(weights=fieldabs,
                                         axis=axes)

theos's avatar
theos committed
72
        rho = self.paradict['rho']
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        if axes is not None:
            new_rho_shape = [1, ] * len(power_spectrum.shape)
            new_rho_shape[axes[0]] = len(rho)
            rho = rho.reshape(new_rho_shape)
        power_spectrum /= rho

        return power_spectrum

    def _shape_up_pindex(self, pindex, target_shape, target_strategy, axes):
        if pindex.distribution_strategy not in STRATEGIES['global']:
            raise ValueError("ERROR: pindex's distribution strategy must be "
                             "global-type")

        if pindex.distribution_strategy in STRATEGIES['slicing']:
            if ((0 not in axes) or
                    (target_strategy is not pindex.distribution_strategy)):
                raise ValueError(
                    "ERROR: A slicing distributor shall not be reshaped to "
                    "something non-sliced.")

        semiscaled_shape = [1, ] * len(target_shape)
        for i in axes:
            semiscaled_shape[i] = target_shape[i]
        local_data = pindex.get_local_data(copy=False)
        semiscaled_local_data = local_data.reshape(semiscaled_shape)
        result_obj = pindex.copy_empty(global_shape=target_shape,
                                       distribution_strategy=target_strategy)
        result_obj.set_full_data(semiscaled_local_data, copy=False)

        return result_obj