Commit b5abf435 authored by Theo Steininger's avatar Theo Steininger

Bug-fixes from code review.

parent 69eba271
from __future__ import division
import numpy as np
from keepers import Versionable
from keepers import Versionable,\
Loggable
from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES
......@@ -16,8 +17,6 @@ from nifty.spaces.power_space import PowerSpace
import nifty.nifty_utilities as utilities
from nifty.random import Random
from keepers import Loggable
class Field(Loggable, Versionable, object):
# ---Initialization methods---
......@@ -892,16 +891,14 @@ class Field(Loggable, Versionable, object):
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group['dtype'] = self.dtype.name
hdf5_group['distribution_strategy'] = self.distribution_strategy
hdf5_group['field_type_axes'] = str(self.field_type_axes)
hdf5_group['domain_axes'] = str(self.domain_axes)
hdf5_group.attrs['dtype'] = self.dtype.name
hdf5_group.attrs['distribution_strategy'] = self.distribution_strategy
hdf5_group.attrs['field_type_axes'] = str(self.field_type_axes)
hdf5_group.attrs['domain_axes'] = str(self.domain_axes)
hdf5_group['num_domain'] = len(self.domain)
hdf5_group['num_ft'] = len(self.field_type)
ret_dict = {
'val' : self.val
}
ret_dict = {'val': self.val}
for i in range(len(self.domain)):
ret_dict['s_' + str(i)] = self.domain[i]
......@@ -911,9 +908,8 @@ class Field(Loggable, Versionable, object):
return ret_dict
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
def _from_hdf5(cls, hdf5_group, repository):
# create empty field
new_field = EmptyField()
# reset class
......@@ -921,22 +917,25 @@ class Field(Loggable, Versionable, object):
# set values
temp_domain = []
for i in range(hdf5_group['num_domain'][()]):
temp_domain.append(loopback_get('s_' + str(i)))
temp_domain.append(repository.get('s_' + str(i), hdf5_group))
new_field.domain = tuple(temp_domain)
temp_ft = []
for i in range(hdf5_group['num_ft'][()]):
temp_domain.append(loopback_get('ft_' + str(i)))
temp_domain.append(repository.get('ft_' + str(i), hdf5_group))
new_field.field_type = tuple(temp_ft)
exec('new_field.domain_axes = ' + hdf5_group['domain_axes'][()])
exec('new_field.field_type_axes = ' + hdf5_group['field_type_axes'][()])
new_field._val = loopback_get('val')
new_field.dtype = np.dtype(hdf5_group['dtype'][()])
new_field.distribution_strategy = hdf5_group['distribution_strategy'][()]
exec('new_field.domain_axes = ' + hdf5_group.attrs['domain_axes'])
exec('new_field.field_type_axes = ' +
hdf5_group.attrs['field_type_axes'])
new_field._val = repository.get('val', hdf5_group)
new_field.dtype = np.dtype(hdf5_group.attrs['dtype'])
new_field.distribution_strategy =\
hdf5_group.attrs['distribution_strategy']
return new_field
class EmptyField(Field):
def __init__(self):
pass
......@@ -5,7 +5,6 @@ import numpy as np
import d2o
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
from keepers import Versionable
from nifty.spaces.space import Space
from nifty.config import nifty_configuration as gc,\
......@@ -17,7 +16,7 @@ gl = gdi.get('libsharp_wrapper_gl')
GL_DISTRIBUTION_STRATEGIES = DISTRIBUTION_STRATEGIES['global']
class GLSpace(Versionable, Space):
class GLSpace(Space):
"""
.. __
.. / /
......@@ -220,7 +219,7 @@ class GLSpace(Versionable, Space):
return None
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
def _from_hdf5(cls, hdf5_group, repository):
result = cls(
nlat=hdf5_group['nlat'][()],
nlon=hdf5_group['nlon'][()],
......
......@@ -36,7 +36,6 @@ from __future__ import division
import numpy as np
import d2o
from keepers import Versionable
from nifty.spaces.space import Space
from nifty.config import nifty_configuration as gc, \
......@@ -45,7 +44,7 @@ from nifty.config import nifty_configuration as gc, \
hp = gdi.get('healpy')
class HPSpace(Versionable, Space):
class HPSpace(Space):
"""
.. __
.. / /
......@@ -213,7 +212,7 @@ class HPSpace(Versionable, Space):
return None
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
def _from_hdf5(cls, hdf5_group, repository):
result = cls(
nside=hdf5_group['nside'][()],
dtype=np.dtype(hdf5_group['dtype'][()])
......
......@@ -2,8 +2,6 @@ from __future__ import division
import numpy as np
from keepers import Versionable
from nifty.spaces.space import Space
from nifty.config import nifty_configuration as gc,\
......@@ -17,7 +15,7 @@ gl = gdi.get('libsharp_wrapper_gl')
hp = gdi.get('healpy')
class LMSpace(Versionable, Space):
class LMSpace(Space):
"""
.. __
.. / /
......@@ -191,7 +189,7 @@ class LMSpace(Versionable, Space):
return None
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
def _from_hdf5(cls, hdf5_group, repository):
result = cls(
lmax=hdf5_group['lmax'][()],
dtype=np.dtype(hdf5_group['dtype'][()])
......
......@@ -2,8 +2,6 @@
import numpy as np
from keepers import Versionable
import d2o
from power_index_factory import PowerIndexFactory
......@@ -13,7 +11,7 @@ from nifty.spaces.rg_space import RGSpace
from nifty.nifty_utilities import cast_axis_to_tuple
class PowerSpace(Versionable, Space):
class PowerSpace(Space):
# ---Overwritten properties and methods---
......@@ -163,10 +161,11 @@ class PowerSpace(Versionable, Space):
hdf5_group['kindex'] = self.kindex
hdf5_group['rho'] = self.rho
hdf5_group['pundex'] = self.pundex
hdf5_group['dtype'] = self.dtype.name
hdf5_group.attrs['dtype'] = self.dtype.name
hdf5_group['log'] = self.log
hdf5_group['nbin'] = str(self.nbin)
hdf5_group['binbounds'] = str(self.binbounds)
# Store nbin as string, since it can be None
hdf5_group.attrs['nbin'] = str(self.nbin)
hdf5_group.attrs['binbounds'] = str(self.binbounds)
return {
'harmonic_domain': self.harmonic_domain,
......@@ -175,23 +174,23 @@ class PowerSpace(Versionable, Space):
}
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
def _from_hdf5(cls, hdf5_group, repository):
# make an empty PowerSpace object
new_ps = EmptyPowerSpace()
# reset class
new_ps.__class__ = cls
# set all values
new_ps.dtype = np.dtype(hdf5_group['dtype'][()])
new_ps._harmonic_domain = loopback_get('harmonic_domain')
new_ps.dtype = np.dtype(hdf5_group.attrs['dtype'])
new_ps._harmonic_domain = repository.get('harmonic_domain', hdf5_group)
new_ps._log = hdf5_group['log'][()]
exec('new_ps._nbin = ' + hdf5_group['nbin'][()])
exec('new_ps._binbounds = ' + hdf5_group['binbounds'][()])
exec('new_ps._nbin = ' + hdf5_group.attrs['nbin'])
exec('new_ps._binbounds = ' + hdf5_group.attrs['binbounds'])
new_ps._pindex = loopback_get('pindex')
new_ps._pindex = repository.get('pindex', hdf5_group)
new_ps._kindex = hdf5_group['kindex'][:]
new_ps._rho = hdf5_group['rho'][:]
new_ps._pundex = hdf5_group['pundex'][:]
new_ps._k_array = loopback_get('k_array')
new_ps._k_array = repository.get('k_array', hdf5_group)
return new_ps
......
......@@ -35,15 +35,13 @@ from __future__ import division
import numpy as np
from keepers import Versionable
from d2o import distributed_data_object,\
STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.spaces.space import Space
class RGSpace(Versionable, Space):
class RGSpace(Space):
"""
.. _____ _______
.. / __/ / _ /
......@@ -329,17 +327,17 @@ class RGSpace(Versionable, Space):
hdf5_group['zerocenter'] = self.zerocenter
hdf5_group['distances'] = self.distances
hdf5_group['harmonic'] = self.harmonic
hdf5_group['dtype'] = self.dtype.name
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
def _from_hdf5(cls, hdf5_group, repository):
result = cls(
shape=hdf5_group['shape'][:],
zerocenter=hdf5_group['zerocenter'][:],
distances=hdf5_group['distances'][:],
harmonic=hdf5_group['harmonic'][()],
dtype=np.dtype(hdf5_group['dtype'][()])
dtype=np.dtype(hdf5_group.attrs['dtype'])
)
return result
......@@ -146,10 +146,11 @@ import abc
import numpy as np
from keepers import Loggable
from keepers import Loggable,\
Versionable
class Space(Loggable, object):
class Space(Versionable, Loggable, object):
"""
.. __ __
.. /__/ / /_
......@@ -205,12 +206,13 @@ class Space(Loggable, object):
# parse dtype
self.dtype = np.dtype(dtype)
self._ignore_for_hash = []
self._ignore_for_hash = ['_global_id']
def __hash__(self):
# Extract the identifying parts from the vars(self) dict.
result_hash = 0
for (key, item) in vars(self).items():
for key in sorted(vars(self).keys()):
item = vars(self)[key]
if key in self._ignore_for_hash or key == '_ignore_for_hash':
continue
result_hash ^= item.__hash__() ^ int(hash(key)/117)
......@@ -290,3 +292,15 @@ class Space(Loggable, object):
string += str(type(self)) + "\n"
string += "dtype: " + str(self.dtype) + "\n"
return string
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, repository):
result = cls(dtype=np.dtype(hdf5_group.attrs['dtype']))
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment