Commit 37582f42 authored by Theo Steininger's avatar Theo Steininger

Unified spaces and field_types into single domain object.

parent e38eae92
Pipeline #9997 passed with stage
in 33 minutes and 35 seconds
# -*- coding: utf-8 -*-
import abc
import numpy as np
from keepers import Loggable,\
Versionable
class DomainObject(Versionable, Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, dtype):
self._dtype = np.dtype(dtype)
self._ignore_for_hash = []
def __hash__(self):
# Extract the identifying parts from the vars(self) dict.
result_hash = 0
for key in sorted(vars(self).keys()):
item = vars(self)[key]
if key in self._ignore_for_hash or key == '_ignore_for_hash':
continue
result_hash ^= item.__hash__() ^ int(hash(key)/117)
return result_hash
def __eq__(self, x):
if isinstance(x, type(self)):
return hash(self) == hash(x)
else:
return False
def __ne__(self, x):
return not self.__eq__(x)
@property
def dtype(self):
return self._dtype
@abc.abstractproperty
def shape(self):
raise NotImplementedError(
"There is no generic shape for DomainObject.")
@abc.abstractproperty
def dim(self):
raise NotImplementedError(
"There is no generic dim for DomainObject.")
def pre_cast(self, x, axes=None):
return x
def post_cast(self, x, axes=None):
return x
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, repository):
result = cls(dtype=np.dtype(hdf5_group.attrs['dtype']))
return result
......@@ -9,9 +9,8 @@ from d2o import distributed_data_object,\
from nifty.config import nifty_configuration as gc
from nifty.field_types import FieldType
from nifty.domain_object import DomainObject
from nifty.spaces.space import Space
from nifty.spaces.power_space import PowerSpace
import nifty.nifty_utilities as utilities
......@@ -21,25 +20,15 @@ from nifty.random import Random
class Field(Loggable, Versionable, object):
# ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, field_type=None,
def __init__(self, domain=None, val=None, dtype=None,
distribution_strategy=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain)
self.field_type = self._parse_field_type(field_type, val=val)
try:
start = len(reduce(lambda x, y: x+y, self.domain_axes))
except TypeError:
start = 0
self.field_type_axes = self._get_axes_tuple(self.field_type,
start=start)
self.dtype = self._infer_dtype(dtype=dtype,
val=val,
domain=self.domain,
field_type=self.field_type)
domain=self.domain)
self.distribution_strategy = self._parse_distribution_strategy(
distribution_strategy=distribution_strategy,
......@@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object):
domain = val.domain
else:
domain = ()
elif isinstance(domain, Space):
elif isinstance(domain, DomainObject):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
if not isinstance(d, DomainObject):
raise TypeError(
"Given domain contains something that is not a "
"nifty.space.")
"DomainObject instance.")
return domain
def _parse_field_type(self, field_type, val=None):
if field_type is None:
if isinstance(val, Field):
field_type = val.field_type
else:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
def _get_axes_tuple(self, things_with_shape, start=0):
i = start
axes_list = []
......@@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object):
axes_list += [tuple(l)]
return tuple(axes_list)
def _infer_dtype(self, dtype, val, domain, field_type):
def _infer_dtype(self, dtype, val, domain):
if dtype is None:
if isinstance(val, Field) or \
isinstance(val, distributed_data_object):
......@@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object):
dtype_tuple = (np.dtype(dtype),)
if domain is not None:
dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
if field_type is not None:
dtype_tuple += tuple(np.dtype(ft.dtype) for ft in field_type)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
......@@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object):
# ---Factory methods---
@classmethod
def from_random(cls, random_type, domain=None, dtype=None, field_type=None,
def from_random(cls, random_type, domain=None, dtype=None,
distribution_strategy=None, **kwargs):
# create a initially empty field
f = cls(domain=domain, dtype=dtype, field_type=field_type,
f = cls(domain=domain, dtype=dtype,
distribution_strategy=distribution_strategy)
# now use the processed input in terms of f in order to parse the
......@@ -363,7 +334,6 @@ class Field(Loggable, Versionable, object):
std=std,
domain=result_domain,
dtype=harmonic_domain.dtype,
field_type=self.field_type,
distribution_strategy=self.distribution_strategy)
for x in result_list]
......@@ -451,9 +421,7 @@ class Field(Loggable, Versionable, object):
@property
def shape(self):
shape_tuple = ()
shape_tuple += tuple(sp.shape for sp in self.domain)
shape_tuple += tuple(ft.shape for ft in self.field_type)
shape_tuple = tuple(sp.shape for sp in self.domain)
try:
global_shape = reduce(lambda x, y: x + y, shape_tuple)
except TypeError:
......@@ -463,9 +431,7 @@ class Field(Loggable, Versionable, object):
@property
def dim(self):
dim_tuple = ()
dim_tuple += tuple(sp.dim for sp in self.domain)
dim_tuple += tuple(ft.dim for ft in self.field_type)
dim_tuple = tuple(sp.dim for sp in self.domain)
try:
return reduce(lambda x, y: x * y, dim_tuple)
except TypeError:
......@@ -500,20 +466,12 @@ class Field(Loggable, Versionable, object):
casted_x = sp.pre_cast(casted_x,
axes=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type):
casted_x = ft.pre_cast(casted_x,
axes=self.field_type_axes[ind])
casted_x = self._actual_cast(casted_x, dtype=dtype)
for ind, sp in enumerate(self.domain):
casted_x = sp.post_cast(casted_x,
axes=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type):
casted_x = ft.post_cast(casted_x,
axes=self.field_type_axes[ind])
return casted_x
def _actual_cast(self, x, dtype=None):
......@@ -530,19 +488,16 @@ class Field(Loggable, Versionable, object):
return_x.set_full_data(x, copy=False)
return return_x
def copy(self, domain=None, dtype=None, field_type=None,
distribution_strategy=None):
def copy(self, domain=None, dtype=None, distribution_strategy=None):
copied_val = self.get_val(copy=True)
new_field = self.copy_empty(
domain=domain,
dtype=dtype,
field_type=field_type,
distribution_strategy=distribution_strategy)
new_field.set_val(new_val=copied_val, copy=False)
return new_field
def copy_empty(self, domain=None, dtype=None, field_type=None,
distribution_strategy=None):
def copy_empty(self, domain=None, dtype=None, distribution_strategy=None):
if domain is None:
domain = self.domain
else:
......@@ -553,11 +508,6 @@ class Field(Loggable, Versionable, object):
else:
dtype = np.dtype(dtype)
if field_type is None:
field_type = self.field_type
else:
field_type = self._parse_field_type(field_type)
if distribution_strategy is None:
distribution_strategy = self.distribution_strategy
......@@ -567,10 +517,6 @@ class Field(Loggable, Versionable, object):
if self.domain[i] is not domain[i]:
fast_copyable = False
break
for i in xrange(len(self.field_type)):
if self.field_type[i] is not field_type[i]:
fast_copyable = False
break
except IndexError:
fast_copyable = False
......@@ -580,7 +526,6 @@ class Field(Loggable, Versionable, object):
else:
new_field = Field(domain=domain,
dtype=dtype,
field_type=field_type,
distribution_strategy=distribution_strategy)
return new_field
......@@ -626,8 +571,6 @@ class Field(Loggable, Versionable, object):
assert len(x.domain) == len(self.domain)
for index in xrange(len(self.domain)):
assert x.domain[index] == self.domain[index]
for index in xrange(len(self.field_type)):
assert x.field_type[index] == self.field_type[index]
except AssertionError:
raise ValueError(
"domains are incompatible.")
......@@ -707,22 +650,15 @@ class Field(Loggable, Versionable, object):
return_field.set_val(new_val, copy=False)
return return_field
def _contraction_helper(self, op, spaces, types):
def _contraction_helper(self, op, spaces):
# build a list of all axes
if spaces is None:
spaces = xrange(len(self.domain))
else:
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
if types is None:
types = xrange(len(self.field_type))
else:
types = utilities.cast_axis_to_tuple(types, len(self.field_type))
axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces)
axes_list = ()
axes_list += tuple(self.domain_axes[sp_index] for sp_index in spaces)
axes_list += tuple(self.field_type_axes[ft_index] for
ft_index in types)
try:
axes_list = reduce(lambda x, y: x+y, axes_list)
except TypeError:
......@@ -739,47 +675,44 @@ class Field(Loggable, Versionable, object):
return_domain = tuple(self.domain[i]
for i in xrange(len(self.domain))
if i not in spaces)
return_field_type = tuple(self.field_type[i]
for i in xrange(len(self.field_type))
if i not in types)
return_field = Field(domain=return_domain,
val=data,
field_type=return_field_type,
copy=False)
return return_field
def sum(self, spaces=None, types=None):
return self._contraction_helper('sum', spaces, types)
def sum(self, spaces=None):
return self._contraction_helper('sum', spaces)
def prod(self, spaces=None, types=None):
return self._contraction_helper('prod', spaces, types)
def prod(self, spaces=None):
return self._contraction_helper('prod', spaces)
def all(self, spaces=None, types=None):
return self._contraction_helper('all', spaces, types)
def all(self, spaces=None):
return self._contraction_helper('all', spaces)
def any(self, spaces=None, types=None):
return self._contraction_helper('any', spaces, types)
def any(self, spaces=None):
return self._contraction_helper('any', spaces)
def min(self, spaces=None, types=None):
return self._contraction_helper('min', spaces, types)
def min(self, spaces=None):
return self._contraction_helper('min', spaces)
def nanmin(self, spaces=None, types=None):
return self._contraction_helper('nanmin', spaces, types)
def nanmin(self, spaces=None):
return self._contraction_helper('nanmin', spaces)
def max(self, spaces=None, types=None):
return self._contraction_helper('max', spaces, types)
def max(self, spaces=None):
return self._contraction_helper('max', spaces)
def nanmax(self, spaces=None, types=None):
return self._contraction_helper('nanmax', spaces, types)
def nanmax(self, spaces=None):
return self._contraction_helper('nanmax', spaces)
def mean(self, spaces=None, types=None):
return self._contraction_helper('mean', spaces, types)
def mean(self, spaces=None):
return self._contraction_helper('mean', spaces)
def var(self, spaces=None, types=None):
return self._contraction_helper('var', spaces, types)
def var(self, spaces=None):
return self._contraction_helper('var', spaces)
def std(self, spaces=None, types=None):
return self._contraction_helper('std', spaces, types)
def std(self, spaces=None):
return self._contraction_helper('std', spaces)
# ---General binary methods---
......@@ -790,9 +723,6 @@ class Field(Loggable, Versionable, object):
assert len(other.domain) == len(self.domain)
for index in xrange(len(self.domain)):
assert other.domain[index] == self.domain[index]
assert len(other.field_type) == len(self.field_type)
for index in xrange(len(self.field_type)):
assert other.field_type[index] == self.field_type[index]
except AssertionError:
raise ValueError(
"domains are incompatible.")
......@@ -895,19 +825,14 @@ class Field(Loggable, Versionable, object):
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
hdf5_group.attrs['distribution_strategy'] = self.distribution_strategy
hdf5_group.attrs['field_type_axes'] = str(self.field_type_axes)
hdf5_group.attrs['domain_axes'] = str(self.domain_axes)
hdf5_group['num_domain'] = len(self.domain)
hdf5_group['num_ft'] = len(self.field_type)
ret_dict = {'val': self.val}
for i in range(len(self.domain)):
ret_dict['s_' + str(i)] = self.domain[i]
for i in range(len(self.field_type)):
ret_dict['ft_' + str(i)] = self.field_type[i]
return ret_dict
@classmethod
......@@ -922,14 +847,7 @@ class Field(Loggable, Versionable, object):
temp_domain.append(repository.get('s_' + str(i), hdf5_group))
new_field.domain = tuple(temp_domain)
temp_ft = []
for i in range(hdf5_group['num_ft'][()]):
temp_domain.append(repository.get('ft_' + str(i), hdf5_group))
new_field.field_type = tuple(temp_ft)
exec('new_field.domain_axes = ' + hdf5_group.attrs['domain_axes'])
exec('new_field.field_type_axes = ' +
hdf5_group.attrs['field_type_axes'])
new_field._val = repository.get('val', hdf5_group)
new_field.dtype = np.dtype(hdf5_group.attrs['dtype'])
new_field.distribution_strategy =\
......
# -*- coding: utf-8 -*-
import pickle
from field_type import FieldType
class FieldArray(FieldType):
def __init__(self, dtype, shape):
try:
new_shape = tuple([int(i) for i in shape])
except TypeError:
new_shape = (int(shape), )
self._shape = new_shape
super(FieldArray, self).__init__(dtype=dtype)
@property
def shape(self):
return self._shape
@property
def dim(self):
return reduce(lambda x, y: x*y, self.shape)
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group['shape'] = self.shape
hdf5_group['dtype'] = pickle.dumps(self.dtype)
return None
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
result = cls(
hdf5_group['shape'][:],
pickle.loads(hdf5_group['dtype'][()])
)
return result
# -*- coding: utf-8 -*-
import pickle
import numpy as np
from keepers import Versionable
from nifty.domain_object import DomainObject
class FieldType(Versionable, object):
def __init__(self, shape, dtype):
try:
new_shape = tuple([int(i) for i in shape])
except TypeError:
new_shape = (int(shape), )
self._shape = new_shape
self._dtype = np.dtype(dtype)
def __hash__(self):
# Extract the identifying parts from the vars(self) dict.
result_hash = 0
for (key, item) in vars(self).items():
result_hash ^= item.__hash__() ^ int(hash(key)/117)
return result_hash
def __eq__(self, x):
if isinstance(x, type(self)):
return hash(self) == hash(x)
else:
return False
@property
def shape(self):
return self._shape
@property
def dtype(self):
return self._dtype
@property
def dim(self):
raise NotImplementedError
class FieldType(DomainObject):
def process(self, method_name, array, inplace=True, **kwargs):
try:
......@@ -52,25 +17,3 @@ class FieldType(Versionable, object):
result_array = array.copy()
return result_array
def pre_cast(self, x, axes=None):
return x
def post_cast(self, x, axes=None):
return x
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group['shape'] = self.shape
hdf5_group['dtype'] = pickle.dumps(self.dtype)
return None
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
result = cls(
hdf5_group['shape'][:],
pickle.loads(hdf5_group['dtype'][()])
)
return result
......@@ -281,33 +281,17 @@ def get_default_codomain(domain):
def parse_domain(domain):
from nifty.spaces.space import Space
from nifty.domain_object import DomainObject
if domain is None:
domain = ()
elif isinstance(domain, Space):
elif isinstance(domain, DomainObject):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
if not isinstance(d, DomainObject):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
"instance of DomainObject-class.")
return domain
def parse_field_type(field_type):
from nifty.field_types import FieldType
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
......@@ -13,13 +13,13 @@ class ComposedOperator(LinearOperator):
"instances of the LinearOperator-baseclass")
self._operator_store += (op,)
def _check_input_compatibility(self, x, spaces, types, inverse=False):
def _check_input_compatibility(self, x, spaces, inverse=False):
"""
The input check must be disabled for the ComposedOperator, since it
is not easily forecasteable what the output of an operator-call
will look like.
"""
return (spaces, types)
return spaces
# ---Mandatory properties and methods---
@property
......@@ -38,22 +38,6 @@ class ComposedOperator(LinearOperator):
self._target += op.target
return self._target
@property
def field_type(self):
if not hasattr(self, '_field_type'):
self._field_type = ()
for op in self._operator_store:
self._field_type += op.field_type
return self._field_type
@property
def field_type_target(self):
if not hasattr(self, '_field_type_target'):
self._field_type_target = ()
for op in self._operator_store:
self._field_type_target += op.field_type_target
return self._field_type_target
@property
def implemented(self):
return True
......@@ -62,56 +46,39 @@ class ComposedOperator(LinearOperator):
def unitary(self):
return False
def _times(self, x, spaces, types):
return self._times_helper(x, spaces, types, func='times')
def _times(self, x, spaces):
return self._times_helper(x, spaces, func='times')
def _adjoint_times(self, x, spaces, types):
return self._inverse_times_helper(x, spaces, types,
func='adjoint_times')
def _adjoint_times(self, x, spaces):
return self._inverse_times_helper(x, spaces, func='adjoint_times')
def _inverse_times(self, x, spaces, types):
return self._inverse_times_helper(x, spaces, types,
func='inverse_times')
def _inverse_times(self, x, spaces):
return self._inverse_times_helper(x, spaces, func='inverse_times')