smooth_operator.py 3.55 KB
Newer Older
Jait Dixit's avatar
Jait Dixit committed
1
2
3
4
5
6
7
8
9
10
11
import numpy as np

from nifty.config import about
import nifty.nifty_utilities as utilities
from nifty import RGSpace, LMSpace
from nifty.operators.endomorphic_operator import EndomorphicOperator
from nifty.operators.fft_operator import FFTOperator

class SmoothOperator(EndomorphicOperator):

    # ---Overwritten properties and methods---
Jait Dixit's avatar
Jait Dixit committed
12
    def __init__(self, domain=(), field_type=(), sigma=None):
Jait Dixit's avatar
Jait Dixit committed
13
        super(SmoothOperator, self).__init__(domain=domain,
Jait Dixit's avatar
Jait Dixit committed
14
15
16
17
18
19
20
21
                                             field_type=field_type)

        if len(self.domain) != 1:
            raise ValueError(
                about._errors.cstring(
                    'ERROR: SmoothOperator accepts only exactly one '
                    'space as input domain.')
            )
Jait Dixit's avatar
Jait Dixit committed
22
23
24

        if self.field_type != ():
            raise ValueError(about._errors.cstring(
Jait Dixit's avatar
Jait Dixit committed
25
                'ERROR: SmoothOperator field-type must be an '
Jait Dixit's avatar
Jait Dixit committed
26
27
28
29
                'empty tuple.'
            ))

        self._sigma = sigma
Jait Dixit's avatar
Jait Dixit committed
30
31
32

    def _inverse_times(self, x, spaces, types):
        return self._smooth_helper(x, spaces, types, inverse=True)
Jait Dixit's avatar
Jait Dixit committed
33
34

    def _times(self, x, spaces, types):
Jait Dixit's avatar
Jait Dixit committed
35
        return self._smooth_helper(x, spaces, types)
Jait Dixit's avatar
Jait Dixit committed
36

Jait Dixit's avatar
Jait Dixit committed
37
38
39
40
    # ---Mandatory properties and methods---
    @property
    def implemented(self):
        return True
Jait Dixit's avatar
Jait Dixit committed
41

Jait Dixit's avatar
Jait Dixit committed
42
43
44
    @property
    def symmetric(self):
        return False
Jait Dixit's avatar
Jait Dixit committed
45

Jait Dixit's avatar
Jait Dixit committed
46
47
48
    @property
    def unitary(self):
        return False
Jait Dixit's avatar
Jait Dixit committed
49
50
51
52
53
54

    # ---Added properties and methods---
    @property
    def sigma(self):
        return self._sigma

Jait Dixit's avatar
Jait Dixit committed
55
    def _smooth_helper(self, x, spaces, types, inverse=False):
theos's avatar
theos committed
56
57
58
59
60
61
62
63
64
65
        if self.sigma == 0:
            return x.copy()

        # the domain of the smoothing operator contains exactly one space.
        # Hence, if spaces is None, but we passed LinearOperator's
        # _check_input_compatibility, we know that x is also solely defined
        # on that space
        if spaces is None:
            spaces = (0,)
        else:
Jait Dixit's avatar
Jait Dixit committed
66
67
            spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

theos's avatar
theos committed
68
        Transformator = FFTOperator(x.domain[spaces[0]])
Jait Dixit's avatar
Jait Dixit committed
69

theos's avatar
theos committed
70
71
72
73
74
        # transform to the (global-)default codomain and perform all remaining
        # steps therein
        transformed_x = Transformator(x, spaces=spaces)
        codomain = transformed_x.domain[spaces[0]]
        coaxes = transformed_x.domain_axes[spaces[0]]
75

theos's avatar
theos committed
76
77
78
        # create the kernel using the knowledge of codomain about itself
        axes_local_distribution_strategy = \
            transformed_x.val.get_axes_local_distribution_strategy(axes=coaxes)
Jait Dixit's avatar
Jait Dixit committed
79

theos's avatar
theos committed
80
81
82
83
84
        kernel = codomain.distance_array(
                        distribution_strategy=axes_local_distribution_strategy)
        kernel.apply_scalar_function(
            codomain.get_smoothing_kernel_function(self.sigma),
            inplace=True)
Jait Dixit's avatar
Jait Dixit committed
85

theos's avatar
theos committed
86
87
88
89
90
        # now, apply the kernel to transformed_x
        # this is done node-locally utilizing numpys reshaping in order to
        # apply the kernel to the correct axes
        local_transformed_x = transformed_x.val.get_local_data(copy=False)
        local_kernel = kernel.get_local_data(copy=False)
Jait Dixit's avatar
Jait Dixit committed
91

theos's avatar
theos committed
92
93
94
        reshaper = [transformed_x.shape[i] if i in coaxes else 1
                    for i in xrange(len(transformed_x.shape))]
        local_kernel = np.reshape(local_kernel, reshaper)
Jait Dixit's avatar
Jait Dixit committed
95

theos's avatar
theos committed
96
97
98
99
100
        # apply the kernel
        if inverse:
            local_transformed_x /= local_kernel
        else:
            local_transformed_x *= local_kernel
Jait Dixit's avatar
Jait Dixit committed
101

theos's avatar
theos committed
102
        transformed_x.val.set_local_data(local_transformed_x, copy=False)
Jait Dixit's avatar
Jait Dixit committed
103

theos's avatar
theos committed
104
        result = Transformator.inverse_times(transformed_x, spaces=spaces)
Jait Dixit's avatar
Jait Dixit committed
105

theos's avatar
theos committed
106
        return result