power_projection_operator.py 3.08 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
from ... import Field
2
from ..linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
3
from ...spaces.power_space import PowerSpace
4

Martin Reinecke's avatar
Martin Reinecke committed
5

6
class PowerProjection(LinearOperator):
Martin Reinecke's avatar
Martin Reinecke committed
7
    def __init__(self, domain, target, default_spaces=None):
8
9
        self._domain = self._parse_domain(domain)
        self._target = self._parse_domain(target)
Martin Reinecke's avatar
Martin Reinecke committed
10
        if len(self._domain) != 1 or len(self._target) != 1:
Martin Reinecke's avatar
Martin Reinecke committed
11
12
13
14
15
16
            raise ValueError("Operator only works over one space")
        if not self._domain[0].harmonic:
            raise ValueError("domain must be a harmonic space")
        if not isinstance(self._target[0], PowerSpace):
            raise ValueError("target must be a PowerSpace")
        self.pindex = self.target[0].pindex
17
18
        super(PowerProjection, self).__init__(default_spaces)

Martin Reinecke's avatar
Martin Reinecke committed
19
    def _times(self, x, spaces):
Martin Reinecke's avatar
Martin Reinecke committed
20
21
22
        if spaces is None:
            spaces = 0
        projected_x = self.pindex.bincount(
Martin Reinecke's avatar
Martin Reinecke committed
23
            weights=x.weight(1, spaces=spaces).val.real,
Martin Reinecke's avatar
Martin Reinecke committed
24
25
26
27
            axis=x.domain_axes[spaces])
        tgt_domain = list(x.domain)
        tgt_domain[spaces] = self._target[0]
        y = Field(tgt_domain, val=projected_x).weight(-1, spaces=spaces)
28
29
        return y

Martin Reinecke's avatar
Martin Reinecke committed
30
    def _adjoint_times(self, x, spaces):
31
32
        if spaces is None:
            spaces = 0
Martin Reinecke's avatar
Martin Reinecke committed
33
34
35
36
        tgt_domain = list(x.domain)
        tgt_domain[spaces] = self._domain[0]
        y = Field(tgt_domain, val=1.)
        axes = x.domain_axes[spaces]
37
38
39
40
        spec = x.val.get_full_data()

        spec = self._spec_to_rescaler(spec, spaces, axes)
        y.val.apply_scalar_function(lambda x: x * spec.real,
Martin Reinecke's avatar
Martin Reinecke committed
41
                                    inplace=True)
42
43
44

        return y

Martin Reinecke's avatar
Martin Reinecke committed
45
    def _spec_to_rescaler(self, spec, power_space_index, axes):
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        # weight the random fields with the power spectrum
        # therefore get the pindex from the power space
        # take the local data from pindex. This data must be compatible to the
        # local data of the field given the slice of the PowerSpace
        # local_distribution_strategy = \
        #     result.val.get_axes_local_distribution_strategy(
        #         result.domain_axes[power_space_index])
        #
        # if self.pindex.distribution_strategy is not local_distribution_strategy:
        #     self.logger.warn(
        #         "The distribution_stragey of pindex does not fit the "
        #         "slice_local distribution strategy of the synthesized field.")

        # Now use numpy advanced indexing in order to put the entries of the
        # power spectrum into the appropriate places of the pindex array.
        # Do this for every 'pindex-slice' in parallel using the 'slice(None)'s
        local_pindex = self.pindex.get_local_data(copy=False)

        local_blow_up = [slice(None)]*len(spec.shape)
Martin Reinecke's avatar
Martin Reinecke committed
66
        local_blow_up[axes[power_space_index]] = local_pindex
67
68
69
        # here, the power_spectrum is distributed into the new shape
        local_rescaler = spec[local_blow_up]
        return local_rescaler
Martin Reinecke's avatar
Martin Reinecke committed
70

71
72
73
74
75
76
77
78
79
80
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    @property
    def unitary(self):
Martin Reinecke's avatar
Martin Reinecke committed
81
        return False