Commit 212b66bb authored by Theo Steininger's avatar Theo Steininger

The MagneticFieldFactory now forwards a staticmethod from itself to...

The MagneticFieldFactory now forwards a staticmethod from itself to MagneticField which contains the _create_array functionality.
parent d1e73544
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import types
import numpy as np import numpy as np
from keepers import Loggable from keepers import Loggable
...@@ -7,7 +9,7 @@ from keepers import Loggable ...@@ -7,7 +9,7 @@ from keepers import Loggable
class MagneticField(Loggable, object): class MagneticField(Loggable, object):
def __init__(self, box_dimensions, resolution, descriptor, def __init__(self, box_dimensions, resolution, descriptor,
parameters={}): parameters={}, create_array=None):
self._box_dimensions = box_dimensions self._box_dimensions = box_dimensions
self._resolution = resolution self._resolution = resolution
...@@ -18,6 +20,8 @@ class MagneticField(Loggable, object): ...@@ -18,6 +20,8 @@ class MagneticField(Loggable, object):
self._parameters[str(key)] = np.float(value) self._parameters[str(key)] = np.float(value)
self._array = None self._array = None
if create_array is not None:
self._create_array = types.MethodType(create_array, self)
@property @property
def parameters(self): def parameters(self):
......
...@@ -79,7 +79,11 @@ class MagneticFieldFactory(Loggable, object): ...@@ -79,7 +79,11 @@ class MagneticFieldFactory(Loggable, object):
parameter_dict[variable_name] = mapped_variable parameter_dict[variable_name] = mapped_variable
return parameter_dict return parameter_dict
def generate(self, variables): @staticmethod
def _create_array(self):
raise NotImplementedError
def generate(self, variables={}):
mapped_variables = self._map_variables_to_parameters(variables) mapped_variables = self._map_variables_to_parameters(variables)
work_parameters = self.parameter_defaults.copy() work_parameters = self.parameter_defaults.copy()
work_parameters.update(mapped_variables) work_parameters.update(mapped_variables)
...@@ -88,6 +92,7 @@ class MagneticFieldFactory(Loggable, object): ...@@ -88,6 +92,7 @@ class MagneticFieldFactory(Loggable, object):
box_dimensions=self.box_dimensions, box_dimensions=self.box_dimensions,
resolution=self.resolution, resolution=self.resolution,
descriptor=self.descriptor, descriptor=self.descriptor,
parameters=work_parameters) parameters=work_parameters,
create_array=self._create_array)
return result_magnetic_field return result_magnetic_field
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment