multi_domain.py 4.15 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
14
# Copyright(C) 2013-2019 Max-Planck-Society
Martin Reinecke's avatar
Martin Reinecke committed
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
Martin Reinecke's avatar
Martin Reinecke committed
17

Martin Reinecke's avatar
Martin Reinecke committed
18
from .domain_tuple import DomainTuple
19
from .utilities import frozendict, indent
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
20
21


Martin Reinecke's avatar
Martin Reinecke committed
22
class MultiDomain(object):
Martin Reinecke's avatar
rework    
Martin Reinecke committed
23
    """A tuple of domains corresponding to a direct sum.
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
24

Philipp Arras's avatar
Philipp Arras committed
25
26
    This class is the domain of the direct sum of fields defined
    on (possibly different) domains. To make an instance
Martin Reinecke's avatar
fix    
Martin Reinecke committed
27
    of this class, call `MultiDomain.make(inp)`.
Martin Reinecke's avatar
rework    
Martin Reinecke committed
28
    """
Martin Reinecke's avatar
fix    
Martin Reinecke committed
29
    _domainCache = {}
Reimar H Leike's avatar
Reimar H Leike committed
30

Martin Reinecke's avatar
rework    
Martin Reinecke committed
31
    def __init__(self, dict, _callingfrommake=False):
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
32
        if not _callingfrommake:
Martin Reinecke's avatar
Martin Reinecke committed
33
34
            raise NotImplementedError(
                'To create a MultiDomain call `MultiDomain.make()`.')
Martin Reinecke's avatar
Martin Reinecke committed
35
36
        self._keys = tuple(sorted(dict.keys()))
        self._domains = tuple(dict[key] for key in self._keys)
Martin Reinecke's avatar
Martin Reinecke committed
37
        self._idx = frozendict({key: i for i, key in enumerate(self._keys)})
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
38
39

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
40
    def make(inp):
Martin Reinecke's avatar
rework    
Martin Reinecke committed
41
42
43
44
45
46
47
48
49
50
51
        """Creates a MultiDomain object from a dictionary of names and domains

        Parameters
        ----------
        inp : MultiDomain or dict{name: DomainTuple}
            The already built MultiDomain or a dictionary of DomainTuples

        Returns
        ------
        A MultiDomain with the input Domains as domains
        """
Martin Reinecke's avatar
Martin Reinecke committed
52
53
54
        if isinstance(inp, MultiDomain):
            return inp
        if not isinstance(inp, dict):
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
55
56
            raise TypeError("dict expected")
        tmp = {}
Martin Reinecke's avatar
Martin Reinecke committed
57
        for key, value in inp.items():
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
58
59
60
            if not isinstance(key, str):
                raise TypeError("keys must be strings")
            tmp[key] = DomainTuple.make(value)
Martin Reinecke's avatar
Martin Reinecke committed
61
62
        tmp = frozendict(tmp)
        obj = MultiDomain._domainCache.get(tmp)
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
63
64
        if obj is not None:
            return obj
Martin Reinecke's avatar
Martin Reinecke committed
65
66
        obj = MultiDomain(tmp, _callingfrommake=True)
        MultiDomain._domainCache[tmp] = obj
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
67
        return obj
Martin Reinecke's avatar
Martin Reinecke committed
68

Martin Reinecke's avatar
Martin Reinecke committed
69
70
71
    def keys(self):
        return self._keys

72
73
74
    def values(self):
        return self._domains

Martin Reinecke's avatar
Martin Reinecke committed
75
76
77
    def domains(self):
        return self._domains

Martin Reinecke's avatar
Martin Reinecke committed
78
79
80
81
    @property
    def idx(self):
        return self._idx

Martin Reinecke's avatar
Martin Reinecke committed
82
83
84
85
    def items(self):
        return zip(self._keys, self._domains)

    def __getitem__(self, key):
Martin Reinecke's avatar
Martin Reinecke committed
86
        return self._domains[self._idx[key]]
Martin Reinecke's avatar
Martin Reinecke committed
87
88
89
90
91
92
93

    def __len__(self):
        return len(self._keys)

    def __hash__(self):
        return self._keys.__hash__() ^ self._domains.__hash__()

Martin Reinecke's avatar
Martin Reinecke committed
94
    def __eq__(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
95
96
        if self is x:
            return True
97
        return list(self.items()) == list(x.items())
Martin Reinecke's avatar
Martin Reinecke committed
98
99
100

    def __ne__(self, x):
        return not self.__eq__(x)
Martin Reinecke's avatar
Martin Reinecke committed
101

102
103
    @property
    def size(self):
Martin Reinecke's avatar
Martin Reinecke committed
104
        return sum(dom.size for dom in self._domains)
105

Martin Reinecke's avatar
Martin Reinecke committed
106
107
    def __str__(self):
        res = "MultiDomain:\n"
Martin Reinecke's avatar
Martin Reinecke committed
108
        for key, dom in zip(self._keys, self._domains):
Martin Reinecke's avatar
Martin Reinecke committed
109
110
            res += key+": "+str(dom)+"\n"
        return res
Martin Reinecke's avatar
more    
Martin Reinecke committed
111
112
113

    @staticmethod
    def union(inp):
Martin Reinecke's avatar
Martin Reinecke committed
114
115
116
        inp = set(inp)
        if len(inp) == 1:  # all domains are identical
            return inp.pop()
Martin Reinecke's avatar
more    
Martin Reinecke committed
117
118
119
120
        res = {}
        for dom in inp:
            for key, subdom in zip(dom._keys, dom._domains):
                if key in res:
Martin Reinecke's avatar
Martin Reinecke committed
121
                    if res[key] != subdom:
Martin Reinecke's avatar
more    
Martin Reinecke committed
122
123
124
125
                        raise ValueError("domain mismatch")
                else:
                    res[key] = subdom
        return MultiDomain.make(res)
126

127
128
129
    def __reduce__(self):
        return (_unpickleMultiDomain, (dict(self),))

130
131
132
133
    def __repr__(self):
        subs = "\n".join("{}:\n  {}".format(key, dom.__repr__())
                         for key, dom in self.items())
        return "MultiDomain:\n"+indent(subs)
134
135
136
137


def _unpickleMultiDomain(*args):
    return MultiDomain.make(*args)