nifty_los.py 13.4 KB
Newer Older
Ultima's avatar
Ultima committed
1
2
# -*- coding: utf-8 -*-

3
4
import numpy as np

5
6
7
8
9
10
11
12
13
from line_integrator import multi_integrator, \
                            gaussian_error_function

from nifty.keepers import about,\
                          global_dependency_injector as gdi
from nifty.nifty_mpi_data import distributed_data_object,\
                                 STRATEGIES
from nifty.nifty_core import point_space,\
                             field
14
15
16
from nifty.rg import rg_space
from nifty.operators import operator

17
MPI = gdi['MPI']
18
19
20
21

class los_response(operator):

    def __init__(self, domain, starts, ends, sigmas_low=None, sigmas_up=None,
22
23
24
                 zero_point=None, error_function=gaussian_error_function,
                 target=None):

25
26
27
28
        if not isinstance(domain, rg_space):
            raise TypeError(about._errors.cstring(
                "ERROR: The domain must be a rg_space instance."))
        self.domain = domain
29
        self.codomain = self.domain.get_codomain()
30

31
32
33
        if callable(error_function):
            self.error_function = error_function
        else:
34
35
36
37
38
39
40
41
42
43
44
            raise ValueError(about._errors.cstring(
                "ERROR: error_function must be callable."))

        (self.starts,
         self.ends,
         self.sigmas_low,
         self.sigmas_up,
         self.zero_point) = self._parse_coordinates(self.domain,
                                                    starts, ends, sigmas_low,
                                                    sigmas_up, zero_point)

45
46
47
48
49
50
51
52
53
54
55
56
57
58
        self.local_weights_and_indices = self._compute_weights_and_indices()

        self.number_of_los = len(self.sigmas_low)

        if target is None:
            self.target = point_space(num=self.number_of_los,
                                      dtype=self.domain.dtype,
                                      datamodel=self.domain.datamodel,
                                      comm=self.domain.comm)
        else:
            self.target = target

        self.cotarget = self.target.get_codomain()

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        self.imp = True
        self.uni = False
        self.sym = False

    def _parse_coordinates(self, domain, starts, ends, sigmas_low, sigmas_up,
                           zero_point):
        # basic sanity checks
        if not isinstance(starts, list):
            raise TypeError(about._errors.cstring(
                "ERROR: starts must be a list instance."))
        if not isinstance(ends, list):
            raise TypeError(about._errors.cstring(
                "ERROR: ends must be a list instance."))
        if not (len(domain.get_shape()) == len(starts) == len(ends)):
            raise ValueError(about._errors.cstring(
                "ERROR: The length of starts and ends must " +
                "be the same as the number of dimension of the domain."))

        number_of_dimensions = len(starts)

        if zero_point is None:
            zero_point = [0.] * number_of_dimensions

        if np.shape(zero_point) != (number_of_dimensions,):
            raise ValueError(about._errors.cstring(
                "ERROR: The shape of zero_point must match the length of " +
                "the starts and ends list"))
        parsed_zero_point = list(zero_point)

        # extract the number of line-of-sights and by the way check that
        # all entries of starts and ends have the right shape
        number_of_los = None
        for i in xrange(2*number_of_dimensions):
            if i < number_of_dimensions:
                temp_entry = starts[i]
            else:
                temp_entry = ends[i-number_of_dimensions]

            if isinstance(temp_entry, np.ndarray):
                if len(np.shape(temp_entry)) != 1:
                    raise ValueError(about._errors.cstring(
                        "ERROR: The numpy ndarrays in starts " +
                        "and ends must be flat."))

                if number_of_los is None:
                    number_of_los = len(temp_entry)
                elif number_of_los != len(temp_entry):
                    raise ValueError(about._errors.cstring(
                        "ERROR: The length of all numpy ndarrays in starts " +
                        "and ends must be the same."))
            elif np.isscalar(temp_entry):
                pass
            else:
                raise TypeError(about._errors.cstring(
                    "ERROR: The entries of starts and ends must be either " +
                    "scalar or numpy ndarrays."))

        if number_of_los is None:
            number_of_los = 1
            starts = [np.array([x]) for x in starts]
            ends = [np.array([x]) for x in ends]

        # Parse the coordinate arrays/scalars in the starts and ends list
        parsed_starts = self._parse_startsends(starts, number_of_los)
        parsed_ends = self._parse_startsends(ends, number_of_los)

        # check that sigmas_up/lows have the right shape and parse scalars
        parsed_sigmas_low = self._parse_sigmas_uplows(sigmas_low,
                                                      number_of_los)
128
129
130
        parsed_sigmas_up = self._parse_sigmas_uplows(sigmas_up, number_of_los)
        return (parsed_starts, parsed_ends, parsed_sigmas_low,
                parsed_sigmas_up, parsed_zero_point)
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

    def _parse_startsends(self, coords, number_of_los):
        result_coords = [None]*len(coords)
        for i in xrange(len(coords)):
            temp_array = np.empty(number_of_los, dtype=np.float)
            temp_array[:] = coords[i]
            result_coords[i] = temp_array
        return result_coords

    def _parse_sigmas_uplows(self, sig, number_of_los):
        if sig is None:
            parsed_sig = np.zeros(number_of_los, dtype=np.float)
        elif isinstance(sig, np.ndarray):
            if np.shape(sig) != (number_of_los,):
                    raise ValueError(about._errors.cstring(
                        "ERROR: The length of sigmas_up/sigmas_low must be " +
                        " the same as the number of line-of-sights."))
            parsed_sig = sig.astype(np.float)
        elif np.isscalar(sig):
            parsed_sig = np.empty(number_of_los, dtype=np.float)
            parsed_sig[:] = sig
        else:
            raise TypeError(about._errors.cstring(
                "ERROR: sigmas_up/sigmas_low must either be a scalar or a " +
                "numpy ndarray."))
        return parsed_sig

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    def convert_physical_to_indices(self, physical_positions):
        pixel_coordinates = [None]*len(physical_positions)
        local_zero_point = self._get_local_zero_point()

        for i in xrange(len(pixel_coordinates)):
            # Compute the distance to the zeroth pixel.
            # Then rescale the coordinates to the uniform grid.
            pixel_coordinates[i] = ((physical_positions[i] -
                                     local_zero_point[i]) /
                                    self.domain.distances[i]) + 0.5

        return pixel_coordinates

    def _convert_physical_to_pixel_lengths(self, lengths, starts, ends):
        directions = np.array(ends) - np.array(starts)
        distances = np.array(self.domain.distances)[:, None]
        rescalers = (np.linalg.norm(directions / distances, axis=0) /
                     np.linalg.norm(directions, axis=0))
        return lengths * rescalers

    def _convert_sigmas_to_physical_coordinates(self, starts, ends,
                                                sigmas_low, sigmas_up):
        starts = np.array(starts)
        ends = np.array(ends)
        c = ends - starts
        abs_c = np.linalg.norm(c, axis=0)
        sigmas_low_coords = list(starts + (abs_c - sigmas_low)*c/abs_c)
        sigmas_up_coords = list(starts + (abs_c + sigmas_up)*c/abs_c)
        return (sigmas_low_coords, sigmas_up_coords)

    def _get_local_zero_point(self):
        if self.domain.datamodel == 'np':
            return self.zero_point
        elif self.domain.datamodel in STRATEGIES['not']:
            return self.zero_point
        elif self.domain.datamodel in STRATEGIES['slicing']:
            dummy_d2o = distributed_data_object(
                                global_shape=self.domain.get_shape(),
                                dtype=np.dtype('int16'),
                                distribution_strategy=self.domain.datamodel,
                                skip_parsing=True)

            pixel_offset = dummy_d2o.distributor.local_start
            distance_offset = pixel_offset * self.domain.distances[0]
            local_zero_point = self.zero_point[:]
            local_zero_point[0] += distance_offset
            return local_zero_point
        else:
            raise NotImplementedError(about._errors.cstring(
                "ERROR: The space's datamodel is not supported:" +
                str(self.domain.datamodel)))

    def _get_local_shape(self):
        if self.domain.datamodel == 'np':
            return self.domain.get_shape()
        elif self.domain.datamodel in STRATEGIES['not']:
            return self.domain.get_shape()
        elif self.domain.datamodel in STRATEGIES['slicing']:
            dummy_d2o = distributed_data_object(
                                global_shape=self.domain.get_shape(),
                                dtype=np.dtype('int'),
                                distribution_strategy=self.domain.datamodel,
                                skip_parsing=True)
            return dummy_d2o.distributor.local_shape

    def _compute_weights_and_indices(self):
        # compute the local pixel coordinates for the starts and ends
        localized_pixel_starts = self.convert_physical_to_indices(self.starts)
        localized_pixel_ends = self.convert_physical_to_indices(self.ends)

        # Convert the sigmas from physical distances to pixel coordinates
        # Therefore transform the distances to physical coordinates...
        (sigmas_low_coords, sigmas_up_coords) = \
            self._convert_sigmas_to_physical_coordinates(self.starts,
                                                         self.ends,
                                                         self.sigmas_low,
                                                         self.sigmas_up)
        # ...and then transform them to pixel coordinates
        localized_pixel_sigmas_low = self.convert_physical_to_indices(
                                                             sigmas_low_coords)
        localized_pixel_sigmas_up = self.convert_physical_to_indices(
                                                             sigmas_up_coords)

        # get the shape of the local data slice
        local_shape = self._get_local_shape()
        # let the cython function do the hard work of integrating over
        # the individual lines
        local_indices_and_weights_list = multi_integrator(
                                                  localized_pixel_starts,
                                                  localized_pixel_ends,
                                                  localized_pixel_sigmas_low,
                                                  localized_pixel_sigmas_up,
                                                  local_shape,
                                                  list(self.domain.distances),
                                                  self.error_function)
        return local_indices_and_weights_list

    def _multiply(self, input_field):
        # extract the local data array from the input field
        try:
            local_input_data = input_field.val.data
        except AttributeError:
            local_input_data = input_field.val

        local_result = np.zeros(self.number_of_los, dtype=self.target.dtype)

        for i in xrange(len(self.local_weights_and_indices)):
            current_weights_and_indices = self.local_weights_and_indices[i]
            (los_index, indices, weights) = current_weights_and_indices
            local_result[los_index] += \
                np.sum(local_input_data[indices]*weights)

        if self.domain.datamodel == 'np':
            global_result = local_result
        elif self.domain.datamodel is STRATEGIES['not']:
            global_result = local_result
        if self.domain.datamodel in STRATEGIES['slicing']:
            global_result = np.empty_like(local_result)
            self.domain.comm.Allreduce(local_result, global_result, op=MPI.SUM)

        result_field = field(self.target, val=global_result)
        return result_field

    def _adjoint_multiply(self, input_field):
        # get the full data as np.ndarray from the input field
        try:
            full_input_data = input_field.val.get_full_data()
        except AttributeError:
            full_input_data = input_field.val

        # produce a data_object suitable to domain
        global_result_data_object = self.domain.cast(0)

        # set the references to the underlying np arrays
        try:
            local_result_data = global_result_data_object.data
        except AttributeError:
            local_result_data = global_result_data_object

        for i in xrange(len(self.local_weights_and_indices)):
            current_weights_and_indices = self.local_weights_and_indices[i]
            (los_index, indices, weights) = current_weights_and_indices
            local_result_data[indices] += \
                (full_input_data[los_index]*weights)

        # weight the result
        local_result_data /= self.domain.get_vol()

        # construct the result field
        result_field = field(self.domain)
        try:
            result_field.val.data = local_result_data
        except AttributeError:
            result_field.val = local_result_data

        return result_field
314
315
316
317
318
319
320
321
322
323
324
325