dof_distributor.py 6.09 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import absolute_import, division, print_function
20

Martin Reinecke's avatar
Martin Reinecke committed
21
import numpy as np
22

Martin Reinecke's avatar
Martin Reinecke committed
23
from .. import dobj
24 25
from ..compat import *
from ..domain_tuple import DomainTuple
Martin Reinecke's avatar
Martin Reinecke committed
26
from ..domains.dof_space import DOFSpace
27 28 29
from ..field import Field
from ..utilities import infer_space
from .linear_operator import LinearOperator
Martin Reinecke's avatar
Martin Reinecke committed
30 31


Martin Reinecke's avatar
Martin Reinecke committed
32
class DOFDistributor(LinearOperator):
Martin Reinecke's avatar
PEP8  
Martin Reinecke committed
33 34 35 36 37 38
    """Operator which distributes actual degrees of freedom (dof) according to
    some distribution scheme into a higher dimensional space. This distribution
    scheme is defined by the dofdex, a degree of freedom index, which
    associates the entries within the operators domain to locations in its
    target. This operator's domain is a DOFSpace, which is defined according to
    the distribution scheme.
39 40 41 42

    Parameters
    ----------
    dofdex: Field of integers
Martin Reinecke's avatar
PEP8  
Martin Reinecke committed
43 44 45 46 47
        An integer Field on exactly one Space establishing the association
        between the locations in the operators target and the underlying
        degrees of freedom in its domain.
        It has to start at 0 and it increases monotonicly, no empty bins are
        allowed.
48
    target: Domain, tuple of Domain, or DomainTuple, optional
Martin Reinecke's avatar
PEP8  
Martin Reinecke committed
49 50
        The target of the operator. If not specified, the domain of the dofdex
        is used.
51 52 53 54
    space: int, optional:
       The index of the sub-domain on which the operator acts.
       Can be omitted if `target` only has one sub-domain.
    """
Martin Reinecke's avatar
Martin Reinecke committed
55

Martin Reinecke's avatar
Martin Reinecke committed
56 57
    def __init__(self, dofdex, target=None, space=None):
        super(DOFDistributor, self).__init__()
Martin Reinecke's avatar
Martin Reinecke committed
58

Martin Reinecke's avatar
Martin Reinecke committed
59 60 61 62 63
        if target is None:
            target = dofdex.domain
        self._target = DomainTuple.make(target)
        space = infer_space(self._target, space)
        partner = self._target[space]
Martin Reinecke's avatar
Martin Reinecke committed
64 65
        if not isinstance(dofdex, Field):
            raise TypeError("dofdex must be a Field")
Martin Reinecke's avatar
Martin Reinecke committed
66 67 68
        if not len(dofdex.domain) == 1:
            raise ValueError("dofdex must live on exactly one Space")
        if not np.issubdtype(dofdex.dtype, np.integer):
Martin Reinecke's avatar
Martin Reinecke committed
69
            raise TypeError("dofdex must contain integer numbers")
Martin Reinecke's avatar
Martin Reinecke committed
70
        if partner != dofdex.domain[0]:
Martin Reinecke's avatar
Martin Reinecke committed
71 72 73
            raise ValueError("incorrect dofdex domain")

        nbin = dofdex.max()
Martin Reinecke's avatar
Martin Reinecke committed
74
        if partner.scalar_dvol is not None:
Martin Reinecke's avatar
Martin Reinecke committed
75
            wgt = np.bincount(dofdex.local_data.ravel(), minlength=nbin)
Martin Reinecke's avatar
Martin Reinecke committed
76
            wgt = wgt*partner.scalar_dvol
Martin Reinecke's avatar
Martin Reinecke committed
77
        else:
Martin Reinecke's avatar
Martin Reinecke committed
78
            dvol = dobj.local_data(partner.dvol)
Martin Reinecke's avatar
Martin Reinecke committed
79
            wgt = np.bincount(dofdex.local_data.ravel(),
Martin Reinecke's avatar
Martin Reinecke committed
80 81 82 83 84 85 86 87 88
                              minlength=nbin, weights=dvol)
        # The explicit conversion to float64 is necessary because bincount
        # sometimes returns its result as an integer array, even when
        # floating-point weights are present ...
        wgt = wgt.astype(np.float64, copy=False)
        wgt = dobj.np_allreduce_sum(wgt)
        if (wgt == 0).any():
            raise ValueError("empty bins detected")

Martin Reinecke's avatar
Martin Reinecke committed
89 90 91
        self._init2(dofdex.val, space, DOFSpace(wgt))

    def _init2(self, dofdex, space, other_space):
Martin Reinecke's avatar
Martin Reinecke committed
92
        self._space = space
Martin Reinecke's avatar
Martin Reinecke committed
93 94 95
        dom = list(self._target)
        dom[self._space] = other_space
        self._domain = DomainTuple.make(dom)
Martin Reinecke's avatar
Martin Reinecke committed
96

Martin Reinecke's avatar
Martin Reinecke committed
97 98
        if dobj.default_distaxis() in self._domain.axes[self._space]:
            dofdex = dobj.local_data(dofdex)
Martin Reinecke's avatar
Martin Reinecke committed
99
        else:  # dofdex must be available fully on every task
Martin Reinecke's avatar
Martin Reinecke committed
100
            dofdex = dobj.to_global_data(dofdex)
Martin Reinecke's avatar
Martin Reinecke committed
101
        self._dofdex = dofdex.ravel()
Martin Reinecke's avatar
Martin Reinecke committed
102 103 104
        firstaxis = self._target.axes[self._space][0]
        lastaxis = self._target.axes[self._space][-1]
        arrshape = dobj.local_shape(self._target.shape, 0)
Martin Reinecke's avatar
Martin Reinecke committed
105 106
        presize = np.prod(arrshape[0:firstaxis], dtype=np.int)
        postsize = np.prod(arrshape[lastaxis+1:], dtype=np.int)
Martin Reinecke's avatar
Martin Reinecke committed
107
        self._hshape = (presize, self._domain[self._space].shape[0], postsize)
Martin Reinecke's avatar
Martin Reinecke committed
108 109
        self._pshape = (presize, self._dofdex.size, postsize)

Martin Reinecke's avatar
Martin Reinecke committed
110
    def _adjoint_times(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
111
        arr = x.local_data
Martin Reinecke's avatar
Martin Reinecke committed
112 113 114 115
        arr = arr.reshape(self._pshape)
        oarr = np.zeros(self._hshape, dtype=x.dtype)
        np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr)
        if dobj.distaxis(x.val) in x.domain.axes[self._space]:
Martin Reinecke's avatar
Martin Reinecke committed
116
            oarr = dobj.np_allreduce_sum(oarr).reshape(self._domain.shape)
117
            res = Field.from_global_data(self._domain, oarr)
Martin Reinecke's avatar
Martin Reinecke committed
118
        else:
Martin Reinecke's avatar
Martin Reinecke committed
119
            oarr = oarr.reshape(dobj.local_shape(self._domain.shape,
Martin Reinecke's avatar
Martin Reinecke committed
120
                                                 dobj.distaxis(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
121 122
            res = Field(self._domain,
                        dobj.from_local_data(self._domain.shape, oarr,
Martin Reinecke's avatar
Martin Reinecke committed
123
                                             dobj.default_distaxis()))
124
        return res
Martin Reinecke's avatar
Martin Reinecke committed
125

Martin Reinecke's avatar
Martin Reinecke committed
126
    def _times(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
127
        if dobj.distaxis(x.val) in x.domain.axes[self._space]:
128
            arr = x.to_global_data()
Martin Reinecke's avatar
Martin Reinecke committed
129
        else:
Martin Reinecke's avatar
Martin Reinecke committed
130
            arr = x.local_data
Martin Reinecke's avatar
Martin Reinecke committed
131
        arr = arr.reshape(self._hshape)
Martin Reinecke's avatar
Martin Reinecke committed
132 133 134 135
        oarr = np.empty(self._pshape, dtype=x.dtype)
        oarr[()] = arr[(slice(None), self._dofdex, slice(None))]
        return Field.from_local_data(
            self._target, oarr.reshape(self._target.local_shape))
Martin Reinecke's avatar
Martin Reinecke committed
136

Martin Reinecke's avatar
Martin Reinecke committed
137 138
    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
139
        return self._times(x) if mode == self.TIMES else self._adjoint_times(x)
Martin Reinecke's avatar
Martin Reinecke committed
140

Martin Reinecke's avatar
Martin Reinecke committed
141 142 143 144 145 146 147
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target
Martin Reinecke's avatar
Martin Reinecke committed
148 149 150 151

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES