Commit ef11cbe0 authored by Theo Steininger's avatar Theo Steininger

Added Sample class.

parent 5c0e173d
...@@ -10,5 +10,7 @@ from pymultinest_importer import pymultinest ...@@ -10,5 +10,7 @@ from pymultinest_importer import pymultinest
from pipeline import Pipeline from pipeline import Pipeline
from sample import Sample
import nifty import nifty
nifty.nifty_configuration['default_distribution_strategy'] = 'equal' nifty.nifty_configuration['default_distribution_strategy'] = 'equal'
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import simplejson as json
import numpy as np import numpy as np
from nifty import Field, FieldArray, RGSpace from nifty import Field, FieldArray, RGSpace
...@@ -25,11 +27,12 @@ class MagneticField(Field): ...@@ -25,11 +27,12 @@ class MagneticField(Field):
for p in self.parameter_list: for p in self.parameter_list:
self._parameters[p] = np.float(parameters[p]) self._parameters[p] = np.float(parameters[p])
self.random_seed = np.empty(self.shape[0], casted_random_seed = np.empty(self.shape[0],
if random_seed is None: if random_seed is None:
random_seed = np.random.randint(np.uint32(-1)/3, random_seed = np.random.randint(np.uint32(-1)/3,
size=self.shape[0]) size=self.shape[0])
self.random_seed[:] = random_seed casted_random_seed[:] = random_seed
self.random_seed = tuple(casted_random_seed)
@property @property
def parameter_list(self): def parameter_list(self):
...@@ -44,3 +47,16 @@ class MagneticField(Field): ...@@ -44,3 +47,16 @@ class MagneticField(Field):
raise RuntimeError("Setting the field values explicitly is not " raise RuntimeError("Setting the field values explicitly is not "
"supported by MagneticField.") "supported by MagneticField.")
self._val = self._create_field() self._val = self._create_field()
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['_parameters'] = json.dumps(self._parameters)
hdf5_group.create_dataset('random_seed', data=self.random_seed)
return super(MagneticField, self)._to_hdf5(hdf5_group=hdf5_group)
def _from_hdf5(cls, hdf5_group, repository):
new_field = super(MagneticField, cls)._from_hdf5(hdf5_group=hdf5_group,
new_field._parameters = json.loads(hdf5_group.attrs['_parameters'])
new_field.random_seed = tuple(hdf5_group['random_seed'])
return new_field
# -*- coding: utf-8 -*-
import simplejson as json
from keepers import Loggable,\
from nifty import Field
from imagine.magnetic_fields import MagneticField
class Sample(Loggable, Versionable, object):
def __init__(self):
self._variables = None
self._magnetic_field = None
self._observables = None
def variables(self):
return self._variables
def variables(self, variables):
self._variables = variables
def magnetic_field(self):
return self._magnetic_field
def magnetic_field(self, magnetic_field):
if not isinstance(magnetic_field, MagneticField):
raise TypeError("Input must be a MagneticField instance.")
self._magnetic_field = magnetic_field
def observables(self):
return self._observables
def observables(self, observables):
parsed_observables = {}
if not isinstance(observables, dict):
raise TypeError("Input must be a dict.")
for key, value in observables.iteritems():
if not isinstance(key, str):
raise TypeError("Observable name must be a string.")
if not isinstance(value, Field):
raise TypeError("Observable must be a NIFTy-Field.")
parsed_observables[key] = value
self._observables = parsed_observables
def _to_hdf5(self, hdf5_group):
if self._variables is not None:
hdf5_group.attrs['variables'] = json.dumps(self._variables)
return_dict = {}
if self._magnetic_field is not None:
return_dict['magnetic_field'] = self._magnetic_field
if self._observables is not None:
hdf5_group.attrs['observable_names'] = \
return return_dict
def _from_hdf5(cls, hdf5_group, repository):
new_sample = cls()
variables = hdf5_group.attrs['variables']
new_sample._variables = json.loads(variables)
magnetic_field = repository.get('magnetic_field', hdf5_group)
new_sample._magnetic_field = magnetic_field
observable_names = hdf5_group.attrs['observable_names']
observable_names = json.loads(observable_names)
observables = {}
for name in observable_names:
observables[name] = repository.get(name, hdf5_group)
new_sample._observables = observables
return new_sample
...@@ -25,6 +25,9 @@ setup(name = "imagine", ...@@ -25,6 +25,9 @@ setup(name = "imagine",
package_data={'': ['*.npy'], package_data={'': ['*.npy'],
'imagine.hammurapy': ['confs/*'],}, 'imagine.hammurapy': ['confs/*'],},
package_dir={"imagine": "imagine"}, package_dir={"imagine": "imagine"},
install_requires=['ift_nifty>=3.0.3', 'simplejson'],
zip_safe=False, zip_safe=False,
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment