scaling_operator.py 3.64 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2018 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
16
17
18
19
20
21
22
23
24
25
26
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

from __future__ import division
import numpy as np
from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator


class ScalingOperator(EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
27
    """Operator which multiplies a Field with a scalar.
Martin Reinecke's avatar
Martin Reinecke committed
28

Martin Reinecke's avatar
Martin Reinecke committed
29
    The NIFTy ScalingOperator class is a subclass derived from the
Martin Reinecke's avatar
Martin Reinecke committed
30
31
32
33
34
35
    EndomorphicOperator. It multiplies an input field with a given factor.

    Parameters
    ----------
    factor : scalar
        The multiplication factor
Martin Reinecke's avatar
docs    
Martin Reinecke committed
36
    domain : Domain or tuple of Domain or DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
37
        The domain on which the Operator's input Field lives.
Martin Reinecke's avatar
Martin Reinecke committed
38
39
40
41
42
43
44
45
46
47

    Notes
    -----
    Formally, this operator always supports all operation modes (times,
    adjoint_times, inverse_times and inverse_adjoint_times), even if `factor`
    is 0 or infinity. It is the user's responsibility to apply the operator
    only in appropriate ways (e.g. call inverse_times only if `factor` is
    nonzero).

    This shortcoming will hopefully be fixed in the future.
Martin Reinecke's avatar
Martin Reinecke committed
48
49
50
51
52
53
54
55
56
57
58
59
60
    """

    def __init__(self, factor, domain):
        super(ScalingOperator, self).__init__()

        if not np.isscalar(factor):
            raise TypeError("Scalar required")
        self._factor = factor
        self._domain = DomainTuple.make(domain)

    def apply(self, x, mode):
        self._check_input(x, mode)

Martin Reinecke's avatar
Martin Reinecke committed
61
62
63
        if self._factor == 1.:
            return x.copy()
        if self._factor == 0.:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
64
            return Field.zeros_like(x)
Martin Reinecke's avatar
Martin Reinecke committed
65

Martin Reinecke's avatar
Martin Reinecke committed
66
67
68
69
70
71
72
73
74
        if mode == self.TIMES:
            return x*self._factor
        elif mode == self.ADJOINT_TIMES:
            return x*np.conj(self._factor)
        elif mode == self.INVERSE_TIMES:
            return x*(1./self._factor)
        else:
            return x*(1./np.conj(self._factor))

Martin Reinecke's avatar
Martin Reinecke committed
75
76
    def _flip_modes(self, mode):
        if mode == 0:
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
77
            return self
Martin Reinecke's avatar
Martin Reinecke committed
78
79
80
81
82
83
84
85
86
        if mode == 1 and np.issubdtype(type(self._factor), np.floating):
            return self
        if mode == 1:
            return ScalingOperator(np.conj(self._factor), self._domain)
        elif mode == 2:
            return ScalingOperator(1./self._factor, self._domain)
        elif mode == 3:
            return ScalingOperator(1./np.conj(self._factor), self._domain)
        raise ValueError("bad operator flipping mode")
Martin Reinecke's avatar
Martin Reinecke committed
87

Martin Reinecke's avatar
Martin Reinecke committed
88
89
90
91
92
93
    @property
    def domain(self):
        return self._domain

    @property
    def capability(self):
Martin Reinecke's avatar
Martin Reinecke committed
94
        return self._all_ops
95

Martin Reinecke's avatar
Martin Reinecke committed
96
97
    def _sample_helper(self, fct, dtype):
        if fct.imag != 0. or fct.real <= 0.:
clienhar's avatar
clienhar committed
98
            raise ValueError("operator not positive definite")
Martin Reinecke's avatar
Martin Reinecke committed
99
100
        return Field.from_random(
           random_type="normal", domain=self._domain, std=fct, dtype=dtype)
clienhar's avatar
clienhar committed
101

102
    def draw_sample(self, dtype=np.float64):
Martin Reinecke's avatar
Martin Reinecke committed
103
        return self._sample_helper(np.sqrt(self._factor), dtype)
clienhar's avatar
clienhar committed
104

Martin Reinecke's avatar
Martin Reinecke committed
105
106
    def inverse_draw_sample(self, dtype=np.float64):
        return self._sample_helper(1./np.sqrt(self._factor), dtype)