prober.py 4.44 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Theo Steininger's avatar
Theo Steininger committed
13 14 15 16 17
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
Theo Steininger's avatar
Theo Steininger committed
18

Martin Reinecke's avatar
Martin Reinecke committed
19 20 21
from builtins import str
from builtins import range
from builtins import object
22 23
import numpy as np

24 25 26
from ...field import Field
from ... import nifty_utilities as utilities
from ... import nifty_configuration as nc
Theo Steininger's avatar
Theo Steininger committed
27 28 29 30

from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES


31
class Prober(object):
Theo Steininger's avatar
Theo Steininger committed
32
    """
33 34 35 36 37 38
    See the following webpages for the principles behind the usage of
    mixin-classes

    https://www.python.org/download/releases/2.2.3/descrintro/#cooperation
    https://rhettinger.wordpress.com/2011/05/26/super-considered-super/

Theo Steininger's avatar
Theo Steininger committed
39 40
    """

41
    def __init__(self, domain=None, distribution_strategy=None, probe_count=8,
42 43
                 random_type='pm1', probe_dtype=np.float,
                 compute_variance=False):
44

45
        self._domain = utilities.parse_domain(domain)
Theo Steininger's avatar
Theo Steininger committed
46 47
        self._distribution_strategy = \
            self._parse_distribution_strategy(distribution_strategy)
48 49
        self._probe_count = self._parse_probe_count(probe_count)
        self._random_type = self._parse_random_type(random_type)
Theo Steininger's avatar
Theo Steininger committed
50
        self.compute_variance = bool(compute_variance)
51
        self.probe_dtype = np.dtype(probe_dtype)
52
        self._uid_counter = 0
Theo Steininger's avatar
Theo Steininger committed
53

54
    # ---Properties---
55

56
    @property
Theo Steininger's avatar
Theo Steininger committed
57
    def domain(self):
58 59 60
        return self._domain

    @property
Theo Steininger's avatar
Theo Steininger committed
61
    def distribution_strategy(self):
62
        return self._distribution_strategy
Theo Steininger's avatar
Theo Steininger committed
63

Theo Steininger's avatar
Theo Steininger committed
64
    def _parse_distribution_strategy(self, distribution_strategy):
65 66 67 68
        if distribution_strategy is None:
            distribution_strategy = nc['default_distribution_strategy']
        else:
            distribution_strategy = str(distribution_strategy)
69
        if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']:
70 71
            raise ValueError("distribution_strategy must be a global-type "
                             "strategy.")
72
        self._distribution_strategy = distribution_strategy
73

Theo Steininger's avatar
Theo Steininger committed
74
    @property
75 76
    def probe_count(self):
        return self._probe_count
Theo Steininger's avatar
Theo Steininger committed
77

78 79
    def _parse_probe_count(self, probe_count):
        return int(probe_count)
80 81 82 83 84

    @property
    def random_type(self):
        return self._random_type

85
    def _parse_random_type(self, random_type):
86
        if random_type not in ["pm1", "normal"]:
87 88
            raise ValueError(
                "unsupported random type: '" + str(random_type) + "'.")
89
        return random_type
90 91 92

    # ---Probing methods---

93
    def probing_run(self, callee):
94
        """ controls the generation, evaluation and finalization of probes """
95
        self.reset()
Martin Reinecke's avatar
Martin Reinecke committed
96
        for index in range(self.probe_count):
97
            current_probe = self.get_probe(index)
98
            pre_result = self.process_probe(callee, current_probe, index)
99
            self.finish_probe(current_probe, pre_result)
100

101
    def reset(self):
102
        pass
103 104 105 106

    def get_probe(self, index):
        """ layer of abstraction for potential probe-caching """
        return self.generate_probe()
Theo Steininger's avatar
Theo Steininger committed
107

Theo Steininger's avatar
Theo Steininger committed
108 109
    def generate_probe(self):
        """ a random-probe generator """
110
        f = Field.from_random(random_type=self.random_type,
Theo Steininger's avatar
Theo Steininger committed
111
                              domain=self.domain,
112
                              dtype=self.probe_dtype,
Theo Steininger's avatar
Theo Steininger committed
113
                              distribution_strategy=self.distribution_strategy)
114 115
        uid = self._uid_counter
        self._uid_counter += 1
116
        return (uid, f)
Theo Steininger's avatar
Theo Steininger committed
117

118 119 120
    def process_probe(self, callee, probe, index):
        """ layer of abstraction for potential result-caching/recycling """
        return self.evaluate_probe(callee, probe[1])
121

122
    def evaluate_probe(self, callee, probe, **kwargs):
123
        """ processes a probe """
124
        return callee(probe, **kwargs)
125

Theo Steininger's avatar
Theo Steininger committed
126
    def finish_probe(self, probe, pre_result):
127
        pass
Theo Steininger's avatar
Theo Steininger committed
128

129 130
    def __call__(self, callee):
        return self.probing_run(callee)