multi_field.py 8.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
20
from __future__ import absolute_import, division, print_function
from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
21
22
23
from ..field import Field
import numpy as np
from .multi_domain import MultiDomain
24
from ..utilities import frozendict
Martin Reinecke's avatar
Martin Reinecke committed
25
26
27


class MultiField(object):
Martin Reinecke's avatar
Martin Reinecke committed
28
    def __init__(self, domain, val):
Martin Reinecke's avatar
Martin Reinecke committed
29
30
31
        """
        Parameters
        ----------
Martin Reinecke's avatar
Martin Reinecke committed
32
        domain: MultiDomain
Martin Reinecke's avatar
Martin Reinecke committed
33
        val: tuple containing Field or None entries
Martin Reinecke's avatar
Martin Reinecke committed
34
        """
Martin Reinecke's avatar
Martin Reinecke committed
35
36
37
38
39
40
        if not isinstance(domain, MultiDomain):
            raise TypeError("domain must be of type MultiDomain")
        if not isinstance(val, tuple):
            raise TypeError("val must be a tuple")
        if len(val) != len(domain):
            raise ValueError("length mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
41
        for d, v in zip(domain._domains, val):
Martin Reinecke's avatar
Martin Reinecke committed
42
            if isinstance(v, Field):
Martin Reinecke's avatar
Martin Reinecke committed
43
                if v._domain is not d:
Martin Reinecke's avatar
Martin Reinecke committed
44
45
                    raise ValueError("domain mismatch")
            elif v is not None:
Martin Reinecke's avatar
Martin Reinecke committed
46
                raise TypeError("bad entry in val (must be Field or None)")
Martin Reinecke's avatar
Martin Reinecke committed
47
48
49
50
        self._domain = domain
        self._val = val

    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
51
52
53
54
55
56
    def from_dict(dict, domain=None):
        if domain is None:
            domain = MultiDomain.make({key: v._domain
                                       for key, v in dict.items()})
        return MultiField(domain, tuple(dict[key] if key in dict else None
                                        for key in domain.keys()))
Martin Reinecke's avatar
Martin Reinecke committed
57
58

    def to_dict(self):
Martin Reinecke's avatar
Martin Reinecke committed
59
        return {key: val for key, val in zip(self._domain.keys(), self._val)}
Martin Reinecke's avatar
Martin Reinecke committed
60
61

    def __getitem__(self, key):
Martin Reinecke's avatar
Martin Reinecke committed
62
        return self._val[self._domain.idx[key]]
Martin Reinecke's avatar
Martin Reinecke committed
63
64

    def keys(self):
Martin Reinecke's avatar
Martin Reinecke committed
65
        return self._domain.keys()
Martin Reinecke's avatar
Martin Reinecke committed
66
67

    def items(self):
Martin Reinecke's avatar
Martin Reinecke committed
68
        return zip(self._domain.keys(), self._val)
Martin Reinecke's avatar
Martin Reinecke committed
69
70

    def values(self):
Martin Reinecke's avatar
Martin Reinecke committed
71
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
72
73
74

    @property
    def domain(self):
75
        return self._domain
Martin Reinecke's avatar
Martin Reinecke committed
76

Martin Reinecke's avatar
Martin Reinecke committed
77
78
79
80
81
82
83
84
#    @property
#    def dtype(self):
#        return {key: val.dtype for key, val in self._val.items()}

    def _transform(self, op):
        return MultiField(
            self._domain,
            tuple(op(v) if v is not None else None for v in self._val))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
85

86
87
88
    @property
    def real(self):
        """MultiField : The real part of the multi field"""
Martin Reinecke's avatar
Martin Reinecke committed
89
        return self._transform(lambda x: x.real)
90
91
92
93

    @property
    def imag(self):
        """MultiField : The imaginary part of the multi field"""
Martin Reinecke's avatar
Martin Reinecke committed
94
        return self._transform(lambda x: x.imag)
95

96
97
    @staticmethod
    def from_random(random_type, domain, dtype=np.float64, **kwargs):
Martin Reinecke's avatar
Martin Reinecke committed
98
99
100
101
        domain = MultiDomain.make(domain)
#        dtype = MultiField.build_dtype(dtype, domain)
        return MultiField(
            domain, tuple(Field.from_random(random_type, dom, dtype, **kwargs)
Martin Reinecke's avatar
Martin Reinecke committed
102
                          for dom in domain._domains))
103

Martin Reinecke's avatar
Martin Reinecke committed
104
    def _check_domain(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
105
        if other._domain is not self._domain:
Martin Reinecke's avatar
Martin Reinecke committed
106
107
108
109
110
            raise ValueError("domains are incompatible.")

    def vdot(self, x):
        result = 0.
        self._check_domain(x)
Martin Reinecke's avatar
Martin Reinecke committed
111
112
113
        for v1, v2 in zip(self._val, x._val):
            if v1 is not None and v2 is not None:
                result += v1.vdot(v2)
Martin Reinecke's avatar
Martin Reinecke committed
114
115
        return result

Martin Reinecke's avatar
Martin Reinecke committed
116
117
118
119
120
121
122
#    @staticmethod
#    def build_dtype(dtype, domain):
#        if isinstance(dtype, dict):
#            return dtype
#        if dtype is None:
#            dtype = np.float64
#        return {key: dtype for key in domain.keys()}
Martin Reinecke's avatar
Martin Reinecke committed
123

Martin Reinecke's avatar
Martin Reinecke committed
124
    @staticmethod
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
125
    def full(domain, val):
Martin Reinecke's avatar
Martin Reinecke committed
126
127
        return MultiField(domain, tuple(Field.full(dom, val)
                          for dom in domain._domains))
Martin Reinecke's avatar
Martin Reinecke committed
128

Martin Reinecke's avatar
fixes    
Martin Reinecke committed
129
    def to_global_data(self):
Martin Reinecke's avatar
Martin Reinecke committed
130
131
        return {key: val.to_global_data()
                for key, val in zip(self._domain.keys(), self._val)}
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
132
133
134

    @staticmethod
    def from_global_data(domain, arr, sum_up=False):
Martin Reinecke's avatar
Martin Reinecke committed
135
        return MultiField(domain, tuple(Field.from_global_data(domain[key],
Martin Reinecke's avatar
Martin Reinecke committed
136
137
                                                               arr[key], sum_up)
                          for key in domain.keys()))
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
138

Martin Reinecke's avatar
Martin Reinecke committed
139
140
141
142
143
144
145
146
147
148
    def norm(self):
        """ Computes the L2-norm of the field values.

        Returns
        -------
        norm : float
            The L2-norm of the field values.
        """
        return np.sqrt(np.abs(self.vdot(x=self)))

149
150
151
152
153
154
155
156
157
158
    def squared_norm(self):
        """ Computes the square of the L2-norm of the field values.

        Returns
        -------
        float
            The square of the L2-norm of the field values.
        """
        return abs(self.vdot(x=self))

Martin Reinecke's avatar
Martin Reinecke committed
159
    def __neg__(self):
Martin Reinecke's avatar
Martin Reinecke committed
160
        return self._transform(lambda x: -x)
Martin Reinecke's avatar
Martin Reinecke committed
161

162
    def __abs__(self):
Martin Reinecke's avatar
Martin Reinecke committed
163
        return self._transform(lambda x: abs(x))
164

Martin Reinecke's avatar
Martin Reinecke committed
165
    def conjugate(self):
Martin Reinecke's avatar
Martin Reinecke committed
166
        return self._transform(lambda x: x.conjugate())
Martin Reinecke's avatar
Martin Reinecke committed
167

168
    def all(self):
Martin Reinecke's avatar
Martin Reinecke committed
169
170
        for v in self._val:
            if v is None or not v.all():
171
172
173
174
                return False
        return True

    def any(self):
Martin Reinecke's avatar
Martin Reinecke committed
175
176
        for v in self._val:
            if v is not None and v.any():
177
178
179
                return True
        return False

180
181
182
    def isEquivalentTo(self, other):
        """Determines (as quickly as possible) whether `self`'s content is
        identical to `other`'s content."""
183
184
185
186
        if self is other:
            return True
        if not isinstance(other, MultiField):
            return False
Martin Reinecke's avatar
Martin Reinecke committed
187
        if self._domain is not other._domain:
188
            return False
Martin Reinecke's avatar
Martin Reinecke committed
189
190
        for v1, v2 in zip(self._val, other._val):
            if not v1.isEquivalentTo(v2):
191
192
                return False
        return True
Martin Reinecke's avatar
Martin Reinecke committed
193

194
195
196
197
198
199
200
    def isSubsetOf(self, other):
        """Determines (as quickly as possible) whether `self`'s content is
        a subset of `other`'s content."""
        if self is other:
            return True
        if not isinstance(other, MultiField):
            return False
Martin Reinecke's avatar
Martin Reinecke committed
201
202
203
        if len(set(self._domain.keys()) - set(other._domain.keys())) > 0:
            return False
        for key in self._domain.keys():
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
204
            if other._domain[key] is not self._domain[key]:
205
                return False
Martin Reinecke's avatar
Martin Reinecke committed
206
            if not other[key].isSubsetOf(self[key]):
207
208
                return False
        return True
209

210

Martin Reinecke's avatar
Martin Reinecke committed
211
212
213
214
215
216
217
for op in ["__add__", "__radd__",
           "__sub__", "__rsub__",
           "__mul__", "__rmul__",
           "__div__", "__rdiv__",
           "__truediv__", "__rtruediv__",
           "__floordiv__", "__rfloordiv__",
           "__pow__", "__rpow__",
Martin Reinecke's avatar
Martin Reinecke committed
218
219
220
           "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
    def func(op):
        def func2(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
221
            res = []
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
222
            if isinstance(other, MultiField):
Martin Reinecke's avatar
Martin Reinecke committed
223
224
225
226
227
228
229
230
                if self._domain is not other._domain:
                    raise ValueError("domain mismatch")
                for v1, v2 in zip(self._val, other._val):
                    if v1 is not None:
                        if v2 is None:
                            res.append(getattr(v1, op)(v1*0))
                        else:
                            res.append(getattr(v1, op)(v2))
Philipp Arras's avatar
Philipp Arras committed
231
                    else:
Martin Reinecke's avatar
Martin Reinecke committed
232
233
234
235
236
                        if v2 is None:
                            res.append(None)
                        else:
                            res.append(getattr(v2*0, op)(v2))
                return MultiField(self._domain, tuple(res))
Martin Reinecke's avatar
cleanup    
Martin Reinecke committed
237
            else:
Martin Reinecke's avatar
Martin Reinecke committed
238
                return self._transform(lambda x: getattr(x, op)(other))
Martin Reinecke's avatar
Martin Reinecke committed
239
240
        return func2
    setattr(MultiField, op, func(op))
241
242
243
244
245
246
247
248
249

for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
           "__itruediv__", "__ifloordiv__", "__ipow__"]:
    def func(op):
        def func2(self, other):
            raise TypeError(
                "In-place operations are deliberately not supported")
        return func2
    setattr(MultiField, op, func(op))