power_space.py 2.73 KB
Newer Older
theos's avatar
theos committed
1
2
# -*- coding: utf-8 -*-

theos's avatar
theos committed
3
import numpy as np
4
from d2o import STRATEGIES
theos's avatar
theos committed
5

theos's avatar
theos committed
6
from nifty.space import Space
theos's avatar
theos committed
7
from nifty.nifty_paradict import power_space_paradict
theos's avatar
theos committed
8
9
10


class PowerSpace(Space):
theos's avatar
theos committed
11
12
13
14
    def __init__(self, pindex, kindex, rho, config,
                 harmonic_domain, dtype=np.dtype('float'), **kwargs):
        # the **kwargs is in the __init__ in order to enable a
        # PowerSpace(**power_index) initialization
theos's avatar
theos committed
15
        self.dtype = np.dtype(dtype)
theos's avatar
theos committed
16
17
18
19
20
21
        self.paradict = power_space_paradict(pindex=pindex,
                                             kindex=kindex,
                                             rho=rho,
                                             config=config,
                                             harmonic_domain=harmonic_domain)
        self._harmonic = True
22

23
24
    @property
    def shape(self):
theos's avatar
theos committed
25
        return self.paradict['kindex'].shape
26
27
28

    def calculate_power_spectrum(self, x, axes=None):
        fieldabs = abs(x)**2
theos's avatar
theos committed
29
        pindex = self.paradict['pindex']
30
31
32
33
34
35
36
37
38
        if axes is not None:
            pindex = self._shape_up_pindex(
                                    pindex=pindex,
                                    target_shape=x.shape,
                                    target_strategy=x.distribution_strategy,
                                    axes=axes)
        power_spectrum = pindex.bincount(weights=fieldabs,
                                         axis=axes)

theos's avatar
theos committed
39
        rho = self.paradict['rho']
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        if axes is not None:
            new_rho_shape = [1, ] * len(power_spectrum.shape)
            new_rho_shape[axes[0]] = len(rho)
            rho = rho.reshape(new_rho_shape)
        power_spectrum /= rho

        return power_spectrum

    def _shape_up_pindex(self, pindex, target_shape, target_strategy, axes):
        if pindex.distribution_strategy not in STRATEGIES['global']:
            raise ValueError("ERROR: pindex's distribution strategy must be "
                             "global-type")

        if pindex.distribution_strategy in STRATEGIES['slicing']:
            if ((0 not in axes) or
                    (target_strategy is not pindex.distribution_strategy)):
                raise ValueError(
                    "ERROR: A slicing distributor shall not be reshaped to "
                    "something non-sliced.")

        semiscaled_shape = [1, ] * len(target_shape)
        for i in axes:
            semiscaled_shape[i] = target_shape[i]
        local_data = pindex.get_local_data(copy=False)
        semiscaled_local_data = local_data.reshape(semiscaled_shape)
        result_obj = pindex.copy_empty(global_shape=target_shape,
                                       distribution_strategy=target_strategy)
        result_obj.set_full_data(semiscaled_local_data, copy=False)

        return result_obj