multi_domain.py 4.22 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import absolute_import, division, print_function
20

Martin Reinecke's avatar
Martin Reinecke committed
21
22
from .compat import *
from .domain_tuple import DomainTuple
23
from .utilities import frozendict, indent
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
24
25


Martin Reinecke's avatar
Martin Reinecke committed
26
class MultiDomain(object):
Martin Reinecke's avatar
rework    
Martin Reinecke committed
27
    """A tuple of domains corresponding to a direct sum.
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
28

Martin Reinecke's avatar
rework    
Martin Reinecke committed
29
30
    This class is the domain of the direct sum of fields living
    over (possibly different) domains. To make an instance
Martin Reinecke's avatar
fix    
Martin Reinecke committed
31
    of this class, call `MultiDomain.make(inp)`.
Martin Reinecke's avatar
rework    
Martin Reinecke committed
32
    """
Martin Reinecke's avatar
fix    
Martin Reinecke committed
33
    _domainCache = {}
Reimar H Leike's avatar
Reimar H Leike committed
34

Martin Reinecke's avatar
rework    
Martin Reinecke committed
35
    def __init__(self, dict, _callingfrommake=False):
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
36
        if not _callingfrommake:
Martin Reinecke's avatar
Martin Reinecke committed
37
38
            raise NotImplementedError(
                'To create a MultiDomain call `MultiDomain.make()`.')
Martin Reinecke's avatar
Martin Reinecke committed
39
40
        self._keys = tuple(sorted(dict.keys()))
        self._domains = tuple(dict[key] for key in self._keys)
Martin Reinecke's avatar
Martin Reinecke committed
41
        self._idx = frozendict({key: i for i, key in enumerate(self._keys)})
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
42
43

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
44
    def make(inp):
Martin Reinecke's avatar
rework    
Martin Reinecke committed
45
46
47
48
49
50
51
52
53
54
55
        """Creates a MultiDomain object from a dictionary of names and domains

        Parameters
        ----------
        inp : MultiDomain or dict{name: DomainTuple}
            The already built MultiDomain or a dictionary of DomainTuples

        Returns
        ------
        A MultiDomain with the input Domains as domains
        """
Martin Reinecke's avatar
Martin Reinecke committed
56
57
58
        if isinstance(inp, MultiDomain):
            return inp
        if not isinstance(inp, dict):
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
59
60
            raise TypeError("dict expected")
        tmp = {}
Martin Reinecke's avatar
Martin Reinecke committed
61
        for key, value in inp.items():
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
62
63
64
            if not isinstance(key, str):
                raise TypeError("keys must be strings")
            tmp[key] = DomainTuple.make(value)
Martin Reinecke's avatar
Martin Reinecke committed
65
66
        tmp = frozendict(tmp)
        obj = MultiDomain._domainCache.get(tmp)
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
67
68
        if obj is not None:
            return obj
Martin Reinecke's avatar
Martin Reinecke committed
69
70
        obj = MultiDomain(tmp, _callingfrommake=True)
        MultiDomain._domainCache[tmp] = obj
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
71
        return obj
Martin Reinecke's avatar
Martin Reinecke committed
72

Martin Reinecke's avatar
Martin Reinecke committed
73
74
75
    def keys(self):
        return self._keys

Martin Reinecke's avatar
Martin Reinecke committed
76
77
78
    def values(self):
        return self._domains

Martin Reinecke's avatar
Martin Reinecke committed
79
80
81
    def domains(self):
        return self._domains

Martin Reinecke's avatar
Martin Reinecke committed
82
83
84
85
    @property
    def idx(self):
        return self._idx

Martin Reinecke's avatar
Martin Reinecke committed
86
87
88
89
    def items(self):
        return zip(self._keys, self._domains)

    def __getitem__(self, key):
Martin Reinecke's avatar
Martin Reinecke committed
90
        return self._domains[self._idx[key]]
Martin Reinecke's avatar
Martin Reinecke committed
91
92
93
94
95
96
97

    def __len__(self):
        return len(self._keys)

    def __hash__(self):
        return self._keys.__hash__() ^ self._domains.__hash__()

Martin Reinecke's avatar
Martin Reinecke committed
98
    def __eq__(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
99
100
        if self is x:
            return True
Martin Reinecke's avatar
bug fix    
Martin Reinecke committed
101
        return list(self.items()) == list(x.items())
Martin Reinecke's avatar
Martin Reinecke committed
102
103
104

    def __ne__(self, x):
        return not self.__eq__(x)
Martin Reinecke's avatar
Martin Reinecke committed
105
106
107

    def __str__(self):
        res = "MultiDomain:\n"
Martin Reinecke's avatar
Martin Reinecke committed
108
        for key, dom in zip(self._keys, self._domains):
Martin Reinecke's avatar
Martin Reinecke committed
109
110
            res += key+": "+str(dom)+"\n"
        return res
Martin Reinecke's avatar
more    
Martin Reinecke committed
111
112
113

    @staticmethod
    def union(inp):
Martin Reinecke's avatar
Martin Reinecke committed
114
115
116
        inp = set(inp)
        if len(inp) == 1:  # all domains are identical
            return inp.pop()
Martin Reinecke's avatar
more    
Martin Reinecke committed
117
118
119
120
        res = {}
        for dom in inp:
            for key, subdom in zip(dom._keys, dom._domains):
                if key in res:
Martin Reinecke's avatar
Martin Reinecke committed
121
                    if res[key] != subdom:
Martin Reinecke's avatar
more    
Martin Reinecke committed
122
123
124
125
                        raise ValueError("domain mismatch")
                else:
                    res[key] = subdom
        return MultiDomain.make(res)
126

Martin Reinecke's avatar
Martin Reinecke committed
127
128
129
    def __reduce__(self):
        return (_unpickleMultiDomain, (dict(self),))

130
131
132
133
    def __repr__(self):
        subs = "\n".join("{}:\n  {}".format(key, dom.__repr__())
                         for key, dom in self.items())
        return "MultiDomain:\n"+indent(subs)
Martin Reinecke's avatar
Martin Reinecke committed
134
135
136
137


def _unpickleMultiDomain(*args):
    return MultiDomain.make(*args)