smooth_operator.py 3.08 KB
Newer Older
Jait Dixit's avatar
Jait Dixit committed
1
2
3
4
5
6
7
8
9
10
11
import numpy as np

from nifty.config import about
import nifty.nifty_utilities as utilities
from nifty import RGSpace, LMSpace
from nifty.operators.endomorphic_operator import EndomorphicOperator
from nifty.operators.fft_operator import FFTOperator

class SmoothOperator(EndomorphicOperator):

    # ---Overwritten properties and methods---
Jait Dixit's avatar
Jait Dixit committed
12
    def __init__(self, domain=(), field_type=(), sigma=None):
Jait Dixit's avatar
Jait Dixit committed
13
        super(SmoothOperator, self).__init__(domain=domain,
Jait Dixit's avatar
Jait Dixit committed
14
15
16
17
18
19
20
21
                                             field_type=field_type)

        if len(self.domain) != 1:
            raise ValueError(
                about._errors.cstring(
                    'ERROR: SmoothOperator accepts only exactly one '
                    'space as input domain.')
            )
Jait Dixit's avatar
Jait Dixit committed
22
23
24

        if self.field_type != ():
            raise ValueError(about._errors.cstring(
Jait Dixit's avatar
Jait Dixit committed
25
                'ERROR: SmoothOperator field-type must be an '
Jait Dixit's avatar
Jait Dixit committed
26
27
28
29
                'empty tuple.'
            ))

        self._sigma = sigma
Jait Dixit's avatar
Jait Dixit committed
30
31
32

    def _inverse_times(self, x, spaces, types):
        return self._smooth_helper(x, spaces, types, inverse=True)
Jait Dixit's avatar
Jait Dixit committed
33
34

    def _times(self, x, spaces, types):
Jait Dixit's avatar
Jait Dixit committed
35
        return self._smooth_helper(x, spaces, types)
Jait Dixit's avatar
Jait Dixit committed
36

Jait Dixit's avatar
Jait Dixit committed
37
38
39
40
    # ---Mandatory properties and methods---
    @property
    def implemented(self):
        return True
Jait Dixit's avatar
Jait Dixit committed
41

Jait Dixit's avatar
Jait Dixit committed
42
43
44
    @property
    def symmetric(self):
        return False
Jait Dixit's avatar
Jait Dixit committed
45

Jait Dixit's avatar
Jait Dixit committed
46
47
48
    @property
    def unitary(self):
        return False
Jait Dixit's avatar
Jait Dixit committed
49
50
51
52
53
54

    # ---Added properties and methods---
    @property
    def sigma(self):
        return self._sigma

Jait Dixit's avatar
Jait Dixit committed
55
56
57
58
    def _smooth_helper(self, x, spaces, types, inverse=False):
        # copy for doing the actual smoothing
        smooth_out = x.copy()

Jait Dixit's avatar
Jait Dixit committed
59
60
61
62
63
        if spaces is not None and self.sigma != 0:
            spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))

            axes = x.domain_axes[spaces[0]]

64
65
66
67
            transform = FFTOperator(x.domain[spaces[0]])

            # transform
            smooth_out = transform(smooth_out, spaces=spaces[0])
Jait Dixit's avatar
Jait Dixit committed
68

Jait Dixit's avatar
Jait Dixit committed
69
            # create the kernel
70
            space_obj = smooth_out.domain[spaces[0]]
Jait Dixit's avatar
Jait Dixit committed
71
72
73
            kernel = space_obj.distance_array(
                x.val.get_axes_local_distribution_strategy(axes=axes))
            kernel = kernel.apply_scalar_function(
74
                x.domain[spaces[0]].codomain_smoothing_function(self.sigma))
Jait Dixit's avatar
Jait Dixit committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88

            # local data
            local_val = smooth_out.val.get_local_data(copy=False)

            # extract local kernel and reshape
            local_kernel = kernel.get_local_data(copy=False)
            new_shape = np.ones(len(local_val.shape), dtype=np.int)
            for space_axis, val_axis in zip(range(len(space_obj.shape)), axes):
                new_shape[val_axis] = local_kernel.shape[space_axis]
            local_kernel = local_kernel.reshape(new_shape)

            # multiply kernel
            if inverse:
                local_val /= kernel
Jait Dixit's avatar
Jait Dixit committed
89
            else:
Jait Dixit's avatar
Jait Dixit committed
90
91
92
93
94
95
96
97
                local_val *= kernel

            smooth_out.val.set_local_data(local_val, copy=False)

            # inverse transform
            smooth_out = transform.inverse_times(smooth_out, spaces=spaces[0])

        return smooth_out