dof_projection_operator.py 5.03 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
20
from .linear_operator import LinearOperator
21
from ..utilities import infer_space
Martin Reinecke's avatar
Martin Reinecke committed
22
from .. import Field, DomainTuple, dobj
Martin Reinecke's avatar
Martin Reinecke committed
23 24 25 26
from ..spaces import DOFSpace


class DOFProjectionOperator(LinearOperator):
Martin Reinecke's avatar
Martin Reinecke committed
27
    def __init__(self, dofdex, domain=None, space=None):
Martin Reinecke's avatar
Martin Reinecke committed
28 29
        super(DOFProjectionOperator, self).__init__()

Martin Reinecke's avatar
Martin Reinecke committed
30 31
        if domain is None:
            domain = dofdex.domain
Martin Reinecke's avatar
Martin Reinecke committed
32
        self._domain = DomainTuple.make(domain)
33
        space = infer_space(self._domain, space)
Martin Reinecke's avatar
Martin Reinecke committed
34 35 36
        partner = self._domain[space]
        if not isinstance(dofdex, Field):
            raise TypeError("dofdex must be a Field")
Martin Reinecke's avatar
Martin Reinecke committed
37 38 39
        if not len(dofdex.domain) == 1:
            raise ValueError("dofdex must live on exactly one Space")
        if not np.issubdtype(dofdex.dtype, np.integer):
Martin Reinecke's avatar
Martin Reinecke committed
40
            raise TypeError("dofdex must contain integer numbers")
Martin Reinecke's avatar
Martin Reinecke committed
41
        if partner != dofdex.domain[0]:
Martin Reinecke's avatar
Martin Reinecke committed
42 43 44 45 46
            raise ValueError("incorrect dofdex domain")

        nbin = dofdex.max()
        if partner.scalar_dvol() is not None:
            wgt = np.bincount(dobj.local_data(dofdex.val).ravel(),
Martin Reinecke's avatar
Martin Reinecke committed
47 48
                              minlength=nbin)
            wgt = wgt*partner.scalar_dvol()
Martin Reinecke's avatar
Martin Reinecke committed
49 50 51 52 53 54 55 56 57 58 59 60
        else:
            dvol = dobj.local_data(partner.dvol())
            wgt = np.bincount(dobj.local_data(dofdex.val).ravel(),
                              minlength=nbin, weights=dvol)
        # The explicit conversion to float64 is necessary because bincount
        # sometimes returns its result as an integer array, even when
        # floating-point weights are present ...
        wgt = wgt.astype(np.float64, copy=False)
        wgt = dobj.np_allreduce_sum(wgt)
        if (wgt == 0).any():
            raise ValueError("empty bins detected")

Martin Reinecke's avatar
Martin Reinecke committed
61 62 63
        self._init2(dofdex.val, space, DOFSpace(wgt))

    def _init2(self, dofdex, space, other_space):
Martin Reinecke's avatar
Martin Reinecke committed
64 65
        self._space = space
        tgt = list(self._domain)
Martin Reinecke's avatar
Martin Reinecke committed
66
        tgt[self._space] = other_space
Martin Reinecke's avatar
Martin Reinecke committed
67 68
        self._target = DomainTuple.make(tgt)

Martin Reinecke's avatar
Martin Reinecke committed
69 70
        if dobj.default_distaxis() in self._domain.axes[self._space]:
            dofdex = dobj.local_data(dofdex)
Martin Reinecke's avatar
Martin Reinecke committed
71
        else:  # dofdex must be available fully on every task
Martin Reinecke's avatar
Martin Reinecke committed
72
            dofdex = dobj.to_global_data(dofdex)
Martin Reinecke's avatar
Martin Reinecke committed
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        self._dofdex = dofdex.ravel()
        firstaxis = self._domain.axes[self._space][0]
        lastaxis = self._domain.axes[self._space][-1]
        arrshape = dobj.local_shape(self._domain.shape, 0)
        presize = np.prod(arrshape[0:firstaxis], dtype=np.int)
        postsize = np.prod(arrshape[lastaxis+1:], dtype=np.int)
        self._hshape = (presize, self._target[self._space].shape[0], postsize)
        self._pshape = (presize, self._dofdex.size, postsize)

    def _times(self, x):
        arr = dobj.local_data(x.weight(1).val)
        arr = arr.reshape(self._pshape)
        oarr = np.zeros(self._hshape, dtype=x.dtype)
        np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr)
        if dobj.distaxis(x.val) in x.domain.axes[self._space]:
            oarr = dobj.np_allreduce_sum(oarr).reshape(self._target.shape)
            res = Field(self._target, dobj.from_global_data(oarr))
        else:
            oarr = oarr.reshape(dobj.local_shape(self._target.shape,
                                                 dobj.distaxis(x.val)))
            res = Field(self._target,
                        dobj.from_local_data(self._target.shape, oarr,
                                             dobj.default_distaxis()))
        return res.weight(-1, spaces=self._space)

    def _adjoint_times(self, x):
        res = Field.empty(self._domain, dtype=x.dtype)
        if dobj.distaxis(x.val) in x.domain.axes[self._space]:
            arr = dobj.to_global_data(x.val)
        else:
            arr = dobj.local_data(x.val)
        arr = arr.reshape(self._hshape)
        oarr = dobj.local_data(res.val).reshape(self._pshape)
        oarr[()] = arr[(slice(None), self._dofdex, slice(None))]
        return res

Martin Reinecke's avatar
Martin Reinecke committed
109 110
    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
111
        return self._times(x) if mode == self.TIMES else self._adjoint_times(x)
Martin Reinecke's avatar
Martin Reinecke committed
112

Martin Reinecke's avatar
Martin Reinecke committed
113 114 115 116 117 118 119
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target
Martin Reinecke's avatar
Martin Reinecke committed
120 121 122 123

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES