exp_transform.py 4.79 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17

18
import numpy as np
19
20

from .. import dobj
21
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
22
23
from ..domains.power_space import PowerSpace
from ..domains.rg_space import RGSpace
24
from ..field import Field
Martin Reinecke's avatar
Martin Reinecke committed
25
from ..utilities import infer_space, special_add_at
Philipp Arras's avatar
Philipp Arras committed
26
from .linear_operator import LinearOperator
27
28
29


class ExpTransform(LinearOperator):
30
    """Transforms log-space to target
31
32
33

    This operator creates a log-space subject to the degrees of freedom and
    and its target-domain.
Martin Reinecke's avatar
Martin Reinecke committed
34
35
    Then it transforms between this log-space and its target, which is defined
    in normal units.
36

Philipp Arras's avatar
Docs    
Philipp Arras committed
37
38
    FIXME Write something on t_0 of domain space

39
    E.g: A field in log-log-space can be transformed into log-norm-space,
40
         that is the y-axis stays logarithmic, but the x-axis is transformed.
41
42
43
44
45
46
47
48

    Parameters
    ----------
    target : domain, tuple of domains or DomainTuple
        The full output domain
    dof : int
        The degrees of freedom of the log-domain, i.e. the number of bins.
    """
Martin Reinecke's avatar
Martin Reinecke committed
49
50
    def __init__(self, target, dof, space=0):
        self._target = DomainTuple.make(target)
Martin Reinecke's avatar
Martin Reinecke committed
51
        self._capability = self.TIMES | self.ADJOINT_TIMES
Martin Reinecke's avatar
Martin Reinecke committed
52
        self._space = infer_space(self._target, space)
Martin Reinecke's avatar
Martin Reinecke committed
53
54
55
        tgt = self._target[self._space]
        if not ((isinstance(tgt, RGSpace) and tgt.harmonic) or
                isinstance(tgt, PowerSpace)):
56
57
58
            raise ValueError(
                "Target must be a harmonic RGSpace or a power space.")

Martin Reinecke's avatar
Martin Reinecke committed
59
        ndim = len(tgt.shape)
60
        if np.isscalar(dof):
Martin Reinecke's avatar
Martin Reinecke committed
61
            dof = np.full(ndim, int(dof), dtype=np.int)
62
63
64
65
66
67
68
69
        dof = np.array(dof)

        t_mins = np.empty(ndim)
        bindistances = np.empty(ndim)
        self._bindex = [None] * ndim
        self._frac = [None] * ndim

        for i in range(ndim):
Martin Reinecke's avatar
Martin Reinecke committed
70
71
72
73
            if isinstance(tgt, RGSpace):
                rng = np.arange(tgt.shape[i])
                tmp = np.minimum(rng, tgt.shape[i]+1-rng)
                k_array = tmp * tgt.distances[i]
74
            else:
Martin Reinecke's avatar
Martin Reinecke committed
75
                k_array = tgt.k_lengths
76
77
78
79
80
81
82
83
84
85
86

            # avoid taking log of first entry
            log_k_array = np.log(k_array[1:])

            # Interpolate log_k_array linearly
            t_max = np.max(log_k_array)
            t_min = np.min(log_k_array)

            # Save t_min for later
            t_mins[i] = t_min

Martin Reinecke's avatar
Martin Reinecke committed
87
88
            bindistances[i] = (t_max-t_min) / (dof[i]-1)
            coord = np.append(0., 1. + (log_k_array-t_min) / bindistances[i])
89
90
91
92
93
94
95
            self._bindex[i] = np.floor(coord).astype(int)

            # Interpolated value is computed via
            # (1.-frac)*<value from this bin> + frac*<value from next bin>
            # 0 <= frac < 1.
            self._frac[i] = coord - self._bindex[i]

Martin Reinecke's avatar
fixes    
Martin Reinecke committed
96
        from ..domains.log_rg_space import LogRGSpace
Martin Reinecke's avatar
Martin Reinecke committed
97
        log_space = LogRGSpace(2*dof+1, bindistances,
98
                               t_mins, harmonic=False)
Martin Reinecke's avatar
Martin Reinecke committed
99
100
101
        self._domain = [dom for dom in self._target]
        self._domain[self._space] = log_space
        self._domain = DomainTuple.make(self._domain)
102
103
104

    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
105
        v = x.val
106
        ndim = len(self.target.shape)
Martin Reinecke's avatar
Martin Reinecke committed
107
        curshp = list(self._dom(mode).shape)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
108
        tgtshp = self._tgt(mode).shape
Martin Reinecke's avatar
Martin Reinecke committed
109
110
        d0 = self._target.axes[self._space][0]
        for d in self._target.axes[self._space]:
Martin Reinecke's avatar
Martin Reinecke committed
111
            idx = (slice(None),) * d
Martin Reinecke's avatar
Martin Reinecke committed
112
            wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
Martin Reinecke's avatar
Martin Reinecke committed
113

Martin Reinecke's avatar
Martin Reinecke committed
114
            v, x = dobj.ensure_not_distributed(v, (d,))
115
116
117

            if mode == self.ADJOINT_TIMES:
                shp = list(x.shape)
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
118
                shp[d] = tgtshp[d]
119
                xnew = np.zeros(shp, dtype=x.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
120
121
                xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
                xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
Martin Reinecke's avatar
Martin Reinecke committed
122
            else:  # TIMES
Martin Reinecke's avatar
Martin Reinecke committed
123
124
                xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
                xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
Martin Reinecke's avatar
Martin Reinecke committed
125

Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
126
            curshp[d] = xnew.shape[d]
Martin Reinecke's avatar
Martin Reinecke committed
127
128
            v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
        return Field(self._tgt(mode), dobj.ensure_default_distributed(v))