multi_field.py 12.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
Philipp Arras's avatar
Philipp Arras committed
14
# Copyright(C) 2013-2020 Max-Planck-Society
15
#
16
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
17
18
19

import numpy as np

Philipp Arras's avatar
Philipp Arras committed
20
from . import utilities
Martin Reinecke's avatar
Martin Reinecke committed
21
from .field import Field
Martin Reinecke's avatar
Martin Reinecke committed
22
from .multi_domain import MultiDomain
23
from .domain_tuple import DomainTuple
Martin Reinecke's avatar
misc    
Martin Reinecke committed
24
from .operators.operator import Operator
Martin Reinecke's avatar
Martin Reinecke committed
25
26


Martin Reinecke's avatar
misc    
Martin Reinecke committed
27
class MultiField(Operator):
Martin Reinecke's avatar
Martin Reinecke committed
28
    def __init__(self, domain, val):
29
30
        """The discrete representation of a continuous field over a sum space.

Martin Reinecke's avatar
Martin Reinecke committed
31
32
        Parameters
        ----------
Martin Reinecke's avatar
Martin Reinecke committed
33
        domain: MultiDomain
Martin Reinecke's avatar
Martin Reinecke committed
34
        val: tuple containing Field entries
Martin Reinecke's avatar
Martin Reinecke committed
35
        """
Martin Reinecke's avatar
Martin Reinecke committed
36
37
38
39
40
41
        if not isinstance(domain, MultiDomain):
            raise TypeError("domain must be of type MultiDomain")
        if not isinstance(val, tuple):
            raise TypeError("val must be a tuple")
        if len(val) != len(domain):
            raise ValueError("length mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
42
        for d, v in zip(domain._domains, val):
Martin Reinecke's avatar
Martin Reinecke committed
43
            if isinstance(v, Field):
Martin Reinecke's avatar
Martin Reinecke committed
44
                if v._domain != d:
Martin Reinecke's avatar
Martin Reinecke committed
45
                    raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
46
47
            else:
                raise TypeError("bad entry in val (must be Field)")
Martin Reinecke's avatar
Martin Reinecke committed
48
49
50
51
        self._domain = domain
        self._val = val

    @staticmethod
Philipp Arras's avatar
Philipp Arras committed
52
    def from_dict(dct, domain=None):
Martin Reinecke's avatar
Martin Reinecke committed
53
        if domain is None:
Philipp Arras's avatar
Philipp Arras committed
54
            for dd in dct.values():
55
                if not isinstance(dd.domain, DomainTuple):
Martin Reinecke's avatar
Martin Reinecke committed
56
57
                    raise TypeError('Values of dictionary need to be Fields '
                                    'defined on DomainTuples.')
Martin Reinecke's avatar
Martin Reinecke committed
58
            domain = MultiDomain.make({key: v._domain
Philipp Arras's avatar
Philipp Arras committed
59
60
                                       for key, v in dct.items()})
        res = tuple(dct[key] if key in dct else Field(dom, 0.)
Martin Reinecke's avatar
tweak    
Martin Reinecke committed
61
62
                    for key, dom in zip(domain.keys(), domain.domains()))
        return MultiField(domain, res)
Martin Reinecke's avatar
Martin Reinecke committed
63
64

    def to_dict(self):
Martin Reinecke's avatar
Martin Reinecke committed
65
        return {key: val for key, val in zip(self._domain.keys(), self._val)}
Martin Reinecke's avatar
Martin Reinecke committed
66
67

    def __getitem__(self, key):
Martin Reinecke's avatar
Martin Reinecke committed
68
        return self._val[self._domain.idx[key]]
Martin Reinecke's avatar
Martin Reinecke committed
69

Martin Reinecke's avatar
fix #72    
Martin Reinecke committed
70
    def __contains__(self, key):
Martin Reinecke's avatar
Martin Reinecke committed
71
        return key in self._domain.idx
Martin Reinecke's avatar
fix #72    
Martin Reinecke committed
72

Martin Reinecke's avatar
Martin Reinecke committed
73
    def keys(self):
Martin Reinecke's avatar
Martin Reinecke committed
74
        return self._domain.keys()
Martin Reinecke's avatar
Martin Reinecke committed
75
76

    def items(self):
Martin Reinecke's avatar
Martin Reinecke committed
77
        return zip(self._domain.keys(), self._val)
Martin Reinecke's avatar
Martin Reinecke committed
78
79

    def values(self):
Martin Reinecke's avatar
Martin Reinecke committed
80
        return self._val
Martin Reinecke's avatar
Martin Reinecke committed
81
82
83

    @property
    def domain(self):
84
        return self._domain
Martin Reinecke's avatar
Martin Reinecke committed
85

Martin Reinecke's avatar
Martin Reinecke committed
86
87
88
89
90
#    @property
#    def dtype(self):
#        return {key: val.dtype for key, val in self._val.items()}

    def _transform(self, op):
Martin Reinecke's avatar
Martin Reinecke committed
91
        return MultiField(self._domain, tuple(op(v) for v in self._val))
Martin Reinecke's avatar
tweaks    
Martin Reinecke committed
92

93
94
95
    @property
    def real(self):
        """MultiField : The real part of the multi field"""
Martin Reinecke's avatar
Martin Reinecke committed
96
        return self._transform(lambda x: x.real)
97
98
99
100

    @property
    def imag(self):
        """MultiField : The imaginary part of the multi field"""
Martin Reinecke's avatar
Martin Reinecke committed
101
        return self._transform(lambda x: x.imag)
102

103
    @staticmethod
104
    def from_random(domain, random_type='normal', dtype=np.float64, **kwargs):
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        """Draws a random multi-field with the given parameters.

        Parameters
        ----------
        random_type : 'pm1', 'normal', or 'uniform'
            The random distribution to use.
        domain : DomainTuple
            The domain of the output random Field.
        dtype : type
            The datatype of the output random Field.

        Returns
        -------
        MultiField
            The newly created :class:`MultiField`.

        Notes
        -----
Martin Reinecke's avatar
Martin Reinecke committed
123
124
125
126
        The individual fields within this multi-field will be drawn in alphabetical
        order of the multi-field's domain keys. As a consequence, renaming these
        keys may cause the multi-field to be filled with different random numbers,
        even for the same initial RNG state.
127
        """
Martin Reinecke's avatar
Martin Reinecke committed
128
        domain = MultiDomain.make(domain)
Philipp Arras's avatar
Philipp Arras committed
129
130
131
132
        if isinstance(dtype, dict):
            dtype = {kk: np.dtype(dt) for kk, dt in dtype.items()}
        else:
            dtype = np.dtype(dtype)
Philipp Arras's avatar
Philipp Arras committed
133
            dtype = {kk: dtype for kk in domain.keys()}
134
        dct = {kk: Field.from_random(domain[kk], random_type, dtype[kk], **kwargs)
Philipp Arras's avatar
Philipp Arras committed
135
136
               for kk in domain.keys()}
        return MultiField.from_dict(dct)
137

Martin Reinecke's avatar
Martin Reinecke committed
138
    def _check_domain(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
139
        if other._domain != self._domain:
Martin Reinecke's avatar
Martin Reinecke committed
140
141
            raise ValueError("domains are incompatible.")

Martin Reinecke's avatar
Martin Reinecke committed
142
    def s_vdot(self, x):
Martin Reinecke's avatar
Martin Reinecke committed
143
144
        result = 0.
        self._check_domain(x)
Martin Reinecke's avatar
Martin Reinecke committed
145
        for v1, v2 in zip(self._val, x._val):
Martin Reinecke's avatar
Martin Reinecke committed
146
            result += v1.s_vdot(v2)
Martin Reinecke's avatar
Martin Reinecke committed
147
148
        return result

Martin Reinecke's avatar
Martin Reinecke committed
149
150
151
    def vdot(self, x):
        return Field.scalar(self.s_vdot(x))

Martin Reinecke's avatar
Martin Reinecke committed
152
153
154
155
156
157
158
#    @staticmethod
#    def build_dtype(dtype, domain):
#        if isinstance(dtype, dict):
#            return dtype
#        if dtype is None:
#            dtype = np.float64
#        return {key: dtype for key in domain.keys()}
Martin Reinecke's avatar
Martin Reinecke committed
159

Martin Reinecke's avatar
Martin Reinecke committed
160
    @staticmethod
Martin Reinecke's avatar
step 1    
Martin Reinecke committed
161
    def full(domain, val):
Martin Reinecke's avatar
more    
Martin Reinecke committed
162
        domain = MultiDomain.make(domain)
Martin Reinecke's avatar
Martin Reinecke committed
163
        return MultiField(domain, tuple(Field(dom, val)
Martin Reinecke's avatar
Martin Reinecke committed
164
                          for dom in domain._domains))
Martin Reinecke's avatar
Martin Reinecke committed
165

Martin Reinecke's avatar
Martin Reinecke committed
166
167
    @property
    def val(self):
Martin Reinecke's avatar
merge    
Martin Reinecke committed
168
        return {key: val.val
Martin Reinecke's avatar
Martin Reinecke committed
169
                for key, val in zip(self._domain.keys(), self._val)}
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
170

Martin Reinecke's avatar
Martin Reinecke committed
171
172
173
174
    def val_rw(self):
        return {key: val.val_rw()
                for key, val in zip(self._domain.keys(), self._val)}

Martin Reinecke's avatar
fixes    
Martin Reinecke committed
175
    @staticmethod
Martin Reinecke's avatar
Martin Reinecke committed
176
    def from_raw(domain, arr):
Martin Reinecke's avatar
Martin Reinecke committed
177
        return MultiField(
Martin Reinecke's avatar
stage 3    
Martin Reinecke committed
178
            domain, tuple(Field(domain[key], arr[key])
Martin Reinecke's avatar
Martin Reinecke committed
179
                          for key in domain.keys()))
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
180

Martin Reinecke's avatar
Martin Reinecke committed
181
    def norm(self, ord=2):
182
        """Computes the norm of the field values.
Martin Reinecke's avatar
Martin Reinecke committed
183
184
185
186
187

        Parameters
        ----------
        ord : int, default=2
            accepted values: 1, 2, ..., np.inf
Martin Reinecke's avatar
Martin Reinecke committed
188
189
190
191

        Returns
        -------
        norm : float
Martin Reinecke's avatar
Martin Reinecke committed
192
            The norm of the field values.
Martin Reinecke's avatar
Martin Reinecke committed
193
        """
Martin Reinecke's avatar
Martin Reinecke committed
194
195
196
197
198
        nrm = np.asarray([f.norm(ord) for f in self._val])
        if ord == np.inf:
            return nrm.max()
        return (nrm ** ord).sum() ** (1./ord)
#        return np.sqrt(np.abs(self.vdot(x=self)))
Martin Reinecke's avatar
Martin Reinecke committed
199

Martin Reinecke's avatar
Martin Reinecke committed
200
    def s_sum(self):
201
        """Computes the sum all field values.
202
203
204
205
206
207

        Returns
        -------
        norm : float
            The sum of the field values.
        """
Martin Reinecke's avatar
Martin Reinecke committed
208
        return utilities.my_sum(map(lambda v: v.s_sum(), self._val))
209
210
211

    @property
    def size(self):
212
        """Computes the overall degrees of freedom.
213
214
215
216
217
218

        Returns
        -------
        size : int
            The sum of the size of the individual fields
        """
Martin Reinecke's avatar
Martin Reinecke committed
219
        return utilities.my_sum(map(lambda d: d.size, self._domain.domains()))
220

Martin Reinecke's avatar
Martin Reinecke committed
221
    def __neg__(self):
Martin Reinecke's avatar
Martin Reinecke committed
222
        return self._transform(lambda x: -x)
Martin Reinecke's avatar
Martin Reinecke committed
223

224
    def __abs__(self):
Martin Reinecke's avatar
Martin Reinecke committed
225
        return self._transform(lambda x: abs(x))
226

Martin Reinecke's avatar
Martin Reinecke committed
227
    def conjugate(self):
Martin Reinecke's avatar
Martin Reinecke committed
228
        return self._transform(lambda x: x.conjugate())
Martin Reinecke's avatar
Martin Reinecke committed
229

Martin Reinecke's avatar
misc    
Martin Reinecke committed
230
231
    def clip(self, a_min=None, a_max=None):
        return self.ptw("clip", a_min, a_max)
Martin Reinecke's avatar
Martin Reinecke committed
232

Martin Reinecke's avatar
Martin Reinecke committed
233
    def s_all(self):
Martin Reinecke's avatar
Martin Reinecke committed
234
        for v in self._val:
Martin Reinecke's avatar
Martin Reinecke committed
235
            if not v.s_all():
236
237
238
                return False
        return True

Martin Reinecke's avatar
Martin Reinecke committed
239
    def s_any(self):
Martin Reinecke's avatar
Martin Reinecke committed
240
        for v in self._val:
Martin Reinecke's avatar
Martin Reinecke committed
241
            if v.s_any():
242
243
244
                return True
        return False

Martin Reinecke's avatar
Martin Reinecke committed
245
    def extract(self, subset):
Martin Reinecke's avatar
Martin Reinecke committed
246
247
248
249
        if subset is self._domain:
            return self
        return MultiField(subset,
                          tuple(self[key] for key in subset.keys()))
Martin Reinecke's avatar
Martin Reinecke committed
250

251
252
253
254
255
    def extract_part(self, subset):
        if subset is self._domain:
            return self
        return MultiField.from_dict({key: self[key] for key in subset.keys()
                                     if key in self})
Philipp Arras's avatar
Philipp Arras committed
256

Martin Reinecke's avatar
more    
Martin Reinecke committed
257
    def unite(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        """Merges two MultiFields on potentially different MultiDomains.

        Parameters
        ----------
        other : MultiField
            the partner Field

        Returns
        -------
        MultiField
            This MultiField's domain is the union of the input fields'
            domains. The values are the sum of the fields in self and other.
            If a field is not present, it is assumed to have an uniform value
            of zero.
        """
Martin Reinecke's avatar
Martin Reinecke committed
273
274
        if self._domain is other._domain:
            return self + other
Martin Reinecke's avatar
Martin Reinecke committed
275
276
277
        res = self.to_dict()
        for key, val in other.items():
            res[key] = res[key]+val if key in res else val
Martin Reinecke's avatar
Martin Reinecke committed
278
        return MultiField.from_dict(res)
Martin Reinecke's avatar
Martin Reinecke committed
279

280
281
    @staticmethod
    def union(fields, domain=None):
282
        """Returns the union of its input fields.
Martin Reinecke's avatar
doc fix    
Martin Reinecke committed
283

284
285
        Parameters
        ----------
Martin Reinecke's avatar
doc fix    
Martin Reinecke committed
286
        fields : iterable of MultiFields
Martin Reinecke's avatar
doc fix    
Martin Reinecke committed
287
            The set of input fields. Their domains need not be identical.
Martin Reinecke's avatar
doc fix    
Martin Reinecke committed
288
        domain : MultiDomain or None
Martin Reinecke's avatar
doc fix    
Martin Reinecke committed
289
            If supplied, this will be the domain of the resulting field.
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
            Providing this domain will accelerate the function.

        Returns
        -------
        MultiField
            The union of the input fields

        Notes
        -----
        If the same key occurs more than once in the input fields, the value
        associated with the last occurrence will be put into the output.
        No summation is performed!
        """
        res = {}
        for field in fields:
            res.update(field.to_dict())
        return MultiField.from_dict(res, domain)

Martin Reinecke's avatar
Martin Reinecke committed
308
    def flexible_addsub(self, other, neg):
Martin Reinecke's avatar
Martin Reinecke committed
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        """Merges two MultiFields on potentially different MultiDomains.

        Parameters
        ----------
        other : MultiField
            the partner Field
        neg : bool
            if True, the partner field is subtracted, otherwise added

        Returns
        -------
        MultiField
            This MultiField's domain is the union of the input fields'
            domains. The values are the sum (or difference, if neg==True) of
            the fields in self and other. If a field is not present, it is
            assumed to have an uniform value of zero.
        """
Martin Reinecke's avatar
Martin Reinecke committed
326
327
328
329
330
331
332
333
334
335
        if self._domain is other._domain:
            return self-other if neg else self+other
        res = self.to_dict()
        for key, val in other.items():
            if key in res:
                res[key] = res[key]-val if neg else res[key]+val
            else:
                res[key] = -val if neg else val
        return MultiField.from_dict(res)

Martin Reinecke's avatar
Martin Reinecke committed
336
337
338
339
340
341
342
343
344
345
    def _prep_args(self, args, kwargs, i):
        for arg in args + tuple(kwargs.values()):
            if not (arg is None or np.isscalar(arg) or arg.jac is None):
                raise TypeError("bad argument")
        argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i]
                        for arg in args)
        kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
                     for key, val in kwargs.items()}
        return argstmp, kwargstmp

Martin Reinecke's avatar
misc    
Martin Reinecke committed
346
347
348
    def ptw(self, op, *args, **kwargs):
        tmp = []
        for i in range(len(self._val)):
Martin Reinecke's avatar
Martin Reinecke committed
349
            argstmp, kwargstmp = self._prep_args(args, kwargs, i)
Martin Reinecke's avatar
misc    
Martin Reinecke committed
350
351
352
353
354
355
            tmp.append(self._val[i].ptw(op, *argstmp, **kwargstmp))
        return MultiField(self.domain, tuple(tmp))

    def ptw_with_deriv(self, op, *args, **kwargs):
        tmp = []
        for i in range(len(self._val)):
Martin Reinecke's avatar
Martin Reinecke committed
356
            argstmp, kwargstmp = self._prep_args(args, kwargs, i)
Martin Reinecke's avatar
misc    
Martin Reinecke committed
357
358
359
            tmp.append(self._val[i].ptw_with_deriv(op, *argstmp, **kwargstmp))
        return (MultiField(self.domain, tuple(v[0] for v in tmp)),
                MultiField(self.domain, tuple(v[1] for v in tmp)))
360

Martin Reinecke's avatar
Martin Reinecke committed
361
362
363
    def _binary_op(self, other, op):
        f = getattr(Field, op)
        if isinstance(other, MultiField):
Martin Reinecke's avatar
Martin Reinecke committed
364
            if self._domain != other._domain:
Martin Reinecke's avatar
Martin Reinecke committed
365
366
367
368
369
370
371
                raise ValueError("domain mismatch")
            val = tuple(f(v1, v2)
                        for v1, v2 in zip(self._val, other._val))
        else:
            val = tuple(f(v1, other) for v1 in self._val)
        return MultiField(self._domain, val)

Martin Reinecke's avatar
more    
Martin Reinecke committed
372

Martin Reinecke's avatar
Martin Reinecke committed
373
374
375
for op in ["__add__", "__radd__",
           "__sub__", "__rsub__",
           "__mul__", "__rmul__",
Martin Reinecke's avatar
Martin Reinecke committed
376
377
378
           "__truediv__", "__rtruediv__",
           "__floordiv__", "__rfloordiv__",
           "__pow__", "__rpow__",
Martin Reinecke's avatar
Martin Reinecke committed
379
380
381
           "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
    def func(op):
        def func2(self, other):
Martin Reinecke's avatar
Martin Reinecke committed
382
            return self._binary_op(other, op)
Martin Reinecke's avatar
Martin Reinecke committed
383
384
        return func2
    setattr(MultiField, op, func(op))
385

Martin Reinecke's avatar
Martin Reinecke committed
386

387
388
389
390
391
392
393
394
for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
           "__itruediv__", "__ifloordiv__", "__ipow__"]:
    def func(op):
        def func2(self, other):
            raise TypeError(
                "In-place operations are deliberately not supported")
        return func2
    setattr(MultiField, op, func(op))