fft_smoothing_operator.py 2.79 KB
Newer Older
1
# -*- coding: utf-8 -*-
Martin Reinecke's avatar
Martin Reinecke committed
2 3

from builtins import range
4 5
import numpy as np

6 7
from ..endomorphic_operator import EndomorphicOperator
from ..fft_operator import FFTOperator
8 9


10
class FFTSmoothingOperator(EndomorphicOperator):
11

12
    def __init__(self, domain, sigma,
13
                 default_spaces=None):
14 15 16 17 18 19 20 21
        super(FFTSmoothingOperator, self).__init__(default_spaces)

        self._domain = self._parse_domain(domain)
        if len(self._domain) != 1:
            raise ValueError("SmoothingOperator only accepts exactly one "
                             "space as input domain.")

        self._sigma = sigma
22
        self._transformator_cache = {}
23

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    def _times(self, x, spaces):
        if self.sigma == 0:
            return x.copy()

        # the domain of the smoothing operator contains exactly one space.
        # Hence, if spaces is None, but we passed LinearOperator's
        # _check_input_compatibility, we know that x is also solely defined
        # on that space
        if spaces is None:
            spaces = (0,)

        return self._smooth(x, spaces)

    # ---Mandatory properties and methods---
    @property
    def domain(self):
        return self._domain

    @property
    def self_adjoint(self):
        return True

    @property
    def unitary(self):
        return False

    # ---Added properties and methods---

    @property
    def sigma(self):
        return self._sigma

56
    def _smooth(self, x, spaces):
57 58
        # transform to the (global-)default codomain and perform all remaining
        # steps therein
59 60
        transformator = self._get_transformator(x.dtype)
        transformed_x = transformator(x, spaces=spaces)
61 62 63
        codomain = transformed_x.domain[spaces[0]]
        coaxes = transformed_x.domain_axes[spaces[0]]

Martin Reinecke's avatar
more  
Martin Reinecke committed
64
        kernel = codomain.get_distance_array()
65

Martin Reinecke's avatar
more  
Martin Reinecke committed
66
        kernel = codomain.get_fft_smoothing_kernel_function(self.sigma)(kernel)
67 68

        # now, apply the kernel to transformed_x
69
        # this is done node-locally utilizing numpy's reshaping in order to
70
        # apply the kernel to the correct axes
Martin Reinecke's avatar
more  
Martin Reinecke committed
71 72
        local_transformed_x = transformed_x.val
        local_kernel = kernel
73

74
        reshaper = [local_transformed_x.shape[i] if i in coaxes else 1
Martin Reinecke's avatar
Martin Reinecke committed
75
                    for i in range(len(transformed_x.shape))]
76 77
        local_kernel = np.reshape(local_kernel, reshaper)

78
        local_transformed_x *= local_kernel
79

Martin Reinecke's avatar
more  
Martin Reinecke committed
80
        transformed_x.val=local_transformed_x
81

82 83
        smoothed_x = transformator.adjoint_times(transformed_x,
                                                 spaces=spaces)
84 85

        result = x.copy_empty()
Martin Reinecke's avatar
more  
Martin Reinecke committed
86
        result=smoothed_x
87 88

        return result
89 90 91

    def _get_transformator(self, dtype):
        if dtype not in self._transformator_cache:
Martin Reinecke's avatar
more  
Martin Reinecke committed
92
            self._transformator_cache[dtype] = FFTOperator(self.domain)
93
        return self._transformator_cache[dtype]