fft_smoothing_operator.py 2.29 KB
Newer Older
1
# -*- coding: utf-8 -*-
Martin Reinecke's avatar
Martin Reinecke committed
2
3

from builtins import range
4
5
import numpy as np

6
7
from ..endomorphic_operator import EndomorphicOperator
from ..fft_operator import FFTOperator
Martin Reinecke's avatar
Martin Reinecke committed
8
from ... import DomainTuple
9

10
class FFTSmoothingOperator(EndomorphicOperator):
11

12
    def __init__(self, domain, sigma,
13
                 default_spaces=None):
14
15
        super(FFTSmoothingOperator, self).__init__(default_spaces)

Martin Reinecke's avatar
Martin Reinecke committed
16
        self._domain = DomainTuple.make(domain)
17
18
19
20
        if len(self._domain) != 1:
            raise ValueError("SmoothingOperator only accepts exactly one "
                             "space as input domain.")

Martin Reinecke's avatar
Martin Reinecke committed
21
22
23
24
25
26
        self._sigma = float(sigma)
        if self._sigma == 0.:
            return

        self._transformator = FFTOperator(self._domain)
        codomain = self._domain[0].get_default_codomain()
27
        self._kernel = codomain.get_k_length_array()
Martin Reinecke's avatar
Martin Reinecke committed
28
29
        smoother = codomain.get_fft_smoothing_kernel_function(self._sigma)
        self._kernel = smoother(self._kernel)
30

31
    def _times(self, x, spaces):
Martin Reinecke's avatar
Martin Reinecke committed
32
        if self._sigma == 0:
33
34
35
36
37
38
            return x.copy()

        # the domain of the smoothing operator contains exactly one space.
        # Hence, if spaces is None, but we passed LinearOperator's
        # _check_input_compatibility, we know that x is also solely defined
        # on that space
Martin Reinecke's avatar
Martin Reinecke committed
39
        return self._smooth(x, (0,) if spaces is None else spaces)
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    # ---Mandatory properties and methods---
    @property
    def domain(self):
        return self._domain

    @property
    def self_adjoint(self):
        return True

    @property
    def unitary(self):
        return False

    # ---Added properties and methods---

56
    def _smooth(self, x, spaces):
57
58
        # transform to the (global-)default codomain and perform all remaining
        # steps therein
Martin Reinecke's avatar
Martin Reinecke committed
59
        transformed_x = self._transformator(x, spaces=spaces)
Martin Reinecke's avatar
Martin Reinecke committed
60
        coaxes = transformed_x.domain.axes[spaces[0]]
61
62

        # now, apply the kernel to transformed_x
63
        # this is done node-locally utilizing numpy's reshaping in order to
64
65
        # apply the kernel to the correct axes

Martin Reinecke's avatar
Martin Reinecke committed
66
        reshaper = [transformed_x.shape[i] if i in coaxes else 1
Martin Reinecke's avatar
Martin Reinecke committed
67
                    for i in range(len(transformed_x.shape))]
68

Martin Reinecke's avatar
Martin Reinecke committed
69
        transformed_x *= np.reshape(self._kernel, reshaper)
70

Martin Reinecke's avatar
Martin Reinecke committed
71
        return self._transformator.adjoint_times(transformed_x, spaces=spaces)