fft_smoothing_operator.py 2.65 KB
Newer Older
1
# -*- coding: utf-8 -*-
Martin Reinecke's avatar
Martin Reinecke committed
2
3

from builtins import range
4
5
import numpy as np

6
7
from ..endomorphic_operator import EndomorphicOperator
from ..fft_operator import FFTOperator
8
9


10
class FFTSmoothingOperator(EndomorphicOperator):
11

12
    def __init__(self, domain, sigma,
13
                 default_spaces=None):
14
15
16
17
18
19
20
21
        super(FFTSmoothingOperator, self).__init__(default_spaces)

        self._domain = self._parse_domain(domain)
        if len(self._domain) != 1:
            raise ValueError("SmoothingOperator only accepts exactly one "
                             "space as input domain.")

        self._sigma = sigma
22
        self._transformator_cache = {}
23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    def _times(self, x, spaces):
        if self.sigma == 0:
            return x.copy()

        # the domain of the smoothing operator contains exactly one space.
        # Hence, if spaces is None, but we passed LinearOperator's
        # _check_input_compatibility, we know that x is also solely defined
        # on that space
        if spaces is None:
            spaces = (0,)

        return self._smooth(x, spaces)

    # ---Mandatory properties and methods---
    @property
    def domain(self):
        return self._domain

    @property
    def self_adjoint(self):
        return True

    @property
    def unitary(self):
        return False

    # ---Added properties and methods---

    @property
    def sigma(self):
        return self._sigma

56
    def _smooth(self, x, spaces):
57
58
        # transform to the (global-)default codomain and perform all remaining
        # steps therein
59
60
        transformator = self._get_transformator(x.dtype)
        transformed_x = transformator(x, spaces=spaces)
61
62
63
        codomain = transformed_x.domain[spaces[0]]
        coaxes = transformed_x.domain_axes[spaces[0]]

Martin Reinecke's avatar
more    
Martin Reinecke committed
64
65
        kernel = codomain.get_distance_array()
        kernel = codomain.get_fft_smoothing_kernel_function(self.sigma)(kernel)
66
67

        # now, apply the kernel to transformed_x
68
        # this is done node-locally utilizing numpy's reshaping in order to
69
        # apply the kernel to the correct axes
Martin Reinecke's avatar
more    
Martin Reinecke committed
70
71
        local_transformed_x = transformed_x.val
        local_kernel = kernel
72

73
        reshaper = [local_transformed_x.shape[i] if i in coaxes else 1
Martin Reinecke's avatar
Martin Reinecke committed
74
                    for i in range(len(transformed_x.shape))]
75
76
        local_kernel = np.reshape(local_kernel, reshaper)

77
        local_transformed_x *= local_kernel
78

Martin Reinecke's avatar
Martin Reinecke committed
79
        transformed_x.val = local_transformed_x
80

Martin Reinecke's avatar
Martin Reinecke committed
81
        return transformator.adjoint_times(transformed_x, spaces=spaces)
82
83
84

    def _get_transformator(self, dtype):
        if dtype not in self._transformator_cache:
Martin Reinecke's avatar
more    
Martin Reinecke committed
85
            self._transformator_cache[dtype] = FFTOperator(self.domain)
86
        return self._transformator_cache[dtype]