observable.py 1.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# -*- coding: utf-8 -*-


from nifty import Field, FieldArray


class Observable(Field):
    def __init__(self, domain=None, val=None, dtype=None,
                 distribution_strategy=None, copy=False):

        super(Observable, self).__init__(
                                domain=domain,
                                val=val,
                                dtype=dtype,
                                distribution_strategy=distribution_strategy,
                                copy=copy)

        assert(len(self.domain) == 2)
        assert(isinstance(self.domain[0], FieldArray))

    def ensemble_mean(self):
        try:
            self._ensemble_mean
        except(AttributeError):
            self._ensemble_mean = self.mean(spaces=0)
        finally:
            return self._ensemble_mean
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

    def _to_hdf5(self, hdf5_group):
        if self._ensemble_mean is not None:
            return_dict = {'ensemble_mean': self._ensemble_mean}
        else:
            return_dict = {}
        return_dict.update(
                   super(Observable, self)._to_hdf5(hdf5_group=hdf5_group))
        return return_dict

    @classmethod
    def _from_hdf5(cls, hdf5_group, repository):
        new_field = super(Observable, cls)._from_hdf5(hdf5_group=hdf5_group,
                                                      repository=repository)
        try:
            observable_mean = repository.get('ensemble_mean', hdf5_group)
            new_field._observable_mean = observable_mean
        except(KeyError):
            pass
        return new_field