Commit 37582f42 authored by Theo Steininger's avatar Theo Steininger
Browse files

Unified spaces and field_types into single domain object.

parent e38eae92
Pipeline #9997 passed with stage
in 33 minutes and 35 seconds
# -*- coding: utf-8 -*-
import abc
import numpy as np
from keepers import Loggable,\
Versionable
class DomainObject(Versionable, Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, dtype):
self._dtype = np.dtype(dtype)
self._ignore_for_hash = []
def __hash__(self):
# Extract the identifying parts from the vars(self) dict.
result_hash = 0
for key in sorted(vars(self).keys()):
item = vars(self)[key]
if key in self._ignore_for_hash or key == '_ignore_for_hash':
continue
result_hash ^= item.__hash__() ^ int(hash(key)/117)
return result_hash
def __eq__(self, x):
if isinstance(x, type(self)):
return hash(self) == hash(x)
else:
return False
def __ne__(self, x):
return not self.__eq__(x)
@property
def dtype(self):
return self._dtype
@abc.abstractproperty
def shape(self):
raise NotImplementedError(
"There is no generic shape for DomainObject.")
@abc.abstractproperty
def dim(self):
raise NotImplementedError(
"There is no generic dim for DomainObject.")
def pre_cast(self, x, axes=None):
return x
def post_cast(self, x, axes=None):
return x
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, repository):
result = cls(dtype=np.dtype(hdf5_group.attrs['dtype']))
return result
...@@ -9,9 +9,8 @@ from d2o import distributed_data_object,\ ...@@ -9,9 +9,8 @@ from d2o import distributed_data_object,\
from nifty.config import nifty_configuration as gc from nifty.config import nifty_configuration as gc
from nifty.field_types import FieldType from nifty.domain_object import DomainObject
from nifty.spaces.space import Space
from nifty.spaces.power_space import PowerSpace from nifty.spaces.power_space import PowerSpace
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
...@@ -21,25 +20,15 @@ from nifty.random import Random ...@@ -21,25 +20,15 @@ from nifty.random import Random
class Field(Loggable, Versionable, object): class Field(Loggable, Versionable, object):
# ---Initialization methods--- # ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, field_type=None, def __init__(self, domain=None, val=None, dtype=None,
distribution_strategy=None, copy=False): distribution_strategy=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val) self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain) self.domain_axes = self._get_axes_tuple(self.domain)
self.field_type = self._parse_field_type(field_type, val=val)
try:
start = len(reduce(lambda x, y: x+y, self.domain_axes))
except TypeError:
start = 0
self.field_type_axes = self._get_axes_tuple(self.field_type,
start=start)
self.dtype = self._infer_dtype(dtype=dtype, self.dtype = self._infer_dtype(dtype=dtype,
val=val, val=val,
domain=self.domain, domain=self.domain)
field_type=self.field_type)
self.distribution_strategy = self._parse_distribution_strategy( self.distribution_strategy = self._parse_distribution_strategy(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
...@@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object): ...@@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object):
domain = val.domain domain = val.domain
else: else:
domain = () domain = ()
elif isinstance(domain, Space): elif isinstance(domain, DomainObject):
domain = (domain,) domain = (domain,)
elif not isinstance(domain, tuple): elif not isinstance(domain, tuple):
domain = tuple(domain) domain = tuple(domain)
for d in domain: for d in domain:
if not isinstance(d, Space): if not isinstance(d, DomainObject):
raise TypeError( raise TypeError(
"Given domain contains something that is not a " "Given domain contains something that is not a "
"nifty.space.") "DomainObject instance.")
return domain return domain
def _parse_field_type(self, field_type, val=None):
if field_type is None:
if isinstance(val, Field):
field_type = val.field_type
else:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
def _get_axes_tuple(self, things_with_shape, start=0): def _get_axes_tuple(self, things_with_shape, start=0):
i = start i = start
axes_list = [] axes_list = []
...@@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object): ...@@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object):
axes_list += [tuple(l)] axes_list += [tuple(l)]
return tuple(axes_list) return tuple(axes_list)
def _infer_dtype(self, dtype, val, domain, field_type): def _infer_dtype(self, dtype, val, domain):
if dtype is None: if dtype is None:
if isinstance(val, Field) or \ if isinstance(val, Field) or \
isinstance(val, distributed_data_object): isinstance(val, distributed_data_object):
...@@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object): ...@@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object):
dtype_tuple = (np.dtype(dtype),) dtype_tuple = (np.dtype(dtype),)
if domain is not None: if domain is not None:
dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain) dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
if field_type is not None:
dtype_tuple += tuple(np.dtype(ft.dtype) for ft in field_type)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple) dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
...@@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object): ...@@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object):
# ---Factory methods--- # ---Factory methods---
@classmethod @classmethod
def from_random(cls, random_type, domain=None, dtype=None, field_type=None, def from_random(cls, random_type, domain=None, dtype=None,
distribution_strategy=None, **kwargs): distribution_strategy=None, **kwargs):
# create a initially empty field # create a initially empty field
f = cls(domain=domain, dtype=dtype, field_type=field_type, f = cls(domain=domain, dtype=dtype,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
# now use the processed input in terms of f in order to parse the # now use the processed input in terms of f in order to parse the
...@@ -363,7 +334,6 @@ class Field(Loggable, Versionable, object): ...@@ -363,7 +334,6 @@ class Field(Loggable, Versionable, object):
std=std, std=std,
domain=result_domain, domain=result_domain,
dtype=harmonic_domain.dtype, dtype=harmonic_domain.dtype,
field_type=self.field_type,
distribution_strategy=self.distribution_strategy) distribution_strategy=self.distribution_strategy)
for x in result_list] for x in result_list]
...@@ -451,9 +421,7 @@ class Field(Loggable, Versionable, object): ...@@ -451,9 +421,7 @@ class Field(Loggable, Versionable, object):
@property @property
def shape(self): def shape(self):
shape_tuple = () shape_tuple = tuple(sp.shape for sp in self.domain)
shape_tuple += tuple(sp.shape for sp in self.domain)
shape_tuple += tuple(ft.shape for ft in self.field_type)
try: try:
global_shape = reduce(lambda x, y: x + y, shape_tuple) global_shape = reduce(lambda x, y: x + y, shape_tuple)
except TypeError: except TypeError:
...@@ -463,9 +431,7 @@ class Field(Loggable, Versionable, object): ...@@ -463,9 +431,7 @@ class Field(Loggable, Versionable, object):
@property @property
def dim(self): def dim(self):
dim_tuple = () dim_tuple = tuple(sp.dim for sp in self.domain)
dim_tuple += tuple(sp.dim for sp in self.domain)
dim_tuple += tuple(ft.dim for ft in self.field_type)
try: try:
return reduce(lambda x, y: x * y, dim_tuple) return reduce(lambda x, y: x * y, dim_tuple)
except TypeError: except TypeError:
...@@ -500,20 +466,12 @@ class Field(Loggable, Versionable, object): ...@@ -500,20 +466,12 @@ class Field(Loggable, Versionable, object):
casted_x = sp.pre_cast(casted_x, casted_x = sp.pre_cast(casted_x,
axes=self.domain_axes[ind]) axes=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type):
casted_x = ft.pre_cast(casted_x,
axes=self.field_type_axes[ind])
casted_x = self._actual_cast(casted_x, dtype=dtype) casted_x = self._actual_cast(casted_x, dtype=dtype)
for ind, sp in enumerate(self.domain): for ind, sp in enumerate(self.domain):
casted_x = sp.post_cast(casted_x, casted_x = sp.post_cast(casted_x,
axes=self.domain_axes[ind]) axes=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type):
casted_x = ft.post_cast(casted_x,
axes=self.field_type_axes[ind])
return casted_x return casted_x
def _actual_cast(self, x, dtype=None): def _actual_cast(self, x, dtype=None):
...@@ -530,19 +488,16 @@ class Field(Loggable, Versionable, object): ...@@ -530,19 +488,16 @@ class Field(Loggable, Versionable, object):
return_x.set_full_data(x, copy=False) return_x.set_full_data(x, copy=False)
return return_x return return_x
def copy(self, domain=None, dtype=None, field_type=None, def copy(self, domain=None, dtype=None, distribution_strategy=None):
distribution_strategy=None):
copied_val = self.get_val(copy=True) copied_val = self.get_val(copy=True)
new_field = self.copy_empty( new_field = self.copy_empty(
domain=domain, domain=domain,
dtype=dtype, dtype=dtype,
field_type=field_type,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
new_field.set_val(new_val=copied_val, copy=False) new_field.set_val(new_val=copied_val, copy=False)
return new_field return new_field
def copy_empty(self, domain=None, dtype=None, field_type=None, def copy_empty(self, domain=None, dtype=None, distribution_strategy=None):
distribution_strategy=None):
if domain is None: if domain is None:
domain = self.domain domain = self.domain
else: else:
...@@ -553,11 +508,6 @@ class Field(Loggable, Versionable, object): ...@@ -553,11 +508,6 @@ class Field(Loggable, Versionable, object):
else: else:
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if field_type is None:
field_type = self.field_type
else:
field_type = self._parse_field_type(field_type)
if distribution_strategy is None: if distribution_strategy is None:
distribution_strategy = self.distribution_strategy distribution_strategy = self.distribution_strategy
...@@ -567,10 +517,6 @@ class Field(Loggable, Versionable, object): ...@@ -567,10 +517,6 @@ class Field(Loggable, Versionable, object):
if self.domain[i] is not domain[i]: if self.domain[i] is not domain[i]:
fast_copyable = False fast_copyable = False
break break
for i in xrange(len(self.field_type)):
if self.field_type[i] is not field_type[i]:
fast_copyable = False
break
except IndexError: except IndexError:
fast_copyable = False fast_copyable = False
...@@ -580,7 +526,6 @@ class Field(Loggable, Versionable, object): ...@@ -580,7 +526,6 @@ class Field(Loggable, Versionable, object):
else: else:
new_field = Field(domain=domain, new_field = Field(domain=domain,
dtype=dtype, dtype=dtype,
field_type=field_type,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
return new_field return new_field
...@@ -626,8 +571,6 @@ class Field(Loggable, Versionable, object): ...@@ -626,8 +571,6 @@ class Field(Loggable, Versionable, object):
assert len(x.domain) == len(self.domain) assert len(x.domain) == len(self.domain)
for index in xrange(len(self.domain)): for index in xrange(len(self.domain)):
assert x.domain[index] == self.domain[index] assert x.domain[index] == self.domain[index]
for index in xrange(len(self.field_type)):
assert x.field_type[index] == self.field_type[index]
except AssertionError: except AssertionError:
raise ValueError( raise ValueError(
"domains are incompatible.") "domains are incompatible.")
...@@ -707,22 +650,15 @@ class Field(Loggable, Versionable, object): ...@@ -707,22 +650,15 @@ class Field(Loggable, Versionable, object):
return_field.set_val(new_val, copy=False) return_field.set_val(new_val, copy=False)
return return_field return return_field
def _contraction_helper(self, op, spaces, types): def _contraction_helper(self, op, spaces):
# build a list of all axes # build a list of all axes
if spaces is None: if spaces is None:
spaces = xrange(len(self.domain)) spaces = xrange(len(self.domain))
else: else:
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
if types is None: axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces)
types = xrange(len(self.field_type))
else:
types = utilities.cast_axis_to_tuple(types, len(self.field_type))
axes_list = ()
axes_list += tuple(self.domain_axes[sp_index] for sp_index in spaces)
axes_list += tuple(self.field_type_axes[ft_index] for
ft_index in types)
try: try:
axes_list = reduce(lambda x, y: x+y, axes_list) axes_list = reduce(lambda x, y: x+y, axes_list)
except TypeError: except TypeError:
...@@ -739,47 +675,44 @@ class Field(Loggable, Versionable, object): ...@@ -739,47 +675,44 @@ class Field(Loggable, Versionable, object):
return_domain = tuple(self.domain[i] return_domain = tuple(self.domain[i]
for i in xrange(len(self.domain)) for i in xrange(len(self.domain))
if i not in spaces) if i not in spaces)
return_field_type = tuple(self.field_type[i]
for i in xrange(len(self.field_type))
if i not in types)
return_field = Field(domain=return_domain, return_field = Field(domain=return_domain,
val=data, val=data,
field_type=return_field_type,
copy=False) copy=False)
return return_field return return_field
def sum(self, spaces=None, types=None): def sum(self, spaces=None):
return self._contraction_helper('sum', spaces, types) return self._contraction_helper('sum', spaces)
def prod(self, spaces=None, types=None): def prod(self, spaces=None):
return self._contraction_helper('prod', spaces, types) return self._contraction_helper('prod', spaces)
def all(self, spaces=None, types=None): def all(self, spaces=None):
return self._contraction_helper('all', spaces, types) return self._contraction_helper('all', spaces)
def any(self, spaces=None, types=None): def any(self, spaces=None):
return self._contraction_helper('any', spaces, types) return self._contraction_helper('any', spaces)
def min(self, spaces=None, types=None): def min(self, spaces=None):
return self._contraction_helper('min', spaces, types) return self._contraction_helper('min', spaces)
def nanmin(self, spaces=None, types=None): def nanmin(self, spaces=None):
return self._contraction_helper('nanmin', spaces, types) return self._contraction_helper('nanmin', spaces)
def max(self, spaces=None, types=None): def max(self, spaces=None):
return self._contraction_helper('max', spaces, types) return self._contraction_helper('max', spaces)
def nanmax(self, spaces=None, types=None): def nanmax(self, spaces=None):
return self._contraction_helper('nanmax', spaces, types) return self._contraction_helper('nanmax', spaces)
def mean(self, spaces=None, types=None): def mean(self, spaces=None):
return self._contraction_helper('mean', spaces, types) return self._contraction_helper('mean', spaces)
def var(self, spaces=None, types=None): def var(self, spaces=None):
return self._contraction_helper('var', spaces, types) return self._contraction_helper('var', spaces)
def std(self, spaces=None, types=None): def std(self, spaces=None):
return self._contraction_helper('std', spaces, types) return self._contraction_helper('std', spaces)
# ---General binary methods--- # ---General binary methods---
...@@ -790,9 +723,6 @@ class Field(Loggable, Versionable, object): ...@@ -790,9 +723,6 @@ class Field(Loggable, Versionable, object):
assert len(other.domain) == len(self.domain) assert len(other.domain) == len(self.domain)
for index in xrange(len(self.domain)): for index in xrange(len(self.domain)):
assert other.domain[index] == self.domain[index] assert other.domain[index] == self.domain[index]
assert len(other.field_type) == len(self.field_type)
for index in xrange(len(self.field_type)):
assert other.field_type[index] == self.field_type[index]
except AssertionError: except AssertionError:
raise ValueError( raise ValueError(
"domains are incompatible.") "domains are incompatible.")
...@@ -895,19 +825,14 @@ class Field(Loggable, Versionable, object): ...@@ -895,19 +825,14 @@ class Field(Loggable, Versionable, object):
def _to_hdf5(self, hdf5_group): def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name hdf5_group.attrs['dtype'] = self.dtype.name
hdf5_group.attrs['distribution_strategy'] = self.distribution_strategy hdf5_group.attrs['distribution_strategy'] = self.distribution_strategy
hdf5_group.attrs['field_type_axes'] = str(self.field_type_axes)
hdf5_group.attrs['domain_axes'] = str(self.domain_axes) hdf5_group.attrs['domain_axes'] = str(self.domain_axes)
hdf5_group['num_domain'] = len(self.domain) hdf5_group['num_domain'] = len(self.domain)
hdf5_group['num_ft'] = len(self.field_type)
ret_dict = {'val': self.val} ret_dict = {'val': self.val}
for i in range(len(self.domain)): for i in range(len(self.domain)):
ret_dict['s_' + str(i)] = self.domain[i] ret_dict['s_' + str(i)] = self.domain[i]
for i in range(len(self.field_type)):
ret_dict['ft_' + str(i)] = self.field_type[i]
return ret_dict return ret_dict
@classmethod @classmethod
...@@ -922,14 +847,7 @@ class Field(Loggable, Versionable, object): ...@@ -922,14 +847,7 @@ class Field(Loggable, Versionable, object):
temp_domain.append(repository.get('s_' + str(i), hdf5_group)) temp_domain.append(repository.get('s_' + str(i), hdf5_group))
new_field.domain = tuple(temp_domain) new_field.domain = tuple(temp_domain)
temp_ft = []
for i in range(hdf5_group['num_ft'][()]):
temp_domain.append(repository.get('ft_' + str(i), hdf5_group))
new_field.field_type = tuple(temp_ft)
exec('new_field.domain_axes = ' + hdf5_group.attrs['domain_axes']) exec('new_field.domain_axes = ' + hdf5_group.attrs['domain_axes'])
exec('new_field.field_type_axes = ' +
hdf5_group.attrs['field_type_axes'])
new_field._val = repository.get('val', hdf5_group) new_field._val = repository.get('val', hdf5_group)
new_field.dtype = np.dtype(hdf5_group.attrs['dtype']) new_field.dtype = np.dtype(hdf5_group.attrs['dtype'])
new_field.distribution_strategy =\ new_field.distribution_strategy =\
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import pickle
from field_type import FieldType from field_type import FieldType
class FieldArray(FieldType): class FieldArray(FieldType):
def __init__(self, dtype, shape):
try:
new_shape = tuple([int(i) for i in shape])
except TypeError:
new_shape = (int(shape), )
self._shape = new_shape
super(FieldArray, self).__init__(dtype=dtype)
@property
def shape(self):
return self._shape
@property @property
def dim(self): def dim(self):
return reduce(lambda x, y: x*y, self.shape) return reduce(lambda x, y: x*y, self.shape)
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group['shape'] = self.shape
hdf5_group['dtype'] = pickle.dumps(self.dtype)
return None
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
result = cls(
hdf5_group['shape'][:],
pickle.loads(hdf5_group['dtype'][()])
)