domain_tuple.py 6.21 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2 3 4 5 6 7 8 9 10 11 12 13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Martin Reinecke's avatar
Martin Reinecke committed
18
from functools import reduce
Martin Reinecke's avatar
Martin Reinecke committed
19

20 21
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
22 23 24
from . import utilities
from .domains.domain import Domain

Martin Reinecke's avatar
Martin Reinecke committed
25

Martin Reinecke's avatar
Martin Reinecke committed
26
class DomainTuple(object):
Martin Reinecke's avatar
Martin Reinecke committed
27 28
    """Ordered sequence of Domain objects.

29 30
    This class holds a tuple of :class:`Domain` objects, which together form
    the space on which a :class:`Field` is defined.
Reimar H Leike's avatar
Reimar H Leike committed
31 32
    This corresponds to a tensor product of the corresponding vector
    fields.
Martin Reinecke's avatar
Martin Reinecke committed
33 34 35 36 37 38 39

    Notes
    -----

    DomainTuples should never be created using the constructor, but rather
    via the factory function :attr:`make`!
    """
40
    _tupleCache = {}
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
41
    _scalarDomain = None
42

Martin Reinecke's avatar
step 1  
Martin Reinecke committed
43 44 45
    def __init__(self, domain, _callingfrommake=False):
        if not _callingfrommake:
            raise NotImplementedError
Martin Reinecke's avatar
Martin Reinecke committed
46 47
        self._dom = self._parse_domain(domain)
        self._axtuple = self._get_axes_tuple()
Martin Reinecke's avatar
Martin Reinecke committed
48 49 50
        self._shape = reduce(lambda x, y: x+y, (sp.shape for sp in self._dom),
                             ())
        self._size = reduce(lambda x, y: x*y, self._shape, 1)
Martin Reinecke's avatar
Martin Reinecke committed
51 52 53 54 55 56 57 58

    def _get_axes_tuple(self):
        i = 0
        res = [None]*len(self._dom)
        for idx, thing in enumerate(self._dom):
            nax = len(thing.shape)
            res[idx] = tuple(range(i, i+nax))
            i += nax
Martin Reinecke's avatar
Martin Reinecke committed
59
        return tuple(res)
Martin Reinecke's avatar
Martin Reinecke committed
60 61 62

    @staticmethod
    def make(domain):
Martin Reinecke's avatar
Martin Reinecke committed
63 64 65 66 67 68 69 70 71 72 73 74
        """Returns a DomainTuple matching `domain`.

        This function checks whether a matching DomainTuple already exists.
        If yes, this object is returned, otherwise a new DomainTuple object
        is created and returned.

        Parameters
        ----------
        domain : Domain or tuple of Domain or DomainTuple
            The geometrical structure for which the DomainTuple shall be
            obtained.
        """
Martin Reinecke's avatar
Martin Reinecke committed
75 76
        if isinstance(domain, DomainTuple):
            return domain
77 78 79 80
        domain = DomainTuple._parse_domain(domain)
        obj = DomainTuple._tupleCache.get(domain)
        if obj is not None:
            return obj
Martin Reinecke's avatar
step 1  
Martin Reinecke committed
81
        obj = DomainTuple(domain, _callingfrommake=True)
82 83
        DomainTuple._tupleCache[domain] = obj
        return obj
Martin Reinecke's avatar
Martin Reinecke committed
84 85 86 87 88

    @staticmethod
    def _parse_domain(domain):
        if domain is None:
            return ()
Martin Reinecke's avatar
Martin Reinecke committed
89
        if isinstance(domain, Domain):
Martin Reinecke's avatar
Martin Reinecke committed
90 91 92 93 94
            return (domain,)

        if not isinstance(domain, tuple):
            domain = tuple(domain)
        for d in domain:
Martin Reinecke's avatar
Martin Reinecke committed
95
            if not isinstance(d, Domain):
Martin Reinecke's avatar
Martin Reinecke committed
96 97
                raise TypeError(
                    "Given object contains something that is not an "
Martin Reinecke's avatar
Martin Reinecke committed
98
                    "instance of Domain class.")
Martin Reinecke's avatar
Martin Reinecke committed
99 100 101 102 103 104 105
        return domain

    def __getitem__(self, i):
        return self._dom[i]

    @property
    def shape(self):
Martin Reinecke's avatar
Martin Reinecke committed
106 107 108
        """tuple of int: number of pixels along each axis

        The shape of the array-like object required to store information
Philipp Arras's avatar
Philipp Arras committed
109
        defined on the DomainTuple.
Martin Reinecke's avatar
Martin Reinecke committed
110
        """
Martin Reinecke's avatar
Martin Reinecke committed
111 112 113
        return self._shape

    @property
Martin Reinecke's avatar
Martin Reinecke committed
114
    def size(self):
Martin Reinecke's avatar
Martin Reinecke committed
115 116 117 118
        """int : total number of pixels.

        Equivalent to the products over all entries in the object's shape.
        """
Martin Reinecke's avatar
Martin Reinecke committed
119
        return self._size
Martin Reinecke's avatar
Martin Reinecke committed
120

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    def scalar_weight(self, spaces=None):
        """Returns the uniform volume element for a sub-domain of `self`.

        Parameters
        ----------
        spaces : int, tuple of int or None
            Indices of the sub-domains to be considered.
            If `None`, the entire domain is used.

        Returns
        -------
        float or None
            If the requested sub-domain has a uniform volume element, it is
            returned. Otherwise, `None` is returned.
        """
        if np.isscalar(spaces):
            return self._dom[spaces].scalar_dvol

        if spaces is None:
            spaces = range(len(self._dom))
        res = 1.
        for i in spaces:
            tmp = self._dom[i].scalar_dvol
            if tmp is None:
                return None
            res *= tmp
        return res

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
    def total_volume(self, spaces=None):
        """Returns the total volume of `self` or of a subspace of it.

        Parameters
        ----------
        spaces : int, tuple of int or None
            Indices of the sub-domains of the domain to be considered.
            If `None`, the total volume of the whole domain is returned.

        Returns
        -------
        float
            the total volume of the requested (sub-)domain.
        """
        if np.isscalar(spaces):
            return self._dom[spaces].total_volume

        if spaces is None:
            spaces = range(len(self._dom))
168
        res = 1.
169 170
        for i in spaces:
            res *= self._dom[i].total_volume
171 172
        return res

Martin Reinecke's avatar
Martin Reinecke committed
173 174
    @property
    def axes(self):
Martin Reinecke's avatar
Martin Reinecke committed
175
        """tuple of tuple of int : axis indices of the underlying domains"""
Martin Reinecke's avatar
Martin Reinecke committed
176 177 178 179 180 181 182 183 184
        return self._axtuple

    def __len__(self):
        return len(self._dom)

    def __hash__(self):
        return self._dom.__hash__()

    def __eq__(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
185
        return (self is x) or (self._dom == x._dom)
Martin Reinecke's avatar
Martin Reinecke committed
186 187

    def __ne__(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
188
        return not self.__eq__(x)
189 190

    def __str__(self):
Martin Reinecke's avatar
Martin Reinecke committed
191 192
        return ("DomainTuple, len: {}\n".format(len(self)) +
                "\n".join(str(i) for i in self))
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
193

194 195 196
    def __reduce__(self):
        return (_unpickleDomainTuple, (self._dom,))

Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
197 198 199 200 201
    @staticmethod
    def scalar_domain():
        if DomainTuple._scalarDomain is None:
            DomainTuple._scalarDomain = DomainTuple.make(())
        return DomainTuple._scalarDomain
202 203 204 205

    def __repr__(self):
        subs = "\n".join(sub.__repr__() for sub in self._dom)
        return "DomainTuple:\n"+utilities.indent(subs)
206 207 208 209


def _unpickleDomainTuple(*args):
    return DomainTuple.make(*args)