magnetic_field_factory.py 5.47 KB
Newer Older
1 2 3 4 5 6
# -*- coding: utf-8 -*-

import numpy as np

from keepers import Loggable

Theo Steininger's avatar
Theo Steininger committed
7
from imagine.carrier_mapper import unity_mapper
8

9 10
from nifty import FieldArray, RGSpace

11 12 13 14 15 16
from magnetic_field import MagneticField


class MagneticFieldFactory(Loggable, object):

    def __init__(self, box_dimensions, resolution):
Theo Steininger's avatar
Theo Steininger committed
17
        self.logger.debug("Setting up MagneticFieldFactory.")
18 19
        self.box_dimensions = box_dimensions
        self.resolution = resolution
20 21 22
        self._parameter_defaults = self._initial_parameter_defaults
        self._variable_to_parameter_mappings = \
            self._initial_variable_to_parameter_mappings
23

24 25 26 27 28 29 30 31 32 33 34 35
        distances = np.array(self.box_dimensions) / np.array(self.resolution)
        self._grid_space = RGSpace(shape=self.resolution,
                                   distances=distances)
        self._vector = FieldArray(shape=(3,))
        self._ensemble_cache = {}

    def _get_ensemble(self, ensemble_size):
        if ensemble_size not in self._ensemble_cache:
            self._ensemble_cache[ensemble_size] = \
                                FieldArray(shape=(ensemble_size,))
        return self._ensemble_cache[ensemble_size]

36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    @property
    def box_dimensions(self):
        return self._box_dimensions

    @box_dimensions.setter
    def box_dimensions(self, box_dimensions):
        dim = tuple(np.array(box_dimensions, dtype=np.float))
        if len(dim) != 3:
            raise ValueError("Input of box_dimensions must have length three.")
        self._box_dimensions = dim

    @property
    def resolution(self):
        return self._resolution

    @resolution.setter
    def resolution(self, resolution):
        resolution = tuple(np.array(resolution, dtype=np.int))
        if len(resolution) != 3:
            raise ValueError("Input for resolution must have length three.")
        self._resolution = resolution

58 59 60
    @property
    def magnetic_field_class(self):
        return MagneticField
61

62 63 64
    @property
    def _initial_parameter_defaults(self):
        return {}
65

66 67 68
    @property
    def _initial_variable_to_parameter_mappings(self):
        return {}
69 70 71

    @staticmethod
    def _interval(mean, sigma, n):
72
        return [mean-n*sigma, mean+n*sigma]
73 74 75

    @staticmethod
    def _positive_interval(mean, sigma, n):
76
        return [max(0, mean-n*sigma), mean+n*sigma]
77 78 79 80 81 82 83 84 85 86 87

    @property
    def parameter_defaults(self):
        return self._parameter_defaults

    @parameter_defaults.setter
    def parameter_defaults(self, new_defaults):
        self._parameter_defaults.update((str(k), np.float(v))
                                        for k, v in new_defaults.items()
                                        if k in self._parameter_defaults)

88 89 90 91 92 93 94 95 96
    @property
    def variable_defaults(self):
        variable_defaults = {}
        for parameter in self.parameter_defaults:
            low, high = self.variable_to_parameter_mappings[parameter]
            default = self.parameter_defaults[parameter]
            variable_defaults[parameter] = (default - low)/(high - low)
        return variable_defaults

97 98 99 100 101 102 103 104 105
    @property
    def variable_to_parameter_mappings(self):
        return self._variable_to_parameter_mappings

    @variable_to_parameter_mappings.setter
    def variable_to_parameter_mappings(self, new_mapping):
        """
        The parameter-mapping must be a dictionary with
        key: parameter-name
106
        value: [min, max]
107 108 109 110
        """
        for k, v in new_mapping.items():
            if k in self._variable_to_parameter_mappings:
                key = str(k)
Theo Steininger's avatar
Theo Steininger committed
111 112
                value = [np.float(v[0]), np.float(v[1])]
                self._variable_to_parameter_mappings.update({key: value})
113 114 115 116 117 118 119 120
                self.logger.debug("Updated variable_to_parameter_mapping %s "
                                  "to %s" % (key, str(value)))

    def _map_variables_to_parameters(self, variables):
        parameter_dict = {}
        for variable_name in variables:
            if variable_name in self.variable_to_parameter_mappings:
                mapping = self.variable_to_parameter_mappings[variable_name]
Theo Steininger's avatar
Theo Steininger committed
121 122
                mapped_variable = unity_mapper(variables[variable_name],
                                               a=mapping[0],
123
                                               b=mapping[1])
Theo Steininger's avatar
Theo Steininger committed
124 125 126 127
#                mapped_variable = carrier_mapper(variables[variable_name],
#                                                 a=mapping[0],
#                                                 m=mapping[1],
#                                                 b=mapping[2])
128 129 130 131 132
            else:
                mapped_variable = np.float(variables[variable_name])
            parameter_dict[variable_name] = mapped_variable
        return parameter_dict

133
    def generate(self, variables={}, ensemble_size=1, random_seed=None):
134 135 136 137
        mapped_variables = self._map_variables_to_parameters(variables)
        work_parameters = self.parameter_defaults.copy()
        work_parameters.update(mapped_variables)

138 139 140
        domain = (self._get_ensemble(ensemble_size),
                  self._grid_space,
                  self._vector)
141

142
        result_magnetic_field = self.magnetic_field_class(
143 144
                                              domain=domain,
                                              parameters=work_parameters,
145 146
                                              distribution_strategy='equal',
                                              random_seed=random_seed)
Theo Steininger's avatar
Theo Steininger committed
147 148
        self.logger.debug("Generated magnetic field with work-parameters %s" %
                          work_parameters)
149
        return result_magnetic_field