dof_distributor.py 5.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
import numpy as np
Martin Reinecke's avatar
Martin Reinecke committed
20
from .linear_operator import LinearOperator
21
from ..utilities import infer_space
Martin Reinecke's avatar
Martin Reinecke committed
22
23
24
from ..field import Field
from ..domain_tuple import DomainTuple
from .. import dobj
Martin Reinecke's avatar
Martin Reinecke committed
25
from ..domains.dof_space import DOFSpace
Martin Reinecke's avatar
Martin Reinecke committed
26
27


Martin Reinecke's avatar
Martin Reinecke committed
28
class DOFDistributor(LinearOperator):
Martin Reinecke's avatar
Martin Reinecke committed
29
30
31
    """Operator transforming between a DOFSpace and any other domain by means
    of distribution/projection."""

Martin Reinecke's avatar
Martin Reinecke committed
32
33
    def __init__(self, dofdex, target=None, space=None):
        super(DOFDistributor, self).__init__()
Martin Reinecke's avatar
Martin Reinecke committed
34

Martin Reinecke's avatar
Martin Reinecke committed
35
36
37
38
39
        if target is None:
            target = dofdex.domain
        self._target = DomainTuple.make(target)
        space = infer_space(self._target, space)
        partner = self._target[space]
Martin Reinecke's avatar
Martin Reinecke committed
40
41
        if not isinstance(dofdex, Field):
            raise TypeError("dofdex must be a Field")
Martin Reinecke's avatar
Martin Reinecke committed
42
43
44
        if not len(dofdex.domain) == 1:
            raise ValueError("dofdex must live on exactly one Space")
        if not np.issubdtype(dofdex.dtype, np.integer):
Martin Reinecke's avatar
Martin Reinecke committed
45
            raise TypeError("dofdex must contain integer numbers")
Martin Reinecke's avatar
Martin Reinecke committed
46
        if partner != dofdex.domain[0]:
Martin Reinecke's avatar
Martin Reinecke committed
47
48
49
            raise ValueError("incorrect dofdex domain")

        nbin = dofdex.max()
Martin Reinecke's avatar
Martin Reinecke committed
50
        if partner.scalar_dvol is not None:
Martin Reinecke's avatar
Martin Reinecke committed
51
            wgt = np.bincount(dofdex.local_data.ravel(), minlength=nbin)
Martin Reinecke's avatar
Martin Reinecke committed
52
            wgt = wgt*partner.scalar_dvol
Martin Reinecke's avatar
Martin Reinecke committed
53
        else:
Martin Reinecke's avatar
Martin Reinecke committed
54
            dvol = dobj.local_data(partner.dvol)
Martin Reinecke's avatar
Martin Reinecke committed
55
            wgt = np.bincount(dofdex.local_data.ravel(),
Martin Reinecke's avatar
Martin Reinecke committed
56
57
58
59
60
61
62
63
64
                              minlength=nbin, weights=dvol)
        # The explicit conversion to float64 is necessary because bincount
        # sometimes returns its result as an integer array, even when
        # floating-point weights are present ...
        wgt = wgt.astype(np.float64, copy=False)
        wgt = dobj.np_allreduce_sum(wgt)
        if (wgt == 0).any():
            raise ValueError("empty bins detected")

Martin Reinecke's avatar
Martin Reinecke committed
65
66
67
        self._init2(dofdex.val, space, DOFSpace(wgt))

    def _init2(self, dofdex, space, other_space):
Martin Reinecke's avatar
Martin Reinecke committed
68
        self._space = space
Martin Reinecke's avatar
Martin Reinecke committed
69
70
71
        dom = list(self._target)
        dom[self._space] = other_space
        self._domain = DomainTuple.make(dom)
Martin Reinecke's avatar
Martin Reinecke committed
72

Martin Reinecke's avatar
Martin Reinecke committed
73
74
        if dobj.default_distaxis() in self._domain.axes[self._space]:
            dofdex = dobj.local_data(dofdex)
Martin Reinecke's avatar
Martin Reinecke committed
75
        else:  # dofdex must be available fully on every task
Martin Reinecke's avatar
Martin Reinecke committed
76
            dofdex = dobj.to_global_data(dofdex)
Martin Reinecke's avatar
Martin Reinecke committed
77
        self._dofdex = dofdex.ravel()
Martin Reinecke's avatar
Martin Reinecke committed
78
79
80
        firstaxis = self._target.axes[self._space][0]
        lastaxis = self._target.axes[self._space][-1]
        arrshape = dobj.local_shape(self._target.shape, 0)
Martin Reinecke's avatar
Martin Reinecke committed
81
82
        presize = np.prod(arrshape[0:firstaxis], dtype=np.int)
        postsize = np.prod(arrshape[lastaxis+1:], dtype=np.int)
Martin Reinecke's avatar
Martin Reinecke committed
83
        self._hshape = (presize, self._domain[self._space].shape[0], postsize)
Martin Reinecke's avatar
Martin Reinecke committed
84
85
        self._pshape = (presize, self._dofdex.size, postsize)

Martin Reinecke's avatar
Martin Reinecke committed
86
    def _adjoint_times(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
87
        arr = x.local_data
Martin Reinecke's avatar
Martin Reinecke committed
88
89
90
91
        arr = arr.reshape(self._pshape)
        oarr = np.zeros(self._hshape, dtype=x.dtype)
        np.add.at(oarr, (slice(None), self._dofdex, slice(None)), arr)
        if dobj.distaxis(x.val) in x.domain.axes[self._space]:
Martin Reinecke's avatar
Martin Reinecke committed
92
            oarr = dobj.np_allreduce_sum(oarr).reshape(self._domain.shape)
93
            res = Field.from_global_data(self._domain, oarr)
Martin Reinecke's avatar
Martin Reinecke committed
94
        else:
Martin Reinecke's avatar
Martin Reinecke committed
95
            oarr = oarr.reshape(dobj.local_shape(self._domain.shape,
Martin Reinecke's avatar
Martin Reinecke committed
96
                                                 dobj.distaxis(x.val)))
Martin Reinecke's avatar
Martin Reinecke committed
97
98
            res = Field(self._domain,
                        dobj.from_local_data(self._domain.shape, oarr,
Martin Reinecke's avatar
Martin Reinecke committed
99
                                             dobj.default_distaxis()))
100
        return res
Martin Reinecke's avatar
Martin Reinecke committed
101

Martin Reinecke's avatar
Martin Reinecke committed
102
103
    def _times(self, x):
        res = Field.empty(self._target, dtype=x.dtype)
Martin Reinecke's avatar
Martin Reinecke committed
104
        if dobj.distaxis(x.val) in x.domain.axes[self._space]:
105
            arr = x.to_global_data()
Martin Reinecke's avatar
Martin Reinecke committed
106
        else:
Martin Reinecke's avatar
Martin Reinecke committed
107
            arr = x.local_data
Martin Reinecke's avatar
Martin Reinecke committed
108
        arr = arr.reshape(self._hshape)
Martin Reinecke's avatar
Martin Reinecke committed
109
        oarr = res.local_data.reshape(self._pshape)
Martin Reinecke's avatar
Martin Reinecke committed
110
111
112
        oarr[()] = arr[(slice(None), self._dofdex, slice(None))]
        return res

Martin Reinecke's avatar
Martin Reinecke committed
113
114
    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
115
        return self._times(x) if mode == self.TIMES else self._adjoint_times(x)
Martin Reinecke's avatar
Martin Reinecke committed
116

Martin Reinecke's avatar
Martin Reinecke committed
117
118
119
120
121
122
123
    @property
    def domain(self):
        return self._domain

    @property
    def target(self):
        return self._target
Martin Reinecke's avatar
Martin Reinecke committed
124
125
126
127

    @property
    def capability(self):
        return self.TIMES | self.ADJOINT_TIMES