field.py 31.1 KB
Newer Older
csongor's avatar
csongor committed
1 2 3
from __future__ import division
import numpy as np

4
from d2o import distributed_data_object,\
5
    STRATEGIES as DISTRIBUTION_STRATEGIES
csongor's avatar
csongor committed
6

7
from nifty.config import nifty_configuration as gc
csongor's avatar
csongor committed
8

9
from nifty.field_types import FieldType
10

11
from nifty.spaces.space import Space
12
from nifty.spaces.power_space import PowerSpace
csongor's avatar
csongor committed
13

csongor's avatar
csongor committed
14
import nifty.nifty_utilities as utilities
15 16
from nifty.random import Random

17 18 19
import logging
logger = logging.getLogger('NIFTy.Field')

csongor's avatar
csongor committed
20

21
class Field(object):
theos's avatar
theos committed
22
    # ---Initialization methods---
23

theos's avatar
theos committed
24
    def __init__(self, domain=None, val=None, dtype=None, field_type=None,
25
                 distribution_strategy=None, copy=False):
csongor's avatar
csongor committed
26

27
        self.domain = self._parse_domain(domain=domain, val=val)
28
        self.domain_axes = self._get_axes_tuple(self.domain)
csongor's avatar
csongor committed
29

30
        self.field_type = self._parse_field_type(field_type, val=val)
31

theos's avatar
theos committed
32 33 34 35 36 37
        try:
            start = len(reduce(lambda x, y: x+y, self.domain_axes))
        except TypeError:
            start = 0
        self.field_type_axes = self._get_axes_tuple(self.field_type,
                                                    start=start)
38

theos's avatar
theos committed
39
        self.dtype = self._infer_dtype(dtype=dtype,
Jait Dixit's avatar
Jait Dixit committed
40
                                       val=val,
theos's avatar
theos committed
41 42
                                       domain=self.domain,
                                       field_type=self.field_type)
43

44 45 46
        self.distribution_strategy = self._parse_distribution_strategy(
                                distribution_strategy=distribution_strategy,
                                val=val)
csongor's avatar
csongor committed
47 48 49

        self.set_val(new_val=val, copy=copy)

50
    def _parse_domain(self, domain, val=None):
51
        if domain is None:
52 53 54 55
            if isinstance(val, Field):
                domain = val.domain
            else:
                domain = ()
56
        elif isinstance(domain, Space):
57
            domain = (domain,)
58 59 60
        elif not isinstance(domain, tuple):
            domain = tuple(domain)

csongor's avatar
csongor committed
61
        for d in domain:
62
            if not isinstance(d, Space):
63 64 65
                raise TypeError(
                    "Given domain contains something that is not a "
                    "nifty.space.")
csongor's avatar
csongor committed
66 67
        return domain

68
    def _parse_field_type(self, field_type, val=None):
69
        if field_type is None:
70 71 72 73
            if isinstance(val, Field):
                field_type = val.field_type
            else:
                field_type = ()
74
        elif isinstance(field_type, FieldType):
75
            field_type = (field_type,)
76 77
        elif not isinstance(field_type, tuple):
            field_type = tuple(field_type)
78
        for ft in field_type:
79
            if not isinstance(ft, FieldType):
80 81
                raise TypeError(
                    "Given object is not a nifty.FieldType.")
82 83
        return field_type

theos's avatar
theos committed
84 85 86 87 88 89 90 91 92 93
    def _get_axes_tuple(self, things_with_shape, start=0):
        i = start
        axes_list = []
        for thing in things_with_shape:
            l = []
            for j in range(len(thing.shape)):
                l += [i]
                i += 1
            axes_list += [tuple(l)]
        return tuple(axes_list)
94

95
    def _infer_dtype(self, dtype, val, domain, field_type):
csongor's avatar
csongor committed
96
        if dtype is None:
97 98 99
            if isinstance(val, Field) or \
               isinstance(val, distributed_data_object):
                dtype = val.dtype
theos's avatar
theos committed
100 101 102 103 104 105 106
            dtype_tuple = (np.dtype(gc['default_field_dtype']),)
        else:
            dtype_tuple = (np.dtype(dtype),)
        if domain is not None:
            dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
        if field_type is not None:
            dtype_tuple += tuple(np.dtype(ft.dtype) for ft in field_type)
csongor's avatar
csongor committed
107

theos's avatar
theos committed
108
        dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
109

theos's avatar
theos committed
110
        return dtype
111

112 113
    def _parse_distribution_strategy(self, distribution_strategy, val):
        if distribution_strategy is None:
114
            if isinstance(val, distributed_data_object):
115
                distribution_strategy = val.distribution_strategy
116
            elif isinstance(val, Field):
117
                distribution_strategy = val.distribution_strategy
118
            else:
119
                logger.info("Datamodel set to default!")
120
                distribution_strategy = gc['default_distribution_strategy']
121
        elif distribution_strategy not in DISTRIBUTION_STRATEGIES['global']:
122 123 124
            raise ValueError(
                    "distribution_strategy must be a global-type "
                    "strategy.")
125
        return distribution_strategy
126 127

    # ---Factory methods---
128

129 130
    @classmethod
    def from_random(cls, random_type, domain=None, dtype=None, field_type=None,
131
                    distribution_strategy=None, **kwargs):
132 133
        # create a initially empty field
        f = cls(domain=domain, dtype=dtype, field_type=field_type,
134
                distribution_strategy=distribution_strategy)
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169

        # now use the processed input in terms of f in order to parse the
        # random arguments
        random_arguments = cls._parse_random_arguments(random_type=random_type,
                                                       f=f,
                                                       **kwargs)

        # extract the distributed_dato_object from f and apply the appropriate
        # random number generator to it
        sample = f.get_val(copy=False)
        generator_function = getattr(Random, random_type)
        sample.apply_generator(
            lambda shape: generator_function(dtype=f.dtype,
                                             shape=shape,
                                             **random_arguments))
        return f

    @staticmethod
    def _parse_random_arguments(random_type, f, **kwargs):

        if random_type == "pm1":
            random_arguments = {}

        elif random_type == "normal":
            mean = kwargs.get('mean', 0)
            std = kwargs.get('std', 1)
            random_arguments = {'mean': mean,
                                'std': std}

        elif random_type == "uniform":
            low = kwargs.get('low', 0)
            high = kwargs.get('high', 1)
            random_arguments = {'low': low,
                                'high': high}

csongor's avatar
csongor committed
170
        else:
171 172
            raise KeyError(
                "unsupported random key '" + str(random_type) + "'.")
csongor's avatar
csongor committed
173

174
        return random_arguments
csongor's avatar
csongor committed
175

176 177 178 179 180 181 182 183 184
    # ---Powerspectral methods---

    def power_analyze(self, spaces=None, log=False, nbin=None, binbounds=None,
                      real_signal=True):
        # assert that all spaces in `self.domain` are either harmonic or
        # power_space instances
        for sp in self.domain:
            if not sp.harmonic and not isinstance(sp, PowerSpace):
                raise AttributeError(
185
                    "Field has a space in `domain` which is neither "
186 187 188
                    "harmonic nor a PowerSpace.")

        # check if the `spaces` input is valid
189 190 191 192 193
        spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
        if spaces is None:
            if len(self.domain) == 1:
                spaces = (0,)
            else:
194 195 196
                raise ValueError(
                    "Field has multiple spaces as domain "
                    "but `spaces` is None.")
197 198

        if len(spaces) == 0:
199 200
            raise ValueError(
                "No space for analysis specified.")
201
        elif len(spaces) > 1:
202 203
            raise ValueError(
                "Conversion of only one space at a time is allowed.")
204 205 206 207

        space_index = spaces[0]

        if not self.domain[space_index].harmonic:
208 209
            raise ValueError(
                "The analyzed space must be harmonic.")
210

211 212 213 214 215 216
        # Create the target PowerSpace instance:
        # If the associated signal-space field was real, we extract the
        # hermitian and anti-hermitian parts of `self` and put them
        # into the real and imaginary parts of the power spectrum.
        # If it was complex, all the power is put into a real power spectrum.

217 218 219 220
        distribution_strategy = \
            self.val.get_axes_local_distribution_strategy(
                self.domain_axes[space_index])

221 222 223 224 225
        if real_signal:
            power_dtype = np.dtype('complex')
        else:
            power_dtype = np.dtype('float')

226 227
        harmonic_domain = self.domain[space_index]
        power_domain = PowerSpace(harmonic_domain=harmonic_domain,
228
                                  distribution_strategy=distribution_strategy,
229 230
                                  log=log, nbin=nbin, binbounds=binbounds,
                                  dtype=power_dtype)
231

232
        # extract pindex and rho from power_domain
233 234
        pindex = power_domain.pindex
        rho = power_domain.rho
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252

        if real_signal:
            hermitian_part, anti_hermitian_part = \
                harmonic_domain.hermitian_decomposition(
                                            self.val,
                                            axes=self.domain_axes[space_index])

            [hermitian_power, anti_hermitian_power] = \
                [self._calculate_power_spectrum(
                                            x=part,
                                            pindex=pindex,
                                            rho=rho,
                                            axes=self.domain_axes[space_index])
                 for part in [hermitian_part, anti_hermitian_part]]

            power_spectrum = hermitian_power + 1j * anti_hermitian_power
        else:
            power_spectrum = self._calculate_power_spectrum(
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
                                            x=self.val,
                                            pindex=pindex,
                                            rho=rho,
                                            axes=self.domain_axes[space_index])

        # create the result field and put power_spectrum into it
        result_domain = list(self.domain)
        result_domain[space_index] = power_domain

        result_field = self.copy_empty(domain=result_domain)
        result_field.set_val(new_val=power_spectrum, copy=False)

        return result_field

    def _calculate_power_spectrum(self, x, pindex, rho, axes=None):
        fieldabs = abs(x)
        fieldabs **= 2

        if axes is not None:
            pindex = self._shape_up_pindex(
                                    pindex=pindex,
                                    target_shape=x.shape,
                                    target_strategy=x.distribution_strategy,
                                    axes=axes)
        power_spectrum = pindex.bincount(weights=fieldabs,
                                         axis=axes)
        if axes is not None:
            new_rho_shape = [1, ] * len(power_spectrum.shape)
            new_rho_shape[axes[0]] = len(rho)
            rho = rho.reshape(new_rho_shape)
        power_spectrum /= rho

        power_spectrum **= 0.5
        return power_spectrum

    def _shape_up_pindex(self, pindex, target_shape, target_strategy, axes):
        if pindex.distribution_strategy not in \
                DISTRIBUTION_STRATEGIES['global']:
291
            raise ValueError("pindex's distribution strategy must be "
292 293 294 295 296 297
                             "global-type")

        if pindex.distribution_strategy in DISTRIBUTION_STRATEGIES['slicing']:
            if ((0 not in axes) or
                    (target_strategy is not pindex.distribution_strategy)):
                raise ValueError(
298
                    "A slicing distributor shall not be reshaped to "
299 300 301 302 303 304 305 306 307 308 309 310 311
                    "something non-sliced.")

        semiscaled_shape = [1, ] * len(target_shape)
        for i in axes:
            semiscaled_shape[i] = target_shape[i]
        local_data = pindex.get_local_data(copy=False)
        semiscaled_local_data = local_data.reshape(semiscaled_shape)
        result_obj = pindex.copy_empty(global_shape=target_shape,
                                       distribution_strategy=target_strategy)
        result_obj.set_full_data(semiscaled_local_data, copy=False)

        return result_obj

312 313
    def power_synthesize(self, spaces=None, real_signal=True,
                         mean=None, std=None):
314
        # assert that all spaces in `self.domain` are either of signal-type or
315 316
        # power_space instances
        for sp in self.domain:
317
            if not sp.harmonic and not isinstance(sp, PowerSpace):
318
                raise AttributeError(
319
                    "Field has a space in `domain` which is neither "
320 321
                    "harmonic nor a PowerSpace.")

322 323 324 325 326 327
        # check if the `spaces` input is valid
        spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
        if spaces is None:
            if len(self.domain) == 1:
                spaces = (0,)
            else:
328 329 330
                raise ValueError(
                    "Field has multiple spaces as domain "
                    "but `spaces` is None.")
331 332

        if len(spaces) == 0:
333 334
            raise ValueError(
                "No space for synthesis specified.")
335
        elif len(spaces) > 1:
336 337
            raise ValueError(
                "Conversion of only one space at a time is allowed.")
338 339 340 341

        power_space_index = spaces[0]
        power_domain = self.domain[power_space_index]
        if not isinstance(power_domain, PowerSpace):
342 343
            raise ValueError(
                "A PowerSpace is needed for field synthetization.")
344 345 346 347 348 349 350 351 352 353 354 355 356 357

        # create the result domain
        result_domain = list(self.domain)
        harmonic_domain = power_domain.harmonic_domain
        result_domain[power_space_index] = harmonic_domain

        # create random samples: one or two, depending on whether the
        # power spectrum is real or complex

        if issubclass(power_domain.dtype.type, np.complexfloating):
            result_list = [None, None]
        else:
            result_list = [None]

358 359
        result_list = [self.__class__.from_random(
                             'normal',
360 361 362
                             mean=mean,
                             std=std,
                             domain=result_domain,
363 364 365
                             dtype=harmonic_domain.dtype,
                             field_type=self.field_type,
                             distribution_strategy=self.distribution_strategy)
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
                       for x in result_list]

        # from now on extract the values from the random fields for further
        # processing without killing the fields.
        # if the signal-space field should be real, hermitianize the field
        # components
        if real_signal:
            result_val_list = [harmonic_domain.hermitian_decomposition(
                                    x.val,
                                    axes=x.domain_axes[power_space_index])[0]
                               for x in result_list]
        else:
            result_val_list = [x.val for x in result_list]

        # weight the random fields with the power spectrum
        # therefore get the pindex from the power space
        pindex = power_domain.pindex
        # take the local data from pindex. This data must be compatible to the
        # local data of the field given the slice of the PowerSpace
        local_distribution_strategy = \
            result_list[0].val.get_axes_local_distribution_strategy(
                result_list[0].domain_axes[power_space_index])

        if pindex.distribution_strategy is not local_distribution_strategy:
390 391
            logger.warn(
                "The distribution_stragey of pindex does not fit the "
392 393 394 395 396 397
                "slice_local distribution strategy of the synthesized field.")

        # Now use numpy advanced indexing in order to put the entries of the
        # power spectrum into the appropriate places of the pindex array.
        # Do this for every 'pindex-slice' in parallel using the 'slice(None)'s
        local_pindex = pindex.get_local_data(copy=False)
theos's avatar
theos committed
398
        full_spec = self.val.get_full_data()
399 400 401 402 403

        local_blow_up = [slice(None)]*len(self.shape)
        local_blow_up[self.domain_axes[power_space_index][0]] = local_pindex

        # here, the power_spectrum is distributed into the new shape
theos's avatar
theos committed
404
        local_rescaler = full_spec[local_blow_up]
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425

        # apply the rescaler to the random fields
        result_val_list[0].apply_scalar_function(
                                            lambda x: x * local_rescaler.real,
                                            inplace=True)

        if issubclass(power_domain.dtype.type, np.complexfloating):
            result_val_list[1].apply_scalar_function(
                                            lambda x: x * local_rescaler.imag,
                                            inplace=True)

        # store the result into the fields
        [x.set_val(new_val=y, copy=False) for x, y in
            zip(result_list, result_val_list)]

        if issubclass(power_domain.dtype.type, np.complexfloating):
            result = result_list[0] + 1j*result_list[1]
        else:
            result = result_list[0]

        return result
426

theos's avatar
theos committed
427
    # ---Properties---
428

theos's avatar
theos committed
429
    def set_val(self, new_val=None, copy=False):
430 431
        new_val = self.cast(new_val)
        if copy:
theos's avatar
theos committed
432 433 434
            new_val = new_val.copy()
        self._val = new_val
        return self._val
csongor's avatar
csongor committed
435

436 437
    def get_val(self, copy=False):
        if copy:
theos's avatar
theos committed
438
            return self._val.copy()
439
        else:
theos's avatar
theos committed
440
            return self._val
csongor's avatar
csongor committed
441

theos's avatar
theos committed
442 443 444
    @property
    def val(self):
        return self._val
csongor's avatar
csongor committed
445

theos's avatar
theos committed
446 447 448
    @val.setter
    def val(self, new_val):
        self._val = self.cast(new_val)
csongor's avatar
csongor committed
449

450 451
    @property
    def shape(self):
452 453 454 455 456 457 458
        shape_tuple = ()
        shape_tuple += tuple(sp.shape for sp in self.domain)
        shape_tuple += tuple(ft.shape for ft in self.field_type)
        try:
            global_shape = reduce(lambda x, y: x + y, shape_tuple)
        except TypeError:
            global_shape = ()
csongor's avatar
csongor committed
459

460
        return global_shape
csongor's avatar
csongor committed
461

462 463
    @property
    def dim(self):
theos's avatar
theos committed
464 465 466 467 468 469 470
        dim_tuple = ()
        dim_tuple += tuple(sp.dim for sp in self.domain)
        dim_tuple += tuple(ft.dim for ft in self.field_type)
        try:
            return reduce(lambda x, y: x * y, dim_tuple)
        except TypeError:
            return 0
csongor's avatar
csongor committed
471

472 473
    @property
    def dof(self):
theos's avatar
theos committed
474 475 476 477 478 479 480 481
        dof = self.dim
        if issubclass(self.dtype.type, np.complexfloating):
            dof *= 2
        return dof

    @property
    def total_volume(self):
        volume_tuple = tuple(sp.total_volume for sp in self.domain)
482
        try:
theos's avatar
theos committed
483
            return reduce(lambda x, y: x * y, volume_tuple)
484
        except TypeError:
theos's avatar
theos committed
485
            return 0
486

theos's avatar
theos committed
487
    # ---Special unary/binary operations---
488

csongor's avatar
csongor committed
489 490 491
    def cast(self, x=None, dtype=None):
        if dtype is None:
            dtype = self.dtype
492 493
        else:
            dtype = np.dtype(dtype)
494

495 496
        casted_x = x

497
        for ind, sp in enumerate(self.domain):
498
            casted_x = sp.pre_cast(casted_x,
499 500 501 502 503 504 505
                                   axes=self.domain_axes[ind])

        for ind, ft in enumerate(self.field_type):
            casted_x = ft.pre_cast(casted_x,
                                   axes=self.field_type_axes[ind])

        casted_x = self._actual_cast(casted_x, dtype=dtype)
506 507

        for ind, sp in enumerate(self.domain):
508 509
            casted_x = sp.post_cast(casted_x,
                                    axes=self.domain_axes[ind])
510 511

        for ind, ft in enumerate(self.field_type):
512 513
            casted_x = ft.post_cast(casted_x,
                                    axes=self.field_type_axes[ind])
514 515

        return casted_x
csongor's avatar
csongor committed
516

theos's avatar
theos committed
517
    def _actual_cast(self, x, dtype=None):
518
        if isinstance(x, Field):
csongor's avatar
csongor committed
519 520 521 522 523
            x = x.get_val()

        if dtype is None:
            dtype = self.dtype

524
        return_x = distributed_data_object(
525 526 527
                            global_shape=self.shape,
                            dtype=dtype,
                            distribution_strategy=self.distribution_strategy)
528 529
        return_x.set_full_data(x, copy=False)
        return return_x
theos's avatar
theos committed
530 531

    def copy(self, domain=None, dtype=None, field_type=None,
532
             distribution_strategy=None):
theos's avatar
theos committed
533
        copied_val = self.get_val(copy=True)
534 535 536 537 538
        new_field = self.copy_empty(
                                domain=domain,
                                dtype=dtype,
                                field_type=field_type,
                                distribution_strategy=distribution_strategy)
theos's avatar
theos committed
539 540
        new_field.set_val(new_val=copied_val, copy=False)
        return new_field
csongor's avatar
csongor committed
541

theos's avatar
theos committed
542
    def copy_empty(self, domain=None, dtype=None, field_type=None,
543
                   distribution_strategy=None):
theos's avatar
theos committed
544 545
        if domain is None:
            domain = self.domain
csongor's avatar
csongor committed
546
        else:
theos's avatar
theos committed
547
            domain = self._parse_domain(domain)
csongor's avatar
csongor committed
548

theos's avatar
theos committed
549 550 551 552
        if dtype is None:
            dtype = self.dtype
        else:
            dtype = np.dtype(dtype)
csongor's avatar
csongor committed
553

theos's avatar
theos committed
554 555 556 557
        if field_type is None:
            field_type = self.field_type
        else:
            field_type = self._parse_field_type(field_type)
csongor's avatar
csongor committed
558

559 560
        if distribution_strategy is None:
            distribution_strategy = self.distribution_strategy
csongor's avatar
csongor committed
561

theos's avatar
theos committed
562 563 564 565 566 567 568 569 570 571 572 573 574 575
        fast_copyable = True
        try:
            for i in xrange(len(self.domain)):
                if self.domain[i] is not domain[i]:
                    fast_copyable = False
                    break
            for i in xrange(len(self.field_type)):
                if self.field_type[i] is not field_type[i]:
                    fast_copyable = False
                    break
        except IndexError:
            fast_copyable = False

        if (fast_copyable and dtype == self.dtype and
576
                distribution_strategy == self.distribution_strategy):
theos's avatar
theos committed
577 578 579 580 581
            new_field = self._fast_copy_empty()
        else:
            new_field = Field(domain=domain,
                              dtype=dtype,
                              field_type=field_type,
582
                              distribution_strategy=distribution_strategy)
theos's avatar
theos committed
583
        return new_field
csongor's avatar
csongor committed
584

theos's avatar
theos committed
585 586 587 588 589 590 591
    def _fast_copy_empty(self):
        # make an empty field
        new_field = EmptyField()
        # repair its class
        new_field.__class__ = self.__class__
        # copy domain, codomain and val
        for key, value in self.__dict__.items():
592
            if key != '_val':
theos's avatar
theos committed
593 594 595 596 597 598
                new_field.__dict__[key] = value
            else:
                new_field.__dict__[key] = self.val.copy_empty()
        return new_field

    def weight(self, power=1, inplace=False, spaces=None):
599
        if inplace:
csongor's avatar
csongor committed
600 601 602 603
            new_field = self
        else:
            new_field = self.copy_empty()

604
        new_val = self.get_val(copy=False)
csongor's avatar
csongor committed
605

csongor's avatar
csongor committed
606
        if spaces is None:
theos's avatar
theos committed
607 608 609
            spaces = range(len(self.domain))
        else:
            spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
csongor's avatar
csongor committed
610

611
        for ind, sp in enumerate(self.domain):
theos's avatar
theos committed
612 613 614 615 616
            if ind in spaces:
                new_val = sp.weight(new_val,
                                    power=power,
                                    axes=self.domain_axes[ind],
                                    inplace=inplace)
617 618

        new_field.set_val(new_val=new_val, copy=False)
csongor's avatar
csongor committed
619 620
        return new_field

theos's avatar
theos committed
621 622 623 624 625 626 627 628 629
    def dot(self, x=None, bare=False):
        if isinstance(x, Field):
            try:
                assert len(x.domain) == len(self.domain)
                for index in xrange(len(self.domain)):
                    assert x.domain[index] == self.domain[index]
                for index in xrange(len(self.field_type)):
                    assert x.field_type[index] == self.field_type[index]
            except AssertionError:
630 631
                raise ValueError(
                    "domains are incompatible.")
theos's avatar
theos committed
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
            # extract the data from x and try to dot with this
            x = x.get_val(copy=False)

        # Compute the dot respecting the fact of discrete/continous spaces
        if bare:
            y = self
        else:
            y = self.weight(power=1)

        y = y.get_val(copy=False)

        # Cast the input in order to cure dtype and shape differences
        x = self.cast(x)

        dotted = x.conjugate() * y

        return dotted.sum()

650
    def norm(self, q=2):
csongor's avatar
csongor committed
651 652 653 654 655 656 657 658 659 660 661 662 663 664
        """
            Computes the Lq-norm of the field values.

            Parameters
            ----------
            q : scalar
                Parameter q of the Lq-norm (default: 2).

            Returns
            -------
            norm : scalar
                The Lq-norm of the field values.

        """
665
        if q == 2:
666
            return (self.dot(x=self)) ** (1 / 2)
csongor's avatar
csongor committed
667
        else:
668
            return self.dot(x=self ** (q - 1)) ** (1 / q)
csongor's avatar
csongor committed
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684

    def conjugate(self, inplace=False):
        """
            Computes the complex conjugate of the field.

            Returns
            -------
            cc : field
                The complex conjugated field.

        """
        if inplace:
            work_field = self
        else:
            work_field = self.copy_empty()

685
        new_val = self.get_val(copy=False)
theos's avatar
theos committed
686
        new_val = new_val.conjugate()
687
        work_field.set_val(new_val=new_val, copy=False)
csongor's avatar
csongor committed
688 689 690

        return work_field

theos's avatar
theos committed
691
    # ---General unary/contraction methods---
692

theos's avatar
theos committed
693 694
    def __pos__(self):
        return self.copy()
695

theos's avatar
theos committed
696 697 698 699
    def __neg__(self):
        return_field = self.copy_empty()
        new_val = -self.get_val(copy=False)
        return_field.set_val(new_val, copy=False)
csongor's avatar
csongor committed
700 701
        return return_field

theos's avatar
theos committed
702 703 704 705 706
    def __abs__(self):
        return_field = self.copy_empty()
        new_val = abs(self.get_val(copy=False))
        return_field.set_val(new_val, copy=False)
        return return_field
csongor's avatar
csongor committed
707

theos's avatar
theos committed
708 709 710 711 712 713
    def _contraction_helper(self, op, spaces, types):
        # build a list of all axes
        if spaces is None:
            spaces = xrange(len(self.domain))
        else:
            spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
csongor's avatar
csongor committed
714

theos's avatar
theos committed
715 716 717 718
        if types is None:
            types = xrange(len(self.field_type))
        else:
            types = utilities.cast_axis_to_tuple(types, len(self.field_type))
719

theos's avatar
theos committed
720 721 722 723
        axes_list = ()
        axes_list += tuple(self.domain_axes[sp_index] for sp_index in spaces)
        axes_list += tuple(self.field_type_axes[ft_index] for
                           ft_index in types)
724
        try:
theos's avatar
theos committed
725
            axes_list = reduce(lambda x, y: x+y, axes_list)
726
        except TypeError:
theos's avatar
theos committed
727
            axes_list = ()
csongor's avatar
csongor committed
728

theos's avatar
theos committed
729 730 731
        # perform the contraction on the d2o
        data = self.get_val(copy=False)
        data = getattr(data, op)(axis=axes_list)
csongor's avatar
csongor committed
732

theos's avatar
theos committed
733 734 735
        # check if the result is scalar or if a result_field must be constr.
        if np.isscalar(data):
            return data
csongor's avatar
csongor committed
736
        else:
theos's avatar
theos committed
737 738 739 740 741 742 743 744 745 746 747
            return_domain = tuple(self.domain[i]
                                  for i in xrange(len(self.domain))
                                  if i not in spaces)
            return_field_type = tuple(self.field_type[i]
                                      for i in xrange(len(self.field_type))
                                      if i not in types)
            return_field = Field(domain=return_domain,
                                 val=data,
                                 field_type=return_field_type,
                                 copy=False)
            return return_field
csongor's avatar
csongor committed
748

theos's avatar
theos committed
749 750
    def sum(self, spaces=None, types=None):
        return self._contraction_helper('sum', spaces, types)
csongor's avatar
csongor committed
751

theos's avatar
theos committed
752 753
    def prod(self, spaces=None, types=None):
        return self._contraction_helper('prod', spaces, types)
csongor's avatar
csongor committed
754

theos's avatar
theos committed
755 756
    def all(self, spaces=None, types=None):
        return self._contraction_helper('all', spaces, types)
csongor's avatar
csongor committed
757

theos's avatar
theos committed
758 759
    def any(self, spaces=None, types=None):
        return self._contraction_helper('any', spaces, types)
csongor's avatar
csongor committed
760

theos's avatar
theos committed
761 762
    def min(self, spaces=None, types=None):
        return self._contraction_helper('min', spaces, types)
csongor's avatar
csongor committed
763

theos's avatar
theos committed
764 765
    def nanmin(self, spaces=None, types=None):
        return self._contraction_helper('nanmin', spaces, types)
csongor's avatar
csongor committed
766

theos's avatar
theos committed
767 768
    def max(self, spaces=None, types=None):
        return self._contraction_helper('max', spaces, types)
csongor's avatar
csongor committed
769

theos's avatar
theos committed
770 771
    def nanmax(self, spaces=None, types=None):
        return self._contraction_helper('nanmax', spaces, types)
csongor's avatar
csongor committed
772

theos's avatar
theos committed
773 774
    def mean(self, spaces=None, types=None):
        return self._contraction_helper('mean', spaces, types)
csongor's avatar
csongor committed
775

theos's avatar
theos committed
776 777
    def var(self, spaces=None, types=None):
        return self._contraction_helper('var', spaces, types)
csongor's avatar
csongor committed
778

theos's avatar
theos committed
779 780
    def std(self, spaces=None, types=None):
        return self._contraction_helper('std', spaces, types)
csongor's avatar
csongor committed
781

theos's avatar
theos committed
782
    # ---General binary methods---
csongor's avatar
csongor committed
783

theos's avatar
theos committed
784
    def _binary_helper(self, other, op, inplace=False):
csongor's avatar
csongor committed
785
        # if other is a field, make sure that the domains match
786
        if isinstance(other, Field):
theos's avatar
theos committed
787 788 789 790
            try:
                assert len(other.domain) == len(self.domain)
                for index in xrange(len(self.domain)):
                    assert other.domain[index] == self.domain[index]
791
                assert len(other.field_type) == len(self.field_type)
theos's avatar
theos committed
792 793 794
                for index in xrange(len(self.field_type)):
                    assert other.field_type[index] == self.field_type[index]
            except AssertionError:
795 796
                raise ValueError(
                    "domains are incompatible.")
theos's avatar
theos committed
797
            other = other.get_val(copy=False)
csongor's avatar
csongor committed
798

theos's avatar
theos committed
799 800
        self_val = self.get_val(copy=False)
        return_val = getattr(self_val, op)(other)
csongor's avatar
csongor committed
801 802 803 804 805 806

        if inplace:
            working_field = self
        else:
            working_field = self.copy_empty()

theos's avatar
theos committed
807
        working_field.set_val(return_val, copy=False)
csongor's avatar
csongor committed
808 809 810
        return working_field

    def __add__(self, other):
theos's avatar
theos committed
811
        return self._binary_helper(other, op='__add__')
812

813
    def __radd__(self, other):
theos's avatar
theos committed
814
        return self._binary_helper(other, op='__radd__')
csongor's avatar
csongor committed
815 816

    def __iadd__(self, other):
theos's avatar
theos committed
817
        return self._binary_helper(other, op='__iadd__', inplace=True)
csongor's avatar
csongor committed
818 819

    def __sub__(self, other):
theos's avatar
theos committed
820
        return self._binary_helper(other, op='__sub__')
csongor's avatar
csongor committed
821 822

    def __rsub__(self, other):
theos's avatar
theos committed
823
        return self._binary_helper(other, op='__rsub__')
csongor's avatar
csongor committed
824 825

    def __isub__(self, other):
theos's avatar
theos committed
826
        return self._binary_helper(other, op='__isub__', inplace=True)
csongor's avatar
csongor committed
827 828

    def __mul__(self, other):
theos's avatar
theos committed
829
        return self._binary_helper(other, op='__mul__')
830

831
    def __rmul__(self, other):
theos's avatar
theos committed
832
        return self._binary_helper(other, op='__rmul__')
csongor's avatar
csongor committed
833 834

    def __imul__(self, other):
theos's avatar
theos committed
835
        return self._binary_helper(other, op='__imul__', inplace=True)
csongor's avatar
csongor committed
836 837

    def __div__(self, other):
theos's avatar
theos committed
838
        return self._binary_helper(other, op='__div__')
csongor's avatar
csongor committed
839 840

    def __rdiv__(self, other):
theos's avatar
theos committed
841
        return self._binary_helper(other, op='__rdiv__')
csongor's avatar
csongor committed
842 843

    def __idiv__(self, other):
theos's avatar
theos committed
844
        return self._binary_helper(other, op='__idiv__', inplace=True)
845

csongor's avatar
csongor committed
846
    def __pow__(self, other):
theos's avatar
theos committed
847
        return self._binary_helper(other, op='__pow__')
csongor's avatar
csongor committed
848 849

    def __rpow__(self, other):
theos's avatar
theos committed
850
        return self._binary_helper(other, op='__rpow__')
csongor's avatar
csongor committed
851 852

    def __ipow__(self, other):
theos's avatar
theos committed
853
        return self._binary_helper(other, op='__ipow__', inplace=True)
csongor's avatar
csongor committed
854 855

    def __lt__(self, other):
theos's avatar
theos committed
856
        return self._binary_helper(other, op='__lt__')
csongor's avatar
csongor committed
857 858

    def __le__(self, other):
theos's avatar
theos committed
859
        return self._binary_helper(other, op='__le__')
csongor's avatar
csongor committed
860 861 862 863 864

    def __ne__(self, other):
        if other is None:
            return True
        else:
theos's avatar
theos committed
865
            return self._binary_helper(other, op='__ne__')
csongor's avatar
csongor committed
866 867 868 869 870

    def __eq__(self, other):
        if other is None:
            return False
        else:
theos's avatar
theos committed
871
            return self._binary_helper(other, op='__eq__')
csongor's avatar
csongor committed
872 873

    def __ge__(self, other):
theos's avatar
theos committed
874
        return self._binary_helper(other, op='__ge__')
csongor's avatar
csongor committed
875 876

    def __gt__(self, other):
theos's avatar
theos committed
877 878 879 880 881 882 883 884 885 886 887 888 889
        return self._binary_helper(other, op='__gt__')

    def __repr__(self):
        return "<nifty_core.field>"

    def __str__(self):
        minmax = [self.min(), self.max()]
        mean = self.mean()
        return "nifty_core.field instance\n- domain      = " + \
               repr(self.domain) + \
               "\n- val         = " + repr(self.get_val()) + \
               "\n  - min.,max. = " + str(minmax) + \
               "\n  - mean = " + str(mean)
csongor's avatar
csongor committed
890

891

892
class EmptyField(Field):
csongor's avatar
csongor committed
893 894
    def __init__(self):
        pass