smoothing_operator.py 8.98 KB
Newer Older
Jait Dixit's avatar
Jait Dixit committed
1
2
3
4
5
import numpy as np

import nifty.nifty_utilities as utilities
from nifty.operators.endomorphic_operator import EndomorphicOperator
from nifty.operators.fft_operator import FFTOperator
6
7
import smooth_util as su
from d2o import STRATEGIES
Jait Dixit's avatar
Jait Dixit committed
8

9

10
class SmoothingOperator(EndomorphicOperator):
Jait Dixit's avatar
Jait Dixit committed
11
    # ---Overwritten properties and methods---
12
13
    def __init__(self, domain=(), field_type=(), sigma=0,
                 log_distances=False):
14
15
16

        self._domain = self._parse_domain(domain)
        self._field_type = self._parse_field_type(field_type)
Jait Dixit's avatar
Jait Dixit committed
17
18
19

        if len(self.domain) != 1:
            raise ValueError(
20

21
22
                'ERROR: SmoothOperator accepts only exactly one '
                'space as input domain.'
Jait Dixit's avatar
Jait Dixit committed
23
            )
Jait Dixit's avatar
Jait Dixit committed
24
25

        if self.field_type != ():
26
            raise ValueError(
Jait Dixit's avatar
Jait Dixit committed
27
                'ERROR: SmoothOperator field-type must be an '
Jait Dixit's avatar
Jait Dixit committed
28
                'empty tuple.'
29
            )
Jait Dixit's avatar
Jait Dixit committed
30

31
32
33
34
        self.sigma = sigma
        self.log_distances = log_distances

        self._direct_smoothing_width = 2.
Jait Dixit's avatar
Jait Dixit committed
35
36

    def _inverse_times(self, x, spaces, types):
37
        return self._smoothing_helper(x, spaces, inverse=True)
Jait Dixit's avatar
Jait Dixit committed
38
39

    def _times(self, x, spaces, types):
40
        return self._smoothing_helper(x, spaces, inverse=False)
Jait Dixit's avatar
Jait Dixit committed
41

Jait Dixit's avatar
Jait Dixit committed
42
    # ---Mandatory properties and methods---
43
44
45
46
47
48
49
50
    @property
    def domain(self):
        return self._domain

    @property
    def field_type(self):
        return self._field_type

Jait Dixit's avatar
Jait Dixit committed
51
52
53
    @property
    def implemented(self):
        return True
Jait Dixit's avatar
Jait Dixit committed
54

Jait Dixit's avatar
Jait Dixit committed
55
56
    @property
    def symmetric(self):
57
        return False
Jait Dixit's avatar
Jait Dixit committed
58

Jait Dixit's avatar
Jait Dixit committed
59
60
61
    @property
    def unitary(self):
        return False
Jait Dixit's avatar
Jait Dixit committed
62
63

    # ---Added properties and methods---
64

Jait Dixit's avatar
Jait Dixit committed
65
66
67
68
    @property
    def sigma(self):
        return self._sigma

69
70
71
72
73
74
75
76
77
78
79
80
81
    @sigma.setter
    def sigma(self, sigma):
        self._sigma = np.float(sigma)

    @property
    def log_distances(self):
        return self._log_distances

    @log_distances.setter
    def log_distances(self, log_distances):
        self._log_distances = bool(log_distances)

    def _smoothing_helper(self, x, spaces, inverse):
theos's avatar
theos committed
82
83
84
85
86
87
88
89
90
91
        if self.sigma == 0:
            return x.copy()

        # the domain of the smoothing operator contains exactly one space.
        # Hence, if spaces is None, but we passed LinearOperator's
        # _check_input_compatibility, we know that x is also solely defined
        # on that space
        if spaces is None:
            spaces = (0,)
        else:
Jait Dixit's avatar
Jait Dixit committed
92
93
            spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

94
95
96
97
98
99
100
        try:
            result = self._fft_smoothing(x, spaces, inverse)
        except ValueError:
            result = self._direct_smoothing(x, spaces, inverse)
        return result

    def _fft_smoothing(self, x, spaces, inverse):
theos's avatar
theos committed
101
        Transformator = FFTOperator(x.domain[spaces[0]])
Jait Dixit's avatar
Jait Dixit committed
102

theos's avatar
theos committed
103
104
105
106
107
        # transform to the (global-)default codomain and perform all remaining
        # steps therein
        transformed_x = Transformator(x, spaces=spaces)
        codomain = transformed_x.domain[spaces[0]]
        coaxes = transformed_x.domain_axes[spaces[0]]
108

theos's avatar
theos committed
109
110
111
        # create the kernel using the knowledge of codomain about itself
        axes_local_distribution_strategy = \
            transformed_x.val.get_axes_local_distribution_strategy(axes=coaxes)
Jait Dixit's avatar
Jait Dixit committed
112

113
        kernel = codomain.get_distance_array(
114
115
116
117
118
            distribution_strategy=axes_local_distribution_strategy)

        if self.log_distances:
            kernel.apply_scalar_function(np.log, inplace=True)

theos's avatar
theos committed
119
        kernel.apply_scalar_function(
120
            codomain.get_fft_smoothing_kernel_function(self.sigma),
theos's avatar
theos committed
121
            inplace=True)
Jait Dixit's avatar
Jait Dixit committed
122

theos's avatar
theos committed
123
124
125
126
127
        # now, apply the kernel to transformed_x
        # this is done node-locally utilizing numpys reshaping in order to
        # apply the kernel to the correct axes
        local_transformed_x = transformed_x.val.get_local_data(copy=False)
        local_kernel = kernel.get_local_data(copy=False)
Jait Dixit's avatar
Jait Dixit committed
128

129
        reshaper = [transformed_x.shape[i] if i in coaxes else 1
theos's avatar
theos committed
130
131
                    for i in xrange(len(transformed_x.shape))]
        local_kernel = np.reshape(local_kernel, reshaper)
Jait Dixit's avatar
Jait Dixit committed
132

theos's avatar
theos committed
133
134
135
136
137
        # apply the kernel
        if inverse:
            local_transformed_x /= local_kernel
        else:
            local_transformed_x *= local_kernel
Jait Dixit's avatar
Jait Dixit committed
138

theos's avatar
theos committed
139
        transformed_x.val.set_local_data(local_transformed_x, copy=False)
Jait Dixit's avatar
Jait Dixit committed
140

theos's avatar
theos committed
141
        result = Transformator.inverse_times(transformed_x, spaces=spaces)
Jait Dixit's avatar
Jait Dixit committed
142

theos's avatar
theos committed
143
        return result
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

    def _direct_smoothing(self, x, spaces, inverse):
        # infer affected axes
        # we rely on the knowledge, that `spaces` is a tuple with length 1.
        affected_axes = x.domain_axes[spaces[0]]

        axes_local_distribution_strategy = \
            x.val.get_axes_local_distribution_strategy(axes=affected_axes)

        distance_array = x.domain[spaces[0]].get_distance_array(
            distribution_strategy=axes_local_distribution_strategy)

        if self.log_distances:
            distance_array.apply_scalar_function(np.log, inplace=True)

        # collect the local data + ghost cells
        local_data_Q = False

        if x.distribution_strategy == 'not':
            local_data_Q = True
        elif x.distribution_strategy in STRATEGIES['slicing']:
            # infer the local start/end based on the slicing information of
            # x's d2o. Only gets non-trivial for axis==0.
            if 0 not in affected_axes:
                local_data_Q = True
            else:
                # we rely on the fact, that the content of x.domain_axes is
                # sorted
                true_starts = [x.val.distributor.local_start]
                true_starts += [0] * (len(affected_axes) - 1)
                true_ends = [x.val.distributor.local_end]
                true_ends += [x.shape[i] for i in affected_axes[1:]]

                augmented_start = max(0,
                                      true_starts[0] -
                                      self._direct_smoothing_width * self.sigma)
                augmented_end = min(x.shape[affected_axes[0]],
                                    true_ends[0] +
                                    self._direct_smoothing_width * self.sigma)
                augmented_slice = slice(augmented_start, augmented_end)
                augmented_data = x.val.get_data(augmented_slice,
                                                local_keys=True,
                                                copy=False)
                augmented_data = augmented_data.get_local_data(copy=False)

                augmented_distance_array = distance_array.get_data(
                    augmented_slice,
                    local_keys=True,
                    copy=False)
                augmented_distance_array = \
                    augmented_distance_array.get_local_data(copy=False)

        else:
            raise ValueError(about._errors.cstring(
                "ERROR: Direct smoothing not implemented for given"
                "distribution strategy."))

        if local_data_Q:
            # if the needed data resides on the nodes already, the necessary
            # are the same; no matter what the distribution strategy was.
            augmented_data = x.val.get_local_data(copy=False)
            augmented_distance_array = \
                distance_array.get_local_data(copy=False)
            true_starts = [0] * len(affected_axes)
            true_ends = [x.shape[i] for i in affected_axes]

        # perform the convolution along the affected axes
        local_result = augmented_data
        for index in range(len(affected_axes)):
            data_axis = affected_axes[index]
            distances_axis = index
            true_start = true_starts[index]
            true_end = true_ends[index]

            local_result = self._direct_smoothing_single_axis(
                local_result,
                data_axis,
                augmented_distance_array,
                distances_axis,
                true_start,
                true_end,
                inverse)

        result = x.copy_empty()
        result.val.set_local_data(local_result, copy=False)
        return result

    def _direct_smoothing_single_axis(self, data, data_axis, distances,
                                      distances_axis, true_start, true_end,
                                      inverse):
        if inverse:
            true_sigma = 1 / self.sigma
        else:
            true_sigma = self.sigma

        if (data.dtype == np.dtype('float32')):
            smoothed_data = su.apply_along_axis_f(data_axis, data,
                                                  startindex=true_start,
                                                  endindex=true_end,
                                                  distances=distances,
                                                  smooth_length=true_sigma)
        else:
            smoothed_data = su.apply_along_axis(data_axis, data,
                                                startindex=true_start,
                                                endindex=true_end,
                                                distances=distances,
                                                smooth_length=true_sigma)
        return smoothed_data