Commit 31330874 authored by Theo Steininger's avatar Theo Steininger

Switched to NIFTy Field for MagneticField.

parent 107300e3
# -*- coding: utf-8 -*-
from likelihood import Likelihood
......@@ -14,14 +14,8 @@ class ConstantMagneticField(MagneticField):
return parameter_list
def _create_field(self):
distances = np.array(self.box_dimensions)/np.array(self.resolution)
space = RGSpace(shape=self.resolution,
distances=distances)
field_array = FieldArray(shape=(3,), dtype=np.float)
result_field = Field(domain=space, field_type=field_array)
result_field.val[:, :, :, 0] = self.parameters['b_x']
result_field.val[:, :, :, 1] = self.parameters['b_y']
result_field.val[:, :, :, 2] = self.parameters['b_z']
return result_field
val = self.cast(None)
val[:, :, :, :, 0] = self.parameters['b_x']
val[:, :, :, :, 1] = self.parameters['b_y']
val[:, :, :, :, 2] = self.parameters['b_z']
return val
......@@ -2,19 +2,24 @@
import numpy as np
from keepers import Loggable
from nifty import Field, FieldArray, RGSpace
class MagneticField(Loggable, object):
def __init__(self, box_dimensions, resolution, parameters):
class MagneticField(Field):
def __init__(self, parameters=[], domain=None, val=None, dtype=None,
distribution_strategy=None, copy=False):
self._box_dimensions = np.empty(3)
self._box_dimensions[:] = box_dimensions
self._box_dimensions = tuple(self._box_dimensions)
super(MagneticField, self).__init__(
domain=domain,
val=val,
dtype=dtype,
distribution_strategy=distribution_strategy,
copy=copy)
self._resolution = np.empty(3)
self._resolution[:] = resolution
self._resolution = tuple(self._resolution)
assert(len(self.domain) == 3)
assert(isinstance(self.domain[0], FieldArray))
assert(isinstance(self.domain[1], RGSpace))
assert(isinstance(self.domain[2], FieldArray))
self._parameters = {}
for p in self.parameter_list:
......@@ -28,23 +33,8 @@ class MagneticField(Loggable, object):
def parameters(self):
return self._parameters
@property
def field(self):
if self._field is None:
self._field = self._create_field()
return self._field
def _create_field(self):
raise NotImplementedError
@property
def box_dimensions(self):
return self._box_dimensions
@property
def resolution(self):
return self._resolution
@property
def descriptor(self):
return self._descriptor
def set_val(self, new_val=None, copy=False):
if new_val is not None:
raise RuntimeError("Setting the field values explicitly is not "
"supported by MagneticField.")
self._val = self._create_field()
......@@ -6,6 +6,8 @@ from keepers import Loggable
from imagine.carrier_mapper import carrier_mapper
from nifty import FieldArray, RGSpace
from magnetic_field import MagneticField
......@@ -18,6 +20,28 @@ class MagneticFieldFactory(Loggable, object):
self._variable_to_parameter_mappings = \
self._initial_variable_to_parameter_mappings
@property
def box_dimensions(self):
return self._box_dimensions
@box_dimensions.setter
def box_dimensions(self, box_dimensions):
dim = tuple(np.array(box_dimensions, dtype=np.float))
if len(dim) != 3:
raise ValueError("Input of box_dimensions must have length three.")
self._box_dimensions = dim
@property
def resolution(self):
return self._resolution
@resolution.setter
def resolution(self, resolution):
resolution = tuple(np.array(resolution, dtype=np.int))
if len(resolution) != 3:
raise ValueError("Input for resolution must have length three.")
self._resolution = resolution
@property
def magnetic_field_class(self):
return MagneticField
......@@ -81,18 +105,20 @@ class MagneticFieldFactory(Loggable, object):
parameter_dict[variable_name] = mapped_variable
return parameter_dict
@staticmethod
def _create_array(self):
raise NotImplementedError
def generate(self, variables={}):
def generate(self, variables={}, ensemble_size=1):
mapped_variables = self._map_variables_to_parameters(variables)
work_parameters = self.parameter_defaults.copy()
work_parameters.update(mapped_variables)
distances = np.array(self.box_dimensions) / np.array(self.resolution)
grid_space = RGSpace(shape=self.resolution,
distances=distances)
ensemble = FieldArray(shape=(ensemble_size,))
vector = FieldArray(shape=(3,))
domain = (ensemble, grid_space, vector)
result_magnetic_field = self.magnetic_field_class(
box_dimensions=self.box_dimensions,
resolution=self.resolution,
parameters=work_parameters)
domain=domain,
parameters=work_parameters)
return result_magnetic_field
# -*- coding: utf-8 -*-
from observer import Observer
from hammurapy import *
......@@ -21,6 +21,13 @@ class HammurapyBase(Observer):
self.last_call_log = ""
self.do_sync_emission = True
self.do_rm = True
self.do_dm = False
self.do_dust = False
self.do_tau = False
self.do_ff = False
self.basic_parameters = {'obs_shell_index_numb': '1',
'total_shell_numb': '1',
'obs_NSIDE': '128',
......@@ -34,30 +41,43 @@ class HammurapyBase(Observer):
'TE_nz': '80',
'B_field_do_random': 'T',
'B_ran_mem_lim': '4',
'do_sync_emission': 'T',
'do_rm': 'T',
'do_dm': 'F',
'do_dust': 'F',
'do_tau': 'F',
'do_ff': 'F'}
}
@abc.abstractproperty
def valid_magnetic_field_descriptor(self):
return []
def check_magnetic_field_descriptor(self, magnetic_field):
for d in self.valid_magnetic_field_descriptor:
if d not in magnetic_field.descriptor:
raise TypeError(
"Given magnetic field does not match the "
"needed descriptor of Hammurapy-Class: "
"%s vs. %s" % (str(self.valid_magnetic_field_descriptor),
str(magnetic_field.descriptor)))
def valid_magnetic_field_class(self):
return object
def _make_temp_folder(self):
prefix = os.path.join(self.working_directory, 'temp_hammurabi_')
return tempfile.mkdtemp(prefix=prefix)
def __call__(self, magnetic_field):
ensemble_number = magnetic_field.shape[0]
ensemble_space = magnetic_field.domain[0]
hp128 = HPSpace(nside=128)
result_observable = {}
if self.do_sync_emission:
result_observable['sync_emission'] = \
Field(domain=(ensemble_space, hp128, FieldArray((3,))))
if self.do_rm:
result_observable['rm'] = Field(domain=(ensemble_space, hp128))
if self.do_dm:
result_observable['dm'] = Field(domain=(hp128,))
if self.do_dust:
result_observable['dust'] = \
Field(domain=(ensemble_space, hp128, FieldArray((3,))))
if self.do_tau:
result_observable['tau'] = Field(domain=(ensemble_space, hp128,))
if self.do_ff:
result_observable['ff'] = Field(domain=(ensemble_space, hp128,))
# create dictionary for parameter file
# iterate over ensemble and put result into result_observable
###########
def _make_parameter_file(self, working_directory, resolution, dimensions,
......@@ -72,6 +92,14 @@ class HammurapyBase(Observer):
'B_field_nz': int(resolution[2]),
}
{
'do_sync_emission': 'T',
'do_rm': 'T',
'do_dm': 'F',
'do_dust': 'F',
'do_tau': 'F',
'do_ff': 'F'}
if self.parameters_dict['do_sync_emission'] == 'T':
obs_sync_file_name = os.path.join(working_directory,
'IQU_sync.fits')
......
# -*- coding: utf-8 -*-
from imagine.magnetic_fields.jf12_magnetic_field import JF12MagneticField
from hammurapy_base import HammurapyBase
......@@ -7,6 +9,4 @@ class HammurapyJF12(HammurapyBase):
@property
def valid_magnetic_field_descriptor(self):
d = super(HammurapyJF12, self).valid_magnetic_field_descriptor
d += ['JF12']
return d
return JF12MagneticField
# -*- coding: utf-8 -*-
from observer import Observer
......@@ -6,4 +6,3 @@ from keepers import Loggable
class Observer(Loggable, object):
def observe(magnetic_field):
raise NotImplementedError
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment