smoothing_operator.py 9.74 KB
Newer Older
Jait Dixit's avatar
Jait Dixit committed
1
2
3
4
5
import numpy as np

import nifty.nifty_utilities as utilities
from nifty.operators.endomorphic_operator import EndomorphicOperator
from nifty.operators.fft_operator import FFTOperator
6
7
import smooth_util as su
from d2o import STRATEGIES
Jait Dixit's avatar
Jait Dixit committed
8

9

10
class SmoothingOperator(EndomorphicOperator):
Jait Dixit's avatar
Jait Dixit committed
11
    # ---Overwritten properties and methods---
12
    def __init__(self, domain=(), sigma=0, log_distances=False):
13
14

        self._domain = self._parse_domain(domain)
Jait Dixit's avatar
Jait Dixit committed
15
16
17

        if len(self.domain) != 1:
            raise ValueError(
18

19
20
                'ERROR: SmoothOperator accepts only exactly one '
                'space as input domain.'
Jait Dixit's avatar
Jait Dixit committed
21
            )
Jait Dixit's avatar
Jait Dixit committed
22

23
24
25
        self.sigma = sigma
        self.log_distances = log_distances

26
        self._direct_smoothing_width = 3.
Jait Dixit's avatar
Jait Dixit committed
27

28
    def _inverse_times(self, x, spaces):
29
        return self._smoothing_helper(x, spaces, inverse=True)
Jait Dixit's avatar
Jait Dixit committed
30

31
    def _times(self, x, spaces):
32
        return self._smoothing_helper(x, spaces, inverse=False)
Jait Dixit's avatar
Jait Dixit committed
33

Jait Dixit's avatar
Jait Dixit committed
34
    # ---Mandatory properties and methods---
35
36
37
38
    @property
    def domain(self):
        return self._domain

Jait Dixit's avatar
Jait Dixit committed
39
40
41
    @property
    def implemented(self):
        return True
Jait Dixit's avatar
Jait Dixit committed
42

Jait Dixit's avatar
Jait Dixit committed
43
44
    @property
    def symmetric(self):
theos's avatar
theos committed
45
        return True
Jait Dixit's avatar
Jait Dixit committed
46

Jait Dixit's avatar
Jait Dixit committed
47
48
49
    @property
    def unitary(self):
        return False
Jait Dixit's avatar
Jait Dixit committed
50
51

    # ---Added properties and methods---
52

Jait Dixit's avatar
Jait Dixit committed
53
54
55
56
    @property
    def sigma(self):
        return self._sigma

57
58
59
60
61
62
63
64
65
66
67
68
69
    @sigma.setter
    def sigma(self, sigma):
        self._sigma = np.float(sigma)

    @property
    def log_distances(self):
        return self._log_distances

    @log_distances.setter
    def log_distances(self, log_distances):
        self._log_distances = bool(log_distances)

    def _smoothing_helper(self, x, spaces, inverse):
theos's avatar
theos committed
70
71
72
73
74
75
76
77
78
79
        if self.sigma == 0:
            return x.copy()

        # the domain of the smoothing operator contains exactly one space.
        # Hence, if spaces is None, but we passed LinearOperator's
        # _check_input_compatibility, we know that x is also solely defined
        # on that space
        if spaces is None:
            spaces = (0,)
        else:
Jait Dixit's avatar
Jait Dixit committed
80
81
            spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

82
83
84
85
86
87
88
        try:
            result = self._fft_smoothing(x, spaces, inverse)
        except ValueError:
            result = self._direct_smoothing(x, spaces, inverse)
        return result

    def _fft_smoothing(self, x, spaces, inverse):
theos's avatar
theos committed
89
        Transformator = FFTOperator(x.domain[spaces[0]])
Jait Dixit's avatar
Jait Dixit committed
90

theos's avatar
theos committed
91
92
93
94
95
        # transform to the (global-)default codomain and perform all remaining
        # steps therein
        transformed_x = Transformator(x, spaces=spaces)
        codomain = transformed_x.domain[spaces[0]]
        coaxes = transformed_x.domain_axes[spaces[0]]
96

theos's avatar
theos committed
97
98
99
        # create the kernel using the knowledge of codomain about itself
        axes_local_distribution_strategy = \
            transformed_x.val.get_axes_local_distribution_strategy(axes=coaxes)
Jait Dixit's avatar
Jait Dixit committed
100

101
        kernel = codomain.get_distance_array(
102
103
104
105
106
            distribution_strategy=axes_local_distribution_strategy)

        if self.log_distances:
            kernel.apply_scalar_function(np.log, inplace=True)

theos's avatar
theos committed
107
        kernel.apply_scalar_function(
108
            codomain.get_fft_smoothing_kernel_function(self.sigma),
theos's avatar
theos committed
109
            inplace=True)
Jait Dixit's avatar
Jait Dixit committed
110

theos's avatar
theos committed
111
112
113
114
115
        # now, apply the kernel to transformed_x
        # this is done node-locally utilizing numpys reshaping in order to
        # apply the kernel to the correct axes
        local_transformed_x = transformed_x.val.get_local_data(copy=False)
        local_kernel = kernel.get_local_data(copy=False)
Jait Dixit's avatar
Jait Dixit committed
116

117
        reshaper = [transformed_x.shape[i] if i in coaxes else 1
theos's avatar
theos committed
118
119
                    for i in xrange(len(transformed_x.shape))]
        local_kernel = np.reshape(local_kernel, reshaper)
Jait Dixit's avatar
Jait Dixit committed
120

theos's avatar
theos committed
121
122
123
124
125
        # apply the kernel
        if inverse:
            local_transformed_x /= local_kernel
        else:
            local_transformed_x *= local_kernel
Jait Dixit's avatar
Jait Dixit committed
126

theos's avatar
theos committed
127
        transformed_x.val.set_local_data(local_transformed_x, copy=False)
Jait Dixit's avatar
Jait Dixit committed
128

theos's avatar
theos committed
129
130
131
132
        smoothed_x = Transformator.inverse_times(transformed_x, spaces=spaces)

        result = x.copy_empty()
        result.set_val(smoothed_x, copy=False)
Jait Dixit's avatar
Jait Dixit committed
133

theos's avatar
theos committed
134
        return result
135
136
137
138
139
140

    def _direct_smoothing(self, x, spaces, inverse):
        # infer affected axes
        # we rely on the knowledge, that `spaces` is a tuple with length 1.
        affected_axes = x.domain_axes[spaces[0]]

141
142
143
144
145
        if len(affected_axes) > 1:
            raise ValueError("By this implementation only one-dimensional "
                             "spaces can be smoothed directly.")

        affected_axis = affected_axes[0]
146
147

        distance_array = x.domain[spaces[0]].get_distance_array(
148
149
            distribution_strategy='not')
        distance_array = distance_array.get_local_data(copy=False)
150
151

        if self.log_distances:
152
            np.log(distance_array, out=distance_array)
153
154
155
156
157
158
159
160
161

        # collect the local data + ghost cells
        local_data_Q = False

        if x.distribution_strategy == 'not':
            local_data_Q = True
        elif x.distribution_strategy in STRATEGIES['slicing']:
            # infer the local start/end based on the slicing information of
            # x's d2o. Only gets non-trivial for axis==0.
162
            if 0 != affected_axis:
163
164
                local_data_Q = True
            else:
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
                start_index = x.val.distributor.local_start
                start_distance = distance_array[start_index]
                augmented_start_distance = \
                    (start_distance - self._direct_smoothing_width*self.sigma)
                augmented_start_index = \
                    np.searchsorted(distance_array, augmented_start_distance)
                true_start = start_index - augmented_start_index
                end_index = x.val.distributor.local_end
                end_distance = distance_array[end_index-1]
                augmented_end_distance = \
                    (end_distance + self._direct_smoothing_width*self.sigma)
                augmented_end_index = \
                    np.searchsorted(distance_array, augmented_end_distance)
                true_end = true_start + x.val.distributor.local_length
                augmented_slice = slice(augmented_start_index,
                                        augmented_end_index)

182
183
184
185
186
                augmented_data = x.val.get_data(augmented_slice,
                                                local_keys=True,
                                                copy=False)
                augmented_data = augmented_data.get_local_data(copy=False)

187
                augmented_distance_array = distance_array[augmented_slice]
188
189

        else:
190
191
            raise ValueError("Direct smoothing not implemented for given"
                             "distribution strategy.")
192
193
194
195
196

        if local_data_Q:
            # if the needed data resides on the nodes already, the necessary
            # are the same; no matter what the distribution strategy was.
            augmented_data = x.val.get_local_data(copy=False)
197
198
199
            augmented_distance_array = distance_array
            true_start = 0
            true_end = x.shape[affected_axis]
200
201

        # perform the convolution along the affected axes
202
203
204
205
206
207
208
209
210
        # currently only one axis is supported
        data_axis = affected_axes[0]
        local_result = self._direct_smoothing_single_axis(
                                                    augmented_data,
                                                    data_axis,
                                                    augmented_distance_array,
                                                    true_start,
                                                    true_end,
                                                    inverse)
211
212
213
214
215
        result = x.copy_empty()
        result.val.set_local_data(local_result, copy=False)
        return result

    def _direct_smoothing_single_axis(self, data, data_axis, distances,
216
                                      true_start, true_end, inverse):
217
        if inverse:
218
            true_sigma = 1. / self.sigma
219
220
221
        else:
            true_sigma = self.sigma

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        if data.dtype is np.dtype('float32'):
            distances = distances.astype(np.float32, copy=False)
            smoothed_data = su.apply_along_axis_f(
                                  data_axis, data,
                                  startindex=true_start,
                                  endindex=true_end,
                                  distances=distances,
                                  smooth_length=true_sigma,
                                  smoothing_width=self._direct_smoothing_width)
        elif data.dtype is np.dtype('float64'):
            distances = distances.astype(np.float64, copy=False)
            smoothed_data = su.apply_along_axis(
                                  data_axis, data,
                                  startindex=true_start,
                                  endindex=true_end,
                                  distances=distances,
                                  smooth_length=true_sigma,
                                  smoothing_width=self._direct_smoothing_width)

        elif np.issubdtype(data.dtype, np.complexfloating):
            real = self._direct_smoothing_single_axis(data.real,
                                                      data_axis,
                                                      distances,
                                                      true_start,
                                                      true_end, inverse)
            imag = self._direct_smoothing_single_axis(data.imag,
                                                      data_axis,
                                                      distances,
                                                      true_start,
                                                      true_end, inverse)

            return real + 1j*imag

255
        else:
256
257
            raise TypeError("Dtype %s not supported" % str(data.dtype))

258
        return smoothed_data