exp_transform.py 4.31 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import absolute_import, division, print_function
20

21
import numpy as np
22 23 24

from .. import dobj
from ..compat import *
25
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
26 27
from ..domains.power_space import PowerSpace
from ..domains.rg_space import RGSpace
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
from ..field import Field
from .linear_operator import LinearOperator


class ExpTransform(LinearOperator):
    def __init__(self, target, dof):

        if not ((isinstance(target, RGSpace) and target.harmonic) or
                isinstance(target, PowerSpace)):
            raise ValueError(
                "Target must be a harmonic RGSpace or a power space.")

        if np.isscalar(dof):
            dof = np.full(len(target.shape), int(dof), dtype=np.int)
        dof = np.array(dof)
        ndim = len(target.shape)

        t_mins = np.empty(ndim)
        bindistances = np.empty(ndim)
        self._bindex = [None] * ndim
        self._frac = [None] * ndim

        for i in range(ndim):
            if isinstance(target, RGSpace):
                rng = np.arange(target.shape[i])
Martin Reinecke's avatar
Martin Reinecke committed
53
                tmp = np.minimum(rng, target.shape[i]+1-rng)
54 55 56 57 58 59 60 61 62 63 64 65 66 67
                k_array = tmp * target.distances[i]
            else:
                k_array = target.k_lengths

            # avoid taking log of first entry
            log_k_array = np.log(k_array[1:])

            # Interpolate log_k_array linearly
            t_max = np.max(log_k_array)
            t_min = np.min(log_k_array)

            # Save t_min for later
            t_mins[i] = t_min

Martin Reinecke's avatar
Martin Reinecke committed
68 69
            bindistances[i] = (t_max-t_min) / (dof[i]-1)
            coord = np.append(0., 1. + (log_k_array-t_min) / bindistances[i])
70 71 72 73 74 75 76
            self._bindex[i] = np.floor(coord).astype(int)

            # Interpolated value is computed via
            # (1.-frac)*<value from this bin> + frac*<value from next bin>
            # 0 <= frac < 1.
            self._frac[i] = coord - self._bindex[i]

Martin Reinecke's avatar
fixes  
Martin Reinecke committed
77
        from ..domains.log_rg_space import LogRGSpace
Martin Reinecke's avatar
Martin Reinecke committed
78
        log_space = LogRGSpace(2*dof+1, bindistances,
79 80 81 82 83 84 85 86 87 88 89 90 91 92
                               t_mins, harmonic=False)
        self._target = DomainTuple.make(target)
        self._domain = DomainTuple.make(log_space)

    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target

    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
93 94
        x = x.val
        ax = dobj.distaxis(x)
95
        ndim = len(self.target.shape)
Martin Reinecke's avatar
Martin Reinecke committed
96
        curshp = list(self._dom(mode).shape)
97
        for d in range(ndim):
Martin Reinecke's avatar
Martin Reinecke committed
98 99 100 101 102 103 104
            idx = (slice(None,),) * d
            wgt = self._frac[d].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))

            if d == ax:
                x = dobj.redistribute(x, nodist=(ax,))
            curax = dobj.distaxis(x)
            x = dobj.local_data(x)
105 106 107 108 109

            if mode == self.ADJOINT_TIMES:
                shp = list(x.shape)
                shp[d] = self._tgt(mode).shape[d]
                xnew = np.zeros(shp, dtype=x.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
110 111 112 113 114 115 116 117 118 119 120
                np.add.at(xnew, idx + (self._bindex[d],), x * (1.-wgt))
                np.add.at(xnew, idx + (self._bindex[d]+1,), x * wgt)
            else:  # TIMES
                xnew = x[idx + (self._bindex[d],)] * (1.-wgt)
                xnew += x[idx + (self._bindex[d]+1,)] * wgt

            curshp[d] = self._tgt(mode).shape[d]
            x = dobj.from_local_data(curshp, xnew, distaxis=curax)
            if d == ax:
                x = dobj.redistribute(x, dist=ax)
        return Field(self._tgt(mode), val=x)
121 122 123 124

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES