fft_smoothing_operator.py 3.02 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
# -*- coding: utf-8 -*-
import numpy as np

from nifty.operators.fft_operator import FFTOperator

from .smoothing_operator import SmoothingOperator


class FFTSmoothingOperator(SmoothingOperator):

11 12 13 14 15 16 17 18
    def __init__(self, domain, sigma, log_distances=False,
                 default_spaces=None):
        super(FFTSmoothingOperator, self).__init__(
                                                domain=domain,
                                                sigma=sigma,
                                                log_distances=log_distances,
                                                default_spaces=default_spaces)
        self._transformator_cache = {}
19

20
    def _smooth(self, x, spaces, inverse):
21 22
        # transform to the (global-)default codomain and perform all remaining
        # steps therein
23 24 25
        transformator = self._get_transformator(x.dtype)

        transformed_x = transformator(x, spaces=spaces)
26 27 28 29 30 31 32 33 34 35
        codomain = transformed_x.domain[spaces[0]]
        coaxes = transformed_x.domain_axes[spaces[0]]

        # create the kernel using the knowledge of codomain about itself
        axes_local_distribution_strategy = \
            transformed_x.val.get_axes_local_distribution_strategy(axes=coaxes)

        kernel = codomain.get_distance_array(
            distribution_strategy=axes_local_distribution_strategy)

Martin Reinecke's avatar
Martin Reinecke committed
36
        #MR FIXME: this causes calls of log(0.) which should probably be avoided
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
        if self.log_distances:
            kernel.apply_scalar_function(np.log, inplace=True)

        kernel.apply_scalar_function(
            codomain.get_fft_smoothing_kernel_function(self.sigma),
            inplace=True)

        # now, apply the kernel to transformed_x
        # this is done node-locally utilizing numpys reshaping in order to
        # apply the kernel to the correct axes
        local_transformed_x = transformed_x.val.get_local_data(copy=False)
        local_kernel = kernel.get_local_data(copy=False)

        reshaper = [transformed_x.shape[i] if i in coaxes else 1
                    for i in xrange(len(transformed_x.shape))]
        local_kernel = np.reshape(local_kernel, reshaper)

        # apply the kernel
        if inverse:
Martin Reinecke's avatar
Martin Reinecke committed
56
            #MR FIXME: danger of having division by zero or overflows
57 58 59 60 61 62
            local_transformed_x /= local_kernel
        else:
            local_transformed_x *= local_kernel

        transformed_x.val.set_local_data(local_transformed_x, copy=False)

63 64
        smoothed_x = transformator.adjoint_times(transformed_x,
                                                 spaces=spaces)
65 66 67 68 69

        result = x.copy_empty()
        result.set_val(smoothed_x, copy=False)

        return result
70 71 72

    def _get_transformator(self, dtype):
        if dtype not in self._transformator_cache:
73 74 75 76
            self._transformator_cache[dtype] = FFTOperator(
                                                    self.domain,
                                                    domain_dtype=dtype,
                                                    target_dtype=np.complex)
77
        return self._transformator_cache[dtype]