selection_operators.py 8.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Authors: Gordian Edenhofer
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from .linear_operator import LinearOperator


class SliceOperator(LinearOperator):
    """Geometry preserving mask operator

    Takes a field, slices it into the desired shape and returns the values of
    the field in the sliced domain all while preserving the original distances.

    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
        The operator's input domain.
    tgt_shape : tuple of integers or None
        The shape of the target domain with None indicating to copy the shape
        of the original domain for this axis.
    center : bool, optional
        Whether to center the slice that is selected in the input field.
    preserve_dist: bool, optional
        Whether to preserve the distance of the input field.
    """
    def __init__(self, domain, tgt_shape, center=False, preserve_dist=True):
        self._domain = DomainTuple.make(domain)
        if len(tgt_shape) != len(self._domain.shape):
            ve = (
                f"shape ({tgt_shape}) is incompatible with the shape of the"
                f" domain ({self._domain.shape})"
            )
            raise ValueError(ve)

        tgt = []
        slc_by_ax = []
        for i, d in enumerate(self._domain):
            if tgt_shape[i] is None or self._domain.shape[i] == tgt_shape[i]:
                tgt += [d]
            elif tgt_shape[i] < self._domain.shape[i]:
                dom_kw = dict()
                if isinstance(d, RGSpace):
                    if preserve_dist:
                        dom_kw["distances"] = d.distances
                    dom_kw["harmonic"] = d.harmonic
                elif not isinstance(d, UnstructuredDomain):
                    # Some domains like HPSpace or LMSPace can not be sliced
                    ve = f"{d.__class__.__name__} can not be sliced"
                    raise ValueError(ve)
                tgt += [d.__class__(tgt_shape[i], **dom_kw)]
            else:
                ve = (
                    f"domain axes ({d}) is smaller than the target shape"
                    f"{tgt_shape[i]}"
                )
                raise ValueError(ve)

            if center:
                slc_start = np.floor(
                    (self._domain.shape[i] - tgt_shape[i]) / 2.
                ).astype(int)
                slc_end = slc_start + tgt_shape[i]
            else:
                slc_start = 0
                slc_end = tgt_shape[i]
            slc_by_ax += [slice(slc_start, slc_end)]

        self._slc_by_ax = tuple(slc_by_ax)
        self._target = DomainTuple.make(tgt)
        self._capability = self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
        self._check_input(x, mode)
        x = x.val
        if mode == self.TIMES:
            res = x[self._slc_by_ax]
            return Field.from_raw(self.target, res)
        res = np.zeros(self.domain.shape, x.dtype)
        res[self._slc_by_ax] = x
        return Field.from_raw(self.domain, res)

    def __str__(self):
        ss = (
            f"{self.__class__.__name__}"
            f"({self.domain.shape} -> {self.target.shape})"
        )
        return ss


class SplitOperator(LinearOperator):
    """Split a single field into a multi-field

    Takes a field, selects the desired entries for each multi-field key and
    puts the result into a multi-field. Along sliced axis, the domain will
    be replaced by an UnstructuredDomain as no distance measures are preserved.

    Note, slices may intersect, i.e. slices may reference the same input
    multiple times if the `intersecting_slices` option is set. However, a
    single field in the output may not contain the same part of the input more
    than once.

    Parameters
    ----------
    domain : Domain, DomainTuple or tuple of Domain
        The operator's input domain.
    slices_by_key : dict{key: tuple of integers or None}
        The key-value pairs of which the values indicate the parts to be
        selected. The result will be a multi-field with the given keys as
        entries and the selected slices of the domain as values. `None`
        indicates to select the whole input along this axis.
    intersecting_slices : bool, optional
        Tells the operator whether slices may contain intersections. If true,
        the adjoint is constructed a little less efficiently.  Set this
        parameter to `False` to gain a little more efficiency.
    """
    def __init__(self, domain, slices_by_key, intersecting_slices=True):
        self._domain = DomainTuple.make(domain)
        self._intersec_slc = intersecting_slices

        tgt = dict()
        self._k_slc = dict()
        for k, slc in slices_by_key.items():
            if len(slc) > len(self._domain):
                ve = f"slice at key {k!r} has more dimensions than the input"
                raise ValueError(ve)
            k_tgt = []
            k_slc_by_ax = []
            for i, d in enumerate(self._domain):
                if i >= len(slc) or slc[i] is None or (
                    isinstance(slc[i], slice) and slc[i] == slice(None)
                ):
                    k_tgt += [d]
                    k_slc_by_ax += [slice(None)]
                elif isinstance(slc[i], slice):
                    start = slc[i].start if slc[i].start is not None else 0
                    stop = slc[i].stop if slc[i].stop is not None else d.size
                    step = slc[i].step if slc[i].step is not None else 1
                    frac = np.floor((stop - start) / np.abs(step))
                    k_tgt += [UnstructuredDomain(frac.astype(int))]
                    k_slc_by_ax += [slc[i]]
                elif isinstance(slc[i],
                                np.ndarray) and slc[i].dtype is np.dtype(bool):
                    if slc[i].size != d.size:
                        ve = (
                            "shape mismatch between desired slice {slc[i]}"
                            "and the shape of the domain {d.size}"
                        )
                        raise ValueError(ve)
                    k_tgt += [UnstructuredDomain(slc[i].sum())]
                    k_slc_by_ax += [slc[i]]
                elif isinstance(slc[i], (tuple, list, np.ndarray)):
                    k_tgt += [UnstructuredDomain(len(slc[i]))]
                    k_slc_by_ax += [slc[i]]
                elif isinstance(slc[i], int):
                    k_slc_by_ax += [slc[i]]
                else:
                    ve = f"invalid type for specifying a slice; got {slc[i]}"
                    raise ValueError(ve)
            tgt[k] = DomainTuple.make(k_tgt)
            self._k_slc[k] = tuple(k_slc_by_ax)

        self._target = MultiDomain.make(tgt)
        self._capability = self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
        self._check_input(x, mode)
        x = x.val
        if mode == self.TIMES:
            res = dict()
            for k, slc in self._k_slc.items():
                res[k] = x[slc]
            return MultiField.from_raw(self.target, res)

        # Note, not-selected parts must be zero. Hence, using the quicker
        # `np.empty` method is unfortunately not possible
        res = np.zeros(self.domain.shape, tuple(x.values())[0].dtype)
        if self._intersec_slc:
            for k, slc in self._k_slc.items():
                # Mind the `+` here for coping with intersections
                res[slc] += x[k]
            return Field.from_raw(self.domain, res)
        for k, slc in self._k_slc.items():
            res[slc] = x[k]
        return Field.from_raw(self.domain, res)

    def __str__(self):
        return f"{self.__class__.__name__} {self._target.keys()!r} <-"