From 37582f42ef32f8c3a4d9df3f5215b394a1ca38a7 Mon Sep 17 00:00:00 2001 From: Theo Steininger <theos@mpa-garching.mpg.de> Date: Tue, 7 Feb 2017 01:32:00 +0100 Subject: [PATCH] Unified spaces and field_types into single domain object. --- nifty/domain_object.py | 66 ++++++++ nifty/field.py | 158 +++++------------- nifty/field_types/field_array.py | 30 ++++ nifty/field_types/field_type.py | 61 +------ nifty/nifty_utilities.py | 24 +-- .../composed_operator/composed_operator.py | 65 ++----- .../diagonal_operator/diagonal_operator.py | 40 ++--- .../endomorphic_operator.py | 32 ++-- nifty/operators/fft_operator/fft_operator.py | 21 +-- .../invertible_operator_mixin.py | 10 +- .../linear_operator/linear_operator.py | 78 +++------ .../propagator_operator.py | 36 ++-- .../smoothing_operator/smoothing_operator.py | 18 +- nifty/probing/prober/prober.py | 11 +- nifty/random.py | 6 +- nifty/spaces/space/space.py | 58 +------ nifty/sugar.py | 1 - 17 files changed, 234 insertions(+), 481 deletions(-) create mode 100644 nifty/domain_object.py diff --git a/nifty/domain_object.py b/nifty/domain_object.py new file mode 100644 index 000000000..7760a7f9b --- /dev/null +++ b/nifty/domain_object.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +import abc + +import numpy as np + +from keepers import Loggable,\ + Versionable + + +class DomainObject(Versionable, Loggable, object): + __metaclass__ = abc.ABCMeta + + def __init__(self, dtype): + self._dtype = np.dtype(dtype) + self._ignore_for_hash = [] + + def __hash__(self): + # Extract the identifying parts from the vars(self) dict. + result_hash = 0 + for key in sorted(vars(self).keys()): + item = vars(self)[key] + if key in self._ignore_for_hash or key == '_ignore_for_hash': + continue + result_hash ^= item.__hash__() ^ int(hash(key)/117) + return result_hash + + def __eq__(self, x): + if isinstance(x, type(self)): + return hash(self) == hash(x) + else: + return False + + def __ne__(self, x): + return not self.__eq__(x) + + @property + def dtype(self): + return self._dtype + + @abc.abstractproperty + def shape(self): + raise NotImplementedError( + "There is no generic shape for DomainObject.") + + @abc.abstractproperty + def dim(self): + raise NotImplementedError( + "There is no generic dim for DomainObject.") + + def pre_cast(self, x, axes=None): + return x + + def post_cast(self, x, axes=None): + return x + + # ---Serialization--- + + def _to_hdf5(self, hdf5_group): + hdf5_group.attrs['dtype'] = self.dtype.name + return None + + @classmethod + def _from_hdf5(cls, hdf5_group, repository): + result = cls(dtype=np.dtype(hdf5_group.attrs['dtype'])) + return result diff --git a/nifty/field.py b/nifty/field.py index 04bee93ee..039c61154 100644 --- a/nifty/field.py +++ b/nifty/field.py @@ -9,9 +9,8 @@ from d2o import distributed_data_object,\ from nifty.config import nifty_configuration as gc -from nifty.field_types import FieldType +from nifty.domain_object import DomainObject -from nifty.spaces.space import Space from nifty.spaces.power_space import PowerSpace import nifty.nifty_utilities as utilities @@ -21,25 +20,15 @@ from nifty.random import Random class Field(Loggable, Versionable, object): # ---Initialization methods--- - def __init__(self, domain=None, val=None, dtype=None, field_type=None, + def __init__(self, domain=None, val=None, dtype=None, distribution_strategy=None, copy=False): self.domain = self._parse_domain(domain=domain, val=val) self.domain_axes = self._get_axes_tuple(self.domain) - self.field_type = self._parse_field_type(field_type, val=val) - - try: - start = len(reduce(lambda x, y: x+y, self.domain_axes)) - except TypeError: - start = 0 - self.field_type_axes = self._get_axes_tuple(self.field_type, - start=start) - self.dtype = self._infer_dtype(dtype=dtype, val=val, - domain=self.domain, - field_type=self.field_type) + domain=self.domain) self.distribution_strategy = self._parse_distribution_strategy( distribution_strategy=distribution_strategy, @@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object): domain = val.domain else: domain = () - elif isinstance(domain, Space): + elif isinstance(domain, DomainObject): domain = (domain,) elif not isinstance(domain, tuple): domain = tuple(domain) for d in domain: - if not isinstance(d, Space): + if not isinstance(d, DomainObject): raise TypeError( "Given domain contains something that is not a " - "nifty.space.") + "DomainObject instance.") return domain - def _parse_field_type(self, field_type, val=None): - if field_type is None: - if isinstance(val, Field): - field_type = val.field_type - else: - field_type = () - elif isinstance(field_type, FieldType): - field_type = (field_type,) - elif not isinstance(field_type, tuple): - field_type = tuple(field_type) - for ft in field_type: - if not isinstance(ft, FieldType): - raise TypeError( - "Given object is not a nifty.FieldType.") - return field_type - def _get_axes_tuple(self, things_with_shape, start=0): i = start axes_list = [] @@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object): axes_list += [tuple(l)] return tuple(axes_list) - def _infer_dtype(self, dtype, val, domain, field_type): + def _infer_dtype(self, dtype, val, domain): if dtype is None: if isinstance(val, Field) or \ isinstance(val, distributed_data_object): @@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object): dtype_tuple = (np.dtype(dtype),) if domain is not None: dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain) - if field_type is not None: - dtype_tuple += tuple(np.dtype(ft.dtype) for ft in field_type) dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple) @@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object): # ---Factory methods--- @classmethod - def from_random(cls, random_type, domain=None, dtype=None, field_type=None, + def from_random(cls, random_type, domain=None, dtype=None, distribution_strategy=None, **kwargs): # create a initially empty field - f = cls(domain=domain, dtype=dtype, field_type=field_type, + f = cls(domain=domain, dtype=dtype, distribution_strategy=distribution_strategy) # now use the processed input in terms of f in order to parse the @@ -363,7 +334,6 @@ class Field(Loggable, Versionable, object): std=std, domain=result_domain, dtype=harmonic_domain.dtype, - field_type=self.field_type, distribution_strategy=self.distribution_strategy) for x in result_list] @@ -451,9 +421,7 @@ class Field(Loggable, Versionable, object): @property def shape(self): - shape_tuple = () - shape_tuple += tuple(sp.shape for sp in self.domain) - shape_tuple += tuple(ft.shape for ft in self.field_type) + shape_tuple = tuple(sp.shape for sp in self.domain) try: global_shape = reduce(lambda x, y: x + y, shape_tuple) except TypeError: @@ -463,9 +431,7 @@ class Field(Loggable, Versionable, object): @property def dim(self): - dim_tuple = () - dim_tuple += tuple(sp.dim for sp in self.domain) - dim_tuple += tuple(ft.dim for ft in self.field_type) + dim_tuple = tuple(sp.dim for sp in self.domain) try: return reduce(lambda x, y: x * y, dim_tuple) except TypeError: @@ -500,20 +466,12 @@ class Field(Loggable, Versionable, object): casted_x = sp.pre_cast(casted_x, axes=self.domain_axes[ind]) - for ind, ft in enumerate(self.field_type): - casted_x = ft.pre_cast(casted_x, - axes=self.field_type_axes[ind]) - casted_x = self._actual_cast(casted_x, dtype=dtype) for ind, sp in enumerate(self.domain): casted_x = sp.post_cast(casted_x, axes=self.domain_axes[ind]) - for ind, ft in enumerate(self.field_type): - casted_x = ft.post_cast(casted_x, - axes=self.field_type_axes[ind]) - return casted_x def _actual_cast(self, x, dtype=None): @@ -530,19 +488,16 @@ class Field(Loggable, Versionable, object): return_x.set_full_data(x, copy=False) return return_x - def copy(self, domain=None, dtype=None, field_type=None, - distribution_strategy=None): + def copy(self, domain=None, dtype=None, distribution_strategy=None): copied_val = self.get_val(copy=True) new_field = self.copy_empty( domain=domain, dtype=dtype, - field_type=field_type, distribution_strategy=distribution_strategy) new_field.set_val(new_val=copied_val, copy=False) return new_field - def copy_empty(self, domain=None, dtype=None, field_type=None, - distribution_strategy=None): + def copy_empty(self, domain=None, dtype=None, distribution_strategy=None): if domain is None: domain = self.domain else: @@ -553,11 +508,6 @@ class Field(Loggable, Versionable, object): else: dtype = np.dtype(dtype) - if field_type is None: - field_type = self.field_type - else: - field_type = self._parse_field_type(field_type) - if distribution_strategy is None: distribution_strategy = self.distribution_strategy @@ -567,10 +517,6 @@ class Field(Loggable, Versionable, object): if self.domain[i] is not domain[i]: fast_copyable = False break - for i in xrange(len(self.field_type)): - if self.field_type[i] is not field_type[i]: - fast_copyable = False - break except IndexError: fast_copyable = False @@ -580,7 +526,6 @@ class Field(Loggable, Versionable, object): else: new_field = Field(domain=domain, dtype=dtype, - field_type=field_type, distribution_strategy=distribution_strategy) return new_field @@ -626,8 +571,6 @@ class Field(Loggable, Versionable, object): assert len(x.domain) == len(self.domain) for index in xrange(len(self.domain)): assert x.domain[index] == self.domain[index] - for index in xrange(len(self.field_type)): - assert x.field_type[index] == self.field_type[index] except AssertionError: raise ValueError( "domains are incompatible.") @@ -707,22 +650,15 @@ class Field(Loggable, Versionable, object): return_field.set_val(new_val, copy=False) return return_field - def _contraction_helper(self, op, spaces, types): + def _contraction_helper(self, op, spaces): # build a list of all axes if spaces is None: spaces = xrange(len(self.domain)) else: spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain)) - if types is None: - types = xrange(len(self.field_type)) - else: - types = utilities.cast_axis_to_tuple(types, len(self.field_type)) + axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces) - axes_list = () - axes_list += tuple(self.domain_axes[sp_index] for sp_index in spaces) - axes_list += tuple(self.field_type_axes[ft_index] for - ft_index in types) try: axes_list = reduce(lambda x, y: x+y, axes_list) except TypeError: @@ -739,47 +675,44 @@ class Field(Loggable, Versionable, object): return_domain = tuple(self.domain[i] for i in xrange(len(self.domain)) if i not in spaces) - return_field_type = tuple(self.field_type[i] - for i in xrange(len(self.field_type)) - if i not in types) + return_field = Field(domain=return_domain, val=data, - field_type=return_field_type, copy=False) return return_field - def sum(self, spaces=None, types=None): - return self._contraction_helper('sum', spaces, types) + def sum(self, spaces=None): + return self._contraction_helper('sum', spaces) - def prod(self, spaces=None, types=None): - return self._contraction_helper('prod', spaces, types) + def prod(self, spaces=None): + return self._contraction_helper('prod', spaces) - def all(self, spaces=None, types=None): - return self._contraction_helper('all', spaces, types) + def all(self, spaces=None): + return self._contraction_helper('all', spaces) - def any(self, spaces=None, types=None): - return self._contraction_helper('any', spaces, types) + def any(self, spaces=None): + return self._contraction_helper('any', spaces) - def min(self, spaces=None, types=None): - return self._contraction_helper('min', spaces, types) + def min(self, spaces=None): + return self._contraction_helper('min', spaces) - def nanmin(self, spaces=None, types=None): - return self._contraction_helper('nanmin', spaces, types) + def nanmin(self, spaces=None): + return self._contraction_helper('nanmin', spaces) - def max(self, spaces=None, types=None): - return self._contraction_helper('max', spaces, types) + def max(self, spaces=None): + return self._contraction_helper('max', spaces) - def nanmax(self, spaces=None, types=None): - return self._contraction_helper('nanmax', spaces, types) + def nanmax(self, spaces=None): + return self._contraction_helper('nanmax', spaces) - def mean(self, spaces=None, types=None): - return self._contraction_helper('mean', spaces, types) + def mean(self, spaces=None): + return self._contraction_helper('mean', spaces) - def var(self, spaces=None, types=None): - return self._contraction_helper('var', spaces, types) + def var(self, spaces=None): + return self._contraction_helper('var', spaces) - def std(self, spaces=None, types=None): - return self._contraction_helper('std', spaces, types) + def std(self, spaces=None): + return self._contraction_helper('std', spaces) # ---General binary methods--- @@ -790,9 +723,6 @@ class Field(Loggable, Versionable, object): assert len(other.domain) == len(self.domain) for index in xrange(len(self.domain)): assert other.domain[index] == self.domain[index] - assert len(other.field_type) == len(self.field_type) - for index in xrange(len(self.field_type)): - assert other.field_type[index] == self.field_type[index] except AssertionError: raise ValueError( "domains are incompatible.") @@ -895,19 +825,14 @@ class Field(Loggable, Versionable, object): def _to_hdf5(self, hdf5_group): hdf5_group.attrs['dtype'] = self.dtype.name hdf5_group.attrs['distribution_strategy'] = self.distribution_strategy - hdf5_group.attrs['field_type_axes'] = str(self.field_type_axes) hdf5_group.attrs['domain_axes'] = str(self.domain_axes) hdf5_group['num_domain'] = len(self.domain) - hdf5_group['num_ft'] = len(self.field_type) ret_dict = {'val': self.val} for i in range(len(self.domain)): ret_dict['s_' + str(i)] = self.domain[i] - for i in range(len(self.field_type)): - ret_dict['ft_' + str(i)] = self.field_type[i] - return ret_dict @classmethod @@ -922,14 +847,7 @@ class Field(Loggable, Versionable, object): temp_domain.append(repository.get('s_' + str(i), hdf5_group)) new_field.domain = tuple(temp_domain) - temp_ft = [] - for i in range(hdf5_group['num_ft'][()]): - temp_domain.append(repository.get('ft_' + str(i), hdf5_group)) - new_field.field_type = tuple(temp_ft) - exec('new_field.domain_axes = ' + hdf5_group.attrs['domain_axes']) - exec('new_field.field_type_axes = ' + - hdf5_group.attrs['field_type_axes']) new_field._val = repository.get('val', hdf5_group) new_field.dtype = np.dtype(hdf5_group.attrs['dtype']) new_field.distribution_strategy =\ diff --git a/nifty/field_types/field_array.py b/nifty/field_types/field_array.py index 5a0f64f8e..844b67f7b 100644 --- a/nifty/field_types/field_array.py +++ b/nifty/field_types/field_array.py @@ -1,10 +1,40 @@ # -*- coding: utf-8 -*- +import pickle from field_type import FieldType class FieldArray(FieldType): + + def __init__(self, dtype, shape): + try: + new_shape = tuple([int(i) for i in shape]) + except TypeError: + new_shape = (int(shape), ) + self._shape = new_shape + super(FieldArray, self).__init__(dtype=dtype) + + @property + def shape(self): + return self._shape + @property def dim(self): return reduce(lambda x, y: x*y, self.shape) + + # ---Serialization--- + + def _to_hdf5(self, hdf5_group): + hdf5_group['shape'] = self.shape + hdf5_group['dtype'] = pickle.dumps(self.dtype) + + return None + + @classmethod + def _from_hdf5(cls, hdf5_group, loopback_get): + result = cls( + hdf5_group['shape'][:], + pickle.loads(hdf5_group['dtype'][()]) + ) + return result diff --git a/nifty/field_types/field_type.py b/nifty/field_types/field_type.py index 105c66de1..16a9a8022 100644 --- a/nifty/field_types/field_type.py +++ b/nifty/field_types/field_type.py @@ -1,44 +1,9 @@ # -*- coding: utf-8 -*- -import pickle -import numpy as np -from keepers import Versionable +from nifty.domain_object import DomainObject -class FieldType(Versionable, object): - def __init__(self, shape, dtype): - try: - new_shape = tuple([int(i) for i in shape]) - except TypeError: - new_shape = (int(shape), ) - self._shape = new_shape - - self._dtype = np.dtype(dtype) - - def __hash__(self): - # Extract the identifying parts from the vars(self) dict. - result_hash = 0 - for (key, item) in vars(self).items(): - result_hash ^= item.__hash__() ^ int(hash(key)/117) - return result_hash - - def __eq__(self, x): - if isinstance(x, type(self)): - return hash(self) == hash(x) - else: - return False - - @property - def shape(self): - return self._shape - - @property - def dtype(self): - return self._dtype - - @property - def dim(self): - raise NotImplementedError +class FieldType(DomainObject): def process(self, method_name, array, inplace=True, **kwargs): try: @@ -52,25 +17,3 @@ class FieldType(Versionable, object): result_array = array.copy() return result_array - - def pre_cast(self, x, axes=None): - return x - - def post_cast(self, x, axes=None): - return x - - # ---Serialization--- - - def _to_hdf5(self, hdf5_group): - hdf5_group['shape'] = self.shape - hdf5_group['dtype'] = pickle.dumps(self.dtype) - - return None - - @classmethod - def _from_hdf5(cls, hdf5_group, loopback_get): - result = cls( - hdf5_group['shape'][:], - pickle.loads(hdf5_group['dtype'][()]) - ) - return result diff --git a/nifty/nifty_utilities.py b/nifty/nifty_utilities.py index f1ff27c9d..d187ba290 100644 --- a/nifty/nifty_utilities.py +++ b/nifty/nifty_utilities.py @@ -281,33 +281,17 @@ def get_default_codomain(domain): def parse_domain(domain): - from nifty.spaces.space import Space + from nifty.domain_object import DomainObject if domain is None: domain = () - elif isinstance(domain, Space): + elif isinstance(domain, DomainObject): domain = (domain,) elif not isinstance(domain, tuple): domain = tuple(domain) for d in domain: - if not isinstance(d, Space): + if not isinstance(d, DomainObject): raise TypeError( "Given object contains something that is not a " - "nifty.space.") + "instance of DomainObject-class.") return domain - - -def parse_field_type(field_type): - from nifty.field_types import FieldType - if field_type is None: - field_type = () - elif isinstance(field_type, FieldType): - field_type = (field_type,) - elif not isinstance(field_type, tuple): - field_type = tuple(field_type) - - for ft in field_type: - if not isinstance(ft, FieldType): - raise TypeError( - "Given object is not a nifty.FieldType.") - return field_type diff --git a/nifty/operators/composed_operator/composed_operator.py b/nifty/operators/composed_operator/composed_operator.py index 602709975..a980a76b5 100644 --- a/nifty/operators/composed_operator/composed_operator.py +++ b/nifty/operators/composed_operator/composed_operator.py @@ -13,13 +13,13 @@ class ComposedOperator(LinearOperator): "instances of the LinearOperator-baseclass") self._operator_store += (op,) - def _check_input_compatibility(self, x, spaces, types, inverse=False): + def _check_input_compatibility(self, x, spaces, inverse=False): """ The input check must be disabled for the ComposedOperator, since it is not easily forecasteable what the output of an operator-call will look like. """ - return (spaces, types) + return spaces # ---Mandatory properties and methods--- @property @@ -38,22 +38,6 @@ class ComposedOperator(LinearOperator): self._target += op.target return self._target - @property - def field_type(self): - if not hasattr(self, '_field_type'): - self._field_type = () - for op in self._operator_store: - self._field_type += op.field_type - return self._field_type - - @property - def field_type_target(self): - if not hasattr(self, '_field_type_target'): - self._field_type_target = () - for op in self._operator_store: - self._field_type_target += op.field_type_target - return self._field_type_target - @property def implemented(self): return True @@ -62,56 +46,39 @@ class ComposedOperator(LinearOperator): def unitary(self): return False - def _times(self, x, spaces, types): - return self._times_helper(x, spaces, types, func='times') + def _times(self, x, spaces): + return self._times_helper(x, spaces, func='times') - def _adjoint_times(self, x, spaces, types): - return self._inverse_times_helper(x, spaces, types, - func='adjoint_times') + def _adjoint_times(self, x, spaces): + return self._inverse_times_helper(x, spaces, func='adjoint_times') - def _inverse_times(self, x, spaces, types): - return self._inverse_times_helper(x, spaces, types, - func='inverse_times') + def _inverse_times(self, x, spaces): + return self._inverse_times_helper(x, spaces, func='inverse_times') - def _adjoint_inverse_times(self, x, spaces, types): - return self._times_helper(x, spaces, types, - func='adjoint_inverse_times') + def _adjoint_inverse_times(self, x, spaces): + return self._times_helper(x, spaces, func='adjoint_inverse_times') - def _inverse_adjoint_times(self, x, spaces, types): - return self._times_helper(x, spaces, types, - func='inverse_adjoint_times') + def _inverse_adjoint_times(self, x, spaces): + return self._times_helper(x, spaces, func='inverse_adjoint_times') - def _times_helper(self, x, spaces, types, func): + def _times_helper(self, x, spaces, func): space_index = 0 - type_index = 0 if spaces is None: spaces = range(len(self.domain)) - if types is None: - types = range(len(self.field_type)) for op in self._operator_store: active_spaces = spaces[space_index:space_index+len(op.domain)] space_index += len(op.domain) - active_types = types[type_index:type_index+len(op.field_type)] - type_index += len(op.field_type) - - x = getattr(op, func)(x, spaces=active_spaces, types=active_types) + x = getattr(op, func)(x, spaces=active_spaces) return x - def _inverse_times_helper(self, x, spaces, types, func): + def _inverse_times_helper(self, x, spaces, func): space_index = 0 - type_index = 0 if spaces is None: spaces = range(len(self.target))[::-1] - if types is None: - types = range(len(self.field_type_target))[::-1] for op in reversed(self._operator_store): active_spaces = spaces[space_index:space_index+len(op.target)] space_index += len(op.target) - active_types = types[type_index: - type_index+len(op.field_type_target)] - type_index += len(op.field_type_target) - - x = getattr(op, func)(x, spaces=active_spaces, types=active_types) + x = getattr(op, func)(x, spaces=active_spaces) return x diff --git a/nifty/operators/diagonal_operator/diagonal_operator.py b/nifty/operators/diagonal_operator/diagonal_operator.py index 8260167fb..77b6a2b1c 100644 --- a/nifty/operators/diagonal_operator/diagonal_operator.py +++ b/nifty/operators/diagonal_operator/diagonal_operator.py @@ -14,11 +14,10 @@ class DiagonalOperator(EndomorphicOperator): # ---Overwritten properties and methods--- - def __init__(self, domain=(), field_type=(), implemented=True, + def __init__(self, domain=(), implemented=True, diagonal=None, bare=False, copy=True, distribution_strategy=None): self._domain = self._parse_domain(domain) - self._field_type = self._parse_field_type(field_type) self._implemented = bool(implemented) @@ -34,20 +33,18 @@ class DiagonalOperator(EndomorphicOperator): self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy) - def _times(self, x, spaces, types): - return self._times_helper(x, spaces, types, - operation=lambda z: z.__mul__) + def _times(self, x, spaces): + return self._times_helper(x, spaces, operation=lambda z: z.__mul__) - def _adjoint_times(self, x, spaces, types): - return self._times_helper(x, spaces, types, + def _adjoint_times(self, x, spaces): + return self._times_helper(x, spaces, operation=lambda z: z.adjoint().__mul__) - def _inverse_times(self, x, spaces, types): - return self._times_helper(x, spaces, types, - operation=lambda z: z.__rdiv__) + def _inverse_times(self, x, spaces): + return self._times_helper(x, spaces, operation=lambda z: z.__rdiv__) - def _adjoint_inverse_times(self, x, spaces, types): - return self._times_helper(x, spaces, types, + def _adjoint_inverse_times(self, x, spaces): + return self._times_helper(x, spaces, operation=lambda z: z.adjoint().__rdiv__) def diagonal(self, bare=False, copy=True): @@ -87,10 +84,6 @@ class DiagonalOperator(EndomorphicOperator): def domain(self): return self._domain - @property - def field_type(self): - return self._field_type - @property def implemented(self): return self._implemented @@ -127,7 +120,6 @@ class DiagonalOperator(EndomorphicOperator): # use the casting functionality from Field to process `diagonal` f = Field(domain=self.domain, val=diagonal, - field_type=self.field_type, distribution_strategy=self.distribution_strategy, copy=copy) @@ -151,10 +143,10 @@ class DiagonalOperator(EndomorphicOperator): # store the diagonal-field self._diagonal = f - def _times_helper(self, x, spaces, types, operation): - # if the domain and field_type match directly + def _times_helper(self, x, spaces, operation): + # if the domain matches directly # -> multiply the fields directly - if x.domain == self.domain and x.field_type == self.field_type: + if x.domain == self.domain: # here the actual multiplication takes place return operation(self.diagonal(copy=False))(x) @@ -169,14 +161,6 @@ class DiagonalOperator(EndomorphicOperator): for space_index in spaces: active_axes += x.domain_axes[space_index] - if types is None: - if self.field_type != (): - for axes in x.field_type_axes: - active_axes += axes - else: - for type_index in types: - active_axes += x.field_type_axes[type_index] - axes_local_distribution_strategy = \ x.val.get_axes_local_distribution_strategy(active_axes) if axes_local_distribution_strategy == self.distribution_strategy: diff --git a/nifty/operators/endomorphic_operator/endomorphic_operator.py b/nifty/operators/endomorphic_operator/endomorphic_operator.py index 56466b999..233cfce53 100644 --- a/nifty/operators/endomorphic_operator/endomorphic_operator.py +++ b/nifty/operators/endomorphic_operator/endomorphic_operator.py @@ -9,41 +9,37 @@ class EndomorphicOperator(LinearOperator): # ---Overwritten properties and methods--- - def inverse_times(self, x, spaces=None, types=None): + def inverse_times(self, x, spaces=None): if self.symmetric and self.unitary: - return self.times(x, spaces, types) + return self.times(x, spaces) else: return super(EndomorphicOperator, self).inverse_times( x=x, - spaces=spaces, - types=types) + spaces=spaces) - def adjoint_times(self, x, spaces=None, types=None): + def adjoint_times(self, x, spaces=None): if self.symmetric: - return self.times(x, spaces, types) + return self.times(x, spaces) else: return super(EndomorphicOperator, self).adjoint_times( x=x, - spaces=spaces, - types=types) + spaces=spaces) - def adjoint_inverse_times(self, x, spaces=None, types=None): + def adjoint_inverse_times(self, x, spaces=None): if self.symmetric: - return self.inverse_times(x, spaces, types) + return self.inverse_times(x, spaces) else: return super(EndomorphicOperator, self).adjoint_inverse_times( x=x, - spaces=spaces, - types=types) + spaces=spaces) - def inverse_adjoint_times(self, x, spaces=None, types=None): + def inverse_adjoint_times(self, x, spaces=None): if self.symmetric: - return self.inverse_times(x, spaces, types) + return self.inverse_times(x, spaces) else: return super(EndomorphicOperator, self).inverse_adjoint_times( x=x, - spaces=spaces, - types=types) + spaces=spaces) # ---Mandatory properties and methods--- @@ -51,10 +47,6 @@ class EndomorphicOperator(LinearOperator): def target(self): return self.domain - @property - def field_type_target(self): - return self.field_type - # ---Added properties and methods--- @abc.abstractproperty diff --git a/nifty/operators/fft_operator/fft_operator.py b/nifty/operators/fft_operator/fft_operator.py index c89e60460..58f7b359c 100644 --- a/nifty/operators/fft_operator/fft_operator.py +++ b/nifty/operators/fft_operator/fft_operator.py @@ -33,10 +33,9 @@ class FFTOperator(LinearOperator): # ---Overwritten properties and methods--- - def __init__(self, domain=(), field_type=(), target=None, module=None): + def __init__(self, domain=(), target=None, module=None): self._domain = self._parse_domain(domain) - self._field_type = self._parse_field_type(field_type) # Initialize domain and target if len(self.domain) != 1: @@ -44,12 +43,6 @@ class FFTOperator(LinearOperator): 'ERROR: TransformationOperator accepts only exactly one ' 'space as input domain.') - if self.field_type != (): - raise ValueError( - 'ERROR: TransformationOperator field-type must be an ' - 'empty tuple.' - ) - if target is None: target = (self.get_default_codomain(self.domain[0]), ) self._target = self._parse_domain(target) @@ -76,7 +69,7 @@ class FFTOperator(LinearOperator): self._backward_transformation = TransformationCache.create( backward_class, self.target[0], self.domain[0], module=module) - def _times(self, x, spaces, types): + def _times(self, x, spaces): spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) if spaces is None: # this case means that x lives on only one space, which is @@ -99,7 +92,7 @@ class FFTOperator(LinearOperator): return result_field - def _inverse_times(self, x, spaces, types): + def _inverse_times(self, x, spaces): spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) if spaces is None: # this case means that x lives on only one space, which is @@ -132,14 +125,6 @@ class FFTOperator(LinearOperator): def target(self): return self._target - @property - def field_type(self): - return self._field_type - - @property - def field_type_target(self): - return self.field_type - @property def implemented(self): return True diff --git a/nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py b/nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py index 7aa5c3922..a32240cec 100644 --- a/nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py +++ b/nifty/operators/invertible_operator_mixin/invertible_operator_mixin.py @@ -14,7 +14,7 @@ class InvertibleOperatorMixin(object): self.__inverter = ConjugateGradient( preconditioner=self.__preconditioner) - def _times(self, x, spaces, types, x0=None): + def _times(self, x, spaces, x0=None): if x0 is None: x0 = Field(self.target, val=0., dtype=x.dtype) @@ -23,7 +23,7 @@ class InvertibleOperatorMixin(object): x0=x0) return result - def _adjoint_times(self, x, spaces, types, x0=None): + def _adjoint_times(self, x, spaces, x0=None): if x0 is None: x0 = Field(self.domain, val=0., dtype=x.dtype) @@ -32,7 +32,7 @@ class InvertibleOperatorMixin(object): x0=x0) return result - def _inverse_times(self, x, spaces, types, x0=None): + def _inverse_times(self, x, spaces, x0=None): if x0 is None: x0 = Field(self.domain, val=0., dtype=x.dtype) @@ -41,7 +41,7 @@ class InvertibleOperatorMixin(object): x0=x0) return result - def _adjoint_inverse_times(self, x, spaces, types, x0=None): + def _adjoint_inverse_times(self, x, spaces, x0=None): if x0 is None: x0 = Field(self.target, val=0., dtype=x.dtype) @@ -50,6 +50,6 @@ class InvertibleOperatorMixin(object): x0=x0) return result - def _inverse_adjoint_times(self, x, spaces, types): + def _inverse_adjoint_times(self, x, spaces): raise NotImplementedError( "no generic instance method 'inverse_adjoint_times'.") diff --git a/nifty/operators/linear_operator/linear_operator.py b/nifty/operators/linear_operator/linear_operator.py index 1a1361892..347d5e5bd 100644 --- a/nifty/operators/linear_operator/linear_operator.py +++ b/nifty/operators/linear_operator/linear_operator.py @@ -16,9 +16,6 @@ class LinearOperator(Loggable, object): def _parse_domain(self, domain): return utilities.parse_domain(domain) - def _parse_field_type(self, field_type): - return utilities.parse_field_type(field_type) - @abc.abstractproperty def domain(self): raise NotImplementedError @@ -27,14 +24,6 @@ class LinearOperator(Loggable, object): def target(self): raise NotImplementedError - @abc.abstractproperty - def field_type(self): - raise NotImplementedError - - @abc.abstractproperty - def field_type_target(self): - raise NotImplementedError - @abc.abstractproperty def implemented(self): raise NotImplementedError @@ -46,86 +35,83 @@ class LinearOperator(Loggable, object): def __call__(self, *args, **kwargs): return self.times(*args, **kwargs) - def times(self, x, spaces=None, types=None, **kwargs): - spaces, types = self._check_input_compatibility(x, spaces, types) + def times(self, x, spaces=None, **kwargs): + spaces = self._check_input_compatibility(x, spaces) if not self.implemented: x = x.weight(spaces=spaces) - y = self._times(x, spaces, types, **kwargs) + y = self._times(x, spaces, **kwargs) return y - def inverse_times(self, x, spaces=None, types=None, **kwargs): - spaces, types = self._check_input_compatibility(x, spaces, types, - inverse=True) + def inverse_times(self, x, spaces=None, **kwargs): + spaces = self._check_input_compatibility(x, spaces, inverse=True) - y = self._inverse_times(x, spaces, types, **kwargs) + y = self._inverse_times(x, spaces, **kwargs) if not self.implemented: y = y.weight(power=-1, spaces=spaces) return y - def adjoint_times(self, x, spaces=None, types=None, **kwargs): + def adjoint_times(self, x, spaces=None, **kwargs): if self.unitary: - return self.inverse_times(x, spaces, types) + return self.inverse_times(x, spaces) - spaces, types = self._check_input_compatibility(x, spaces, types, - inverse=True) + spaces = self._check_input_compatibility(x, spaces, inverse=True) if not self.implemented: x = x.weight(spaces=spaces) - y = self._adjoint_times(x, spaces, types, **kwargs) + y = self._adjoint_times(x, spaces, **kwargs) return y - def adjoint_inverse_times(self, x, spaces=None, types=None, **kwargs): + def adjoint_inverse_times(self, x, spaces=None, **kwargs): if self.unitary: - return self.times(x, spaces, types) + return self.times(x, spaces) - spaces, types = self._check_input_compatibility(x, spaces, types) + spaces = self._check_input_compatibility(x, spaces) - y = self._adjoint_inverse_times(x, spaces, types, **kwargs) + y = self._adjoint_inverse_times(x, spaces, **kwargs) if not self.implemented: y = y.weight(power=-1, spaces=spaces) return y - def inverse_adjoint_times(self, x, spaces=None, types=None, **kwargs): + def inverse_adjoint_times(self, x, spaces=None, **kwargs): if self.unitary: - return self.times(x, spaces, types, **kwargs) + return self.times(x, spaces, **kwargs) - spaces, types = self._check_input_compatibility(x, spaces, types) + spaces = self._check_input_compatibility(x, spaces) - y = self._inverse_adjoint_times(x, spaces, types) + y = self._inverse_adjoint_times(x, spaces) if not self.implemented: y = y.weight(power=-1, spaces=spaces) return y - def _times(self, x, spaces, types): + def _times(self, x, spaces): raise NotImplementedError( "no generic instance method 'times'.") - def _adjoint_times(self, x, spaces, types): + def _adjoint_times(self, x, spaces): raise NotImplementedError( "no generic instance method 'adjoint_times'.") - def _inverse_times(self, x, spaces, types): + def _inverse_times(self, x, spaces): raise NotImplementedError( "no generic instance method 'inverse_times'.") - def _adjoint_inverse_times(self, x, spaces, types): + def _adjoint_inverse_times(self, x, spaces): raise NotImplementedError( "no generic instance method 'adjoint_inverse_times'.") - def _inverse_adjoint_times(self, x, spaces, types): + def _inverse_adjoint_times(self, x, spaces): raise NotImplementedError( "no generic instance method 'inverse_adjoint_times'.") - def _check_input_compatibility(self, x, spaces, types, inverse=False): + def _check_input_compatibility(self, x, spaces, inverse=False): if not isinstance(x, Field): raise ValueError( "supplied object is not a `nifty.Field`.") # sanitize the `spaces` and `types` input spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) - types = utilities.cast_axis_to_tuple(types, len(x.field_type)) # if the operator's domain is set to something, there are two valid # cases: @@ -137,10 +123,8 @@ class LinearOperator(Loggable, object): if not inverse: self_domain = self.domain - self_field_type = self.field_type else: self_domain = self.target - self_field_type = self.field_type_target if spaces is None: if self_domain != () and self_domain != x.domain: @@ -154,19 +138,7 @@ class LinearOperator(Loggable, object): "The operator's and and field's domains don't " "match.") - if types is None: - if self_field_type != () and self_field_type != x.field_type: - raise ValueError( - "The operator's and and field's field_types don't " - "match.") - else: - for i, field_type_index in enumerate(types): - if x.field_type[field_type_index] != self_field_type[i]: - raise ValueError( - "The operator's and and field's field_type " - "don't match.") - - return (spaces, types) + return spaces def __repr__(self): return str(self.__class__) diff --git a/nifty/operators/propagator_operator/propagator_operator.py b/nifty/operators/propagator_operator/propagator_operator.py index f36793daa..70ca7e53d 100644 --- a/nifty/operators/propagator_operator/propagator_operator.py +++ b/nifty/operators/propagator_operator/propagator_operator.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -from nifty.minimization import ConjugateGradient -from nifty.field import Field from nifty.operators import EndomorphicOperator,\ FFTOperator,\ InvertibleOperatorMixin @@ -60,10 +58,6 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator): def domain(self): return self._domain - @property - def field_type(self): - return () - @property def implemented(self): return True @@ -78,34 +72,24 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator): # ---Added properties and methods--- - def _S_times(self, x, spaces=None, types=None): - transformed_x = self._fft_S(x, - spaces=spaces, - types=types) - y = self._S(transformed_x, spaces=spaces, types=types) - transformed_y = self._fft_S.inverse_times(y, - spaces=spaces, - types=types) + def _S_times(self, x, spaces=None): + transformed_x = self._fft_S(x, spaces=spaces) + y = self._S(transformed_x, spaces=spaces) + transformed_y = self._fft_S.inverse_times(y, spaces=spaces) result = x.copy_empty() result.set_val(transformed_y, copy=False) return result - def _S_inverse_times(self, x, spaces=None, types=None): - transformed_x = self._fft_S(x, - spaces=spaces, - types=types) - y = self._S.inverse_times(transformed_x, - spaces=spaces, - types=types) - transformed_y = self._fft_S.inverse_times(y, - spaces=spaces, - types=types) + def _S_inverse_times(self, x, spaces=None): + transformed_x = self._fft_S(x, spaces=spaces) + y = self._S.inverse_times(transformed_x, spaces=spaces) + transformed_y = self._fft_S.inverse_times(y, spaces=spaces) result = x.copy_empty() result.set_val(transformed_y, copy=False) return result - def _inverse_times(self, x, spaces, types): - pre_result = self._S_inverse_times(x, spaces, types) + def _inverse_times(self, x, spaces): + pre_result = self._S_inverse_times(x, spaces) pre_result += self._likelihood_times(x) result = x.copy_empty() result.set_val(pre_result, copy=False) diff --git a/nifty/operators/smoothing_operator/smoothing_operator.py b/nifty/operators/smoothing_operator/smoothing_operator.py index 4abd45ea2..4306ba264 100644 --- a/nifty/operators/smoothing_operator/smoothing_operator.py +++ b/nifty/operators/smoothing_operator/smoothing_operator.py @@ -9,11 +9,9 @@ from d2o import STRATEGIES class SmoothingOperator(EndomorphicOperator): # ---Overwritten properties and methods--- - def __init__(self, domain=(), field_type=(), sigma=0, - log_distances=False): + def __init__(self, domain=(), sigma=0, log_distances=False): self._domain = self._parse_domain(domain) - self._field_type = self._parse_field_type(field_type) if len(self.domain) != 1: raise ValueError( @@ -22,21 +20,15 @@ class SmoothingOperator(EndomorphicOperator): 'space as input domain.' ) - if self.field_type != (): - raise ValueError( - 'ERROR: SmoothOperator field-type must be an ' - 'empty tuple.' - ) - self.sigma = sigma self.log_distances = log_distances self._direct_smoothing_width = 3. - def _inverse_times(self, x, spaces, types): + def _inverse_times(self, x, spaces): return self._smoothing_helper(x, spaces, inverse=True) - def _times(self, x, spaces, types): + def _times(self, x, spaces): return self._smoothing_helper(x, spaces, inverse=False) # ---Mandatory properties and methods--- @@ -44,10 +36,6 @@ class SmoothingOperator(EndomorphicOperator): def domain(self): return self._domain - @property - def field_type(self): - return self._field_type - @property def implemented(self): return True diff --git a/nifty/probing/prober/prober.py b/nifty/probing/prober/prober.py index e563c72b1..e1eb45afd 100644 --- a/nifty/probing/prober/prober.py +++ b/nifty/probing/prober/prober.py @@ -4,8 +4,6 @@ import abc import numpy as np -from nifty.field_types import FieldType -from nifty.spaces import Space from nifty.field import Field import nifty.nifty_utilities as utilities @@ -26,12 +24,10 @@ class Prober(object): __metaclass__ = abc.ABCMeta - def __init__(self, domain=None, field_type=None, - distribution_strategy=None, probe_count=8, + def __init__(self, domain=None, distribution_strategy=None, probe_count=8, random_type='pm1', compute_variance=False): self._domain = utilities.parse_domain(domain) - self._field_type = utilities.parse_field_type(field_type) self._distribution_strategy = \ self._parse_distribution_strategy(distribution_strategy) self._probe_count = self._parse_probe_count(probe_count) @@ -46,10 +42,6 @@ class Prober(object): def domain(self): return self._domain - @property - def field_type(self): - return self._field_type - @property def distribution_strategy(self): return self._distribution_strategy @@ -102,7 +94,6 @@ class Prober(object): """ a random-probe generator """ f = Field.from_random(random_type=self.random_type, domain=self.domain, - field_type=self.field_type, distribution_strategy=self.distribution_strategy) uid = np.random.randint(1e18) return (uid, f) diff --git a/nifty/random.py b/nifty/random.py index d2417f58e..8f9eceb5d 100644 --- a/nifty/random.py +++ b/nifty/random.py @@ -7,7 +7,7 @@ class Random(object): @staticmethod def pm1(dtype=np.dtype('int'), shape=1): - size = int(reduce(lambda x,y: x*y, shape)) + size = int(reduce(lambda x, y: x*y, shape)) if issubclass(dtype.type, np.complexfloating): x = np.array([1 + 0j, 0 + 1j, -1 + 0j, 0 - 1j], dtype=dtype) x = x[np.random.randint(4, high=None, size=size)] @@ -19,7 +19,7 @@ class Random(object): @staticmethod def normal(dtype=np.dtype('float64'), shape=(1,), mean=None, std=None): - size = int(reduce(lambda x,y: x*y, shape)) + size = int(reduce(lambda x, y: x*y, shape)) if issubclass(dtype.type, np.complexfloating): x = np.empty(size, dtype=dtype) x.real = np.random.normal(loc=0, scale=np.sqrt(0.5), size=size) @@ -41,7 +41,7 @@ class Random(object): @staticmethod def uniform(dtype=np.dtype('float64'), shape=1, low=0, high=1): - size = int(reduce(lambda x,y: x*y, shape)) + size = int(reduce(lambda x, y: x*y, shape)) if issubclass(dtype.type, np.complexfloating): x = np.empty(size, dtype=dtype) x.real = (high - low) * np.random.random(size=size) + low diff --git a/nifty/spaces/space/space.py b/nifty/spaces/space/space.py index 72628a040..a3e026eef 100644 --- a/nifty/spaces/space/space.py +++ b/nifty/spaces/space/space.py @@ -146,11 +146,10 @@ import abc import numpy as np -from keepers import Loggable,\ - Versionable +from nifty.domain_object import DomainObject -class Space(Versionable, Loggable, object): +class Space(DomainObject): """ .. __ __ .. /__/ / /_ @@ -185,8 +184,6 @@ class Space(Versionable, Loggable, object): Pixel volume of the :py:class:`point_space`, which is always 1. """ - __metaclass__ = abc.ABCMeta - def __init__(self, dtype=np.dtype('float')): """ Sets the attributes for a point_space class instance. @@ -207,43 +204,14 @@ class Space(Versionable, Loggable, object): casted_dtype = np.result_type(dtype, np.float64) if casted_dtype != dtype: self.Logger.warning("Input dtype reset to: %s" % str(casted_dtype)) - self.dtype = casted_dtype - - self._ignore_for_hash = ['_global_id'] - - def __hash__(self): - # Extract the identifying parts from the vars(self) dict. - result_hash = 0 - for key in sorted(vars(self).keys()): - item = vars(self)[key] - if key in self._ignore_for_hash or key == '_ignore_for_hash': - continue - result_hash ^= item.__hash__() ^ int(hash(key)/117) - return result_hash - - def __eq__(self, x): - if isinstance(x, type(self)): - return hash(self) == hash(x) - else: - return False - def __ne__(self, x): - return not self.__eq__(x) + super(Space, self).__init__(dtype=casted_dtype) + self._ignore_for_hash += ['_global_id'] @abc.abstractproperty def harmonic(self): raise NotImplementedError - @abc.abstractproperty - def shape(self): - raise NotImplementedError( - "There is no generic shape for the Space base class.") - - @abc.abstractproperty - def dim(self): - raise NotImplementedError( - "There is no generic dim for the Space base class.") - @abc.abstractproperty def total_volume(self): raise NotImplementedError( @@ -273,12 +241,6 @@ class Space(Versionable, Loggable, object): """ raise NotImplementedError - def pre_cast(self, x, axes=None): - return x - - def post_cast(self, x, axes=None): - return x - def get_distance_array(self, distribution_strategy): raise NotImplementedError( "There is no generic distance structure for Space base class.") @@ -295,15 +257,3 @@ class Space(Versionable, Loggable, object): string += str(type(self)) + "\n" string += "dtype: " + str(self.dtype) + "\n" return string - - # ---Serialization--- - - def _to_hdf5(self, hdf5_group): - hdf5_group.attrs['dtype'] = self.dtype.name - - return None - - @classmethod - def _from_hdf5(cls, hdf5_group, repository): - result = cls(dtype=np.dtype(hdf5_group.attrs['dtype'])) - return result diff --git a/nifty/sugar.py b/nifty/sugar.py index 957da133b..a5ec6cb91 100644 --- a/nifty/sugar.py +++ b/nifty/sugar.py @@ -25,4 +25,3 @@ def create_power_operator(domain, power_spectrum, distribution_strategy='not'): power_operator = DiagonalOperator(domain, diagonal=f) return power_operator - -- GitLab