From cfcdd7d8a8daf0ecb5d13e62ad053a384eeba870 Mon Sep 17 00:00:00 2001 From: theos Date: Thu, 23 Jun 2016 03:20:58 +0200 Subject: [PATCH] Modified nifty_field.py to successfully initialize with multiple spaces and a field_type. Therefore made the Field_type base class more advanced. Added default field dtype to nifty configuration. --- __init__.py | 3 +- config/nifty_config.py | 18 ++- field_types/__init__.py | 3 +- field_types/base_field_type.py | 38 +++-- field_types/field_array.py | 4 +- nifty_core.py | 2 +- nifty_field.py | 258 ++++++++++++++++----------------- setup.py | 3 +- 8 files changed, 177 insertions(+), 152 deletions(-) diff --git a/__init__.py b/__init__.py index 27cddcbc..24a7fc99 100644 --- a/__init__.py +++ b/__init__.py @@ -50,7 +50,8 @@ from nifty_paradict import space_paradict,\ point_space_paradict,\ nested_space_paradict -from field_types import Field_array +from field_types import Field_type,\ + Field_array from operators import * diff --git a/config/nifty_config.py b/config/nifty_config.py index 718be367..88587a1d 100644 --- a/config/nifty_config.py +++ b/config/nifty_config.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import os - +import numpy as np import keepers # Setup the dependency injector @@ -50,6 +50,21 @@ variable_verbosity = keepers.Variable('verbosity', lambda z: z == abs(int(z)), genus='int') + +def _dtype_validator(dtype): + try: + np.dtype(dtype) + except(TypeError): + return False + else: + return True + +variable_default_field_dtype = keepers.Variable( + 'default_field_dtype', + ['float'], + _dtype_validator, + genus='str') + variable_default_datamodel = keepers.Variable( 'default_datamodel', ['fftw', 'equal'], @@ -64,6 +79,7 @@ nifty_configuration = keepers.get_Configuration( variable_use_healpy, variable_use_libsharp, variable_verbosity, + variable_default_field_dtype, variable_default_datamodel, ], path=os.path.expanduser('~') + "/.nifty/nifty_config") diff --git a/field_types/__init__.py b/field_types/__init__.py index 78db8073..18d88752 100644 --- a/field_types/__init__.py +++ b/field_types/__init__.py @@ -1,3 +1,4 @@ # -*- coding: utf-8 -*- -from field_array import Field_array \ No newline at end of file +from field_array import Field_type,\ + Field_array \ No newline at end of file diff --git a/field_types/base_field_type.py b/field_types/base_field_type.py index f1d312e5..e0aa7f8c 100644 --- a/field_types/base_field_type.py +++ b/field_types/base_field_type.py @@ -1,22 +1,39 @@ # -*- coding: utf-8 -*- +import numpy as np -class Base_field_type(object): - def __init__(self, shape): - self.shape = shape - @property - def shape(self): - return self._shape - - @shape.setter - def shape(self, shape): +class Field_type(object): + def __init__(self, shape, dtype): try: new_shape = tuple([int(i) for i in shape]) except TypeError: new_shape = (int(shape), ) self._shape = new_shape + self._dtype = np.dtype(dtype) + + @property + def shape(self): + return self._shape + + @property + def dtype(self): + return self._dtype + + def get_dof(self, split=False): + if issubclass(self.dtype.type, np.complexfloating): + multiplicator = 2 + else: + multiplicator = 1 + + if split: + dof = tuple(multiplicator*np.array(self.shape)) + else: + dof = multiplicator*reduce(lambda x, y: x*y, self.shape) + + return dof + def process(self, method_name, array, inplace=True, **kwargs): try: result_array = self.__getattr__(method_name)(array, @@ -29,3 +46,6 @@ class Base_field_type(object): result_array = array.copy() return result_array + + def complement_cast(self, x, axis=None): + return x diff --git a/field_types/field_array.py b/field_types/field_array.py index af57f61c..64616156 100644 --- a/field_types/field_array.py +++ b/field_types/field_array.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- -from base_field_type import Base_field_type +from base_field_type import Field_type -class Field_array(Base_field_type): +class Field_array(Field_type): pass diff --git a/nifty_core.py b/nifty_core.py index 006c29b7..91675fe6 100644 --- a/nifty_core.py +++ b/nifty_core.py @@ -311,7 +311,7 @@ class space(object): raise NotImplementedError(about._errors.cstring( "ERROR: no generic instance method 'dof'.")) - def _complement_cast(self, x, axis=None): + def complement_cast(self, x, axis=None): return x # TODO: Move enforce power into power_indices class diff --git a/nifty_field.py b/nifty_field.py index ec994659..56fdde33 100644 --- a/nifty_field.py +++ b/nifty_field.py @@ -9,6 +9,9 @@ from nifty.config import about, \ nifty_configuration as gc, \ dependency_injector as gdi +from nifty.field_types import Field_type,\ + Field_array + from nifty.nifty_core import space import nifty.nifty_utilities as utilities @@ -102,10 +105,9 @@ class field(object): """ - def __init__(self, domain=None, val=None, codomain=None, + def __init__(self, domain, val=None, codomain=None, dtype=None, field_type=None, copy=False, - datamodel=gc['default_datamodel'], comm=gc['default_comm'], - **kwargs): + datamodel=None, comm=None, **kwargs): """ Sets the attributes for a field class instance. @@ -137,6 +139,7 @@ class field(object): comm=comm, copy=copy, dtype=dtype, + field_type=field_type, datamodel=datamodel, **kwargs) else: @@ -146,11 +149,12 @@ class field(object): comm=comm, copy=copy, dtype=dtype, + field_type=field_type, datamodel=datamodel, **kwargs) def _init_from_field(self, f, domain, codomain, comm, copy, dtype, - datamodel, **kwargs): + field_type, datamodel, **kwargs): # check domain if domain is None: domain = f.domain @@ -179,37 +183,45 @@ class field(object): **kwargs) def _init_from_array(self, val, domain, codomain, comm, copy, dtype, - datamodel, **kwargs): - if dtype is None: - dtype = self._get_dtype_from_domain(domain) - self.dtype = dtype - self.comm = self._parse_comm(comm) - - # if val is a distributed data object, we take it's datamodel, - # since we don't want to redistribute large amounts of data, if not - # necessary - if isinstance(val, distributed_data_object): - if datamodel != val.distribution_strategy: - about.warnings.cprint("WARNING: datamodel set to val's " - "datamodel.") - datamodel = val.distribution_strategy - if datamodel not in DISTRIBUTION_STRATEGIES['global']: - about.warnings.cprint("WARNING: datamodel set to default.") - self.datamodel = \ - gc['default_distribution_strategy'] - else: - self.datamodel = datamodel + field_type, datamodel, **kwargs): # check domain - self.domain = self._check_valid_domain(domain=domain) - self._axis_list = self._get_axis_list_from_domain(domain=domain) + self.domain = self._parse_domain(domain=domain) + self._axis_list = self._get_axis_list_from_domain(domain=self.domain) # check codomain if codomain is None: - codomain = self.get_codomain(domain=domain) - elif not self._check_codomain(domain=domain, codomain=codomain): + codomain = self.get_codomain(domain=self.domain) + else: + self.codomain = self._parse_codomain(codomain, self.domain) + + if field_type is None: + field_type = Field_array(shape=(), dtype=np.float) + elif not isinstance(field_type, Field_type): raise ValueError(about._errors.cstring( - "ERROR: The given codomain is not compatible to the domain.")) - self.codomain = codomain + "ERROR: The given field_type object is not an " + "instance of nifty.Field_type.")) + + self.field_type = field_type + + if dtype is None: + dtype = self._infer_dtype(domain=self.domain, + dtype=dtype, + field_type=self.field_type) + self.dtype = dtype + + if comm is not None: + self.comm = self._parse_comm(comm) + elif isinstance(val, distributed_data_object): + self.comm = val.comm + else: + self.comm = gc['default_comm'] + + if datamodel in DISTRIBUTION_STRATEGIES['all']: + self.datamodel = datamodel + elif isinstance(val, distributed_data_object): + self.datamodel = val.distribution_strategy + else: + self.datamodel = gc['default_datamodel'] if val is None: if kwargs == {}: @@ -220,16 +232,17 @@ class field(object): **kwargs) self.set_val(new_val=val, copy=copy) - def _get_dtype_from_domain(self, domain=None): - if domain is None: - domain = self.domain - dtype_tuple = tuple(np.dtype(space.dtype) for space in domain) + def _infer_dtype(self, domain=None, dtype=None, field_type=None): + dtype_tuple = (np.dtype(gc['default_field_dtype']),) + if domain is not None: + dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain) + if field_type is not None: + dtype_tuple += (field_type.dtype,) + dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple) return dtype - def _get_axis_list_from_domain(self, domain=None): - if domain is None: - domain = self.domain + def _get_axis_list_from_domain(self, domain): i = 0 axis_list = [] for sp in domain: @@ -259,36 +272,36 @@ class field(object): "default-MPI-module's Intracomm Class.")) return result_comm - def _check_valid_domain(self, domain): + def _parse_domain(self, domain): if not isinstance(domain, tuple): - raise TypeError(about._errors.cstring( - "ERROR: The given domain is not a list.")) + domain = (domain,) for d in domain: if not isinstance(d, space): raise TypeError(about._errors.cstring( - "ERROR: Given domain is not a space.")) - elif d.dtype > self.dtype: - raise AttributeError(about._errors.cstring( - "ERROR: The dtype of a space in the domain is larger than " - "the field's dtype.")) + "ERROR: Given domain contains something that is not a " + "nifty.space.")) return domain - def _check_codomain(self, domain, codomain): - if codomain is None: - return False - if len(domain) == len(codomain): - return np.all(map((lambda d, c: d._check_codomain(c)), domain, - codomain)) - else: - return False + def _parse_codomain(self, codomain, domain): + if not isinstance(codomain, tuple): + codomain = (codomain,) + if len(domain) != len(codomain): + raise ValueError(about._errors.cstring( + "ERROR: domain and codomain do not have the same length.")) + for (cd, d) in zip(codomain, domain): + if not isinstance(cd, space): + raise TypeError(about._errors.cstring( + "ERROR: Given codomain contains something that is not a" + "nifty.space.")) + if not d.check_codomain(cd): + raise ValueError(about._errors.cstring( + "ERROR: codomain contains a space that is not compatible " + "to its domain-counterpart.")) + return codomain def get_codomain(self, domain): - if len(domain) == 1: - return (domain[0].get_codomain(),) - else: - codomain = tuple(space.get_codomain() for space in domain) - self.codomain = codomain - return codomain + codomain = tuple(sp.get_codomain() for sp in domain) + return codomain def get_random_values(self, domain=None, codomain=None, **kwargs): raise NotImplementedError(about._errors.cstring( @@ -297,10 +310,13 @@ class field(object): def __len__(self): return int(self.get_dim()[0]) - def copy(self, domain=None, codomain=None, **kwargs): + def copy(self, domain=None, codomain=None, field_type=None, **kwargs): copied_val = self._unary_operation(self.get_val(), op='copy', **kwargs) - new_field = self.copy_empty(domain=domain, codomain=codomain) - new_field.set_val(new_val=copied_val) + new_field = self.copy_empty(domain=domain, + codomain=codomain, + field_type=field_type) + new_field.set_val(new_val=copied_val, + copy=True) return new_field def _fast_copy_empty(self): @@ -318,28 +334,41 @@ class field(object): return new_field def copy_empty(self, domain=None, codomain=None, dtype=None, comm=None, - datamodel=None, **kwargs): + datamodel=None, field_type=None, **kwargs): if domain is None: domain = self.domain + if codomain is None: codomain = self.codomain + if dtype is None: dtype = self.dtype + if comm is None: comm = self.comm + if datamodel is None: datamodel = self.datamodel - if (domain is self.domain and - codomain is self.codomain and - dtype == self.dtype and - comm == self.comm and - datamodel == self.datamodel and - kwargs == {}): + if field_type is None: + field_type = self.field_type + + _fast_copyable = True + for i in len(self.domain): + if self.domain[i] is not domain[i]: + _fast_copyable = False + break + if self.codomain[i] is not codomain[i]: + _fast_copyable = False + break + if (_fast_copyable and dtype == self.dtype and comm == self.comm and + datamodel == self.datamodel and + field_type is self.field_type and kwargs == {}): new_field = self._fast_copy_empty() else: new_field = field(domain=domain, codomain=codomain, dtype=dtype, - comm=comm, datamodel=datamodel, **kwargs) + comm=comm, datamodel=datamodel, + field_type=field_type, **kwargs) return new_field def set_val(self, new_val=None, copy=False): @@ -352,34 +381,28 @@ class field(object): New field values either as a constant or an arbitrary array. """ - if new_val is not None: - if copy: - new_val = self.unary_operation(new_val, op='copy') - self.val = self.cast(new_val) + new_val = self.cast(new_val) + if copy: + new_val = self.unary_operation(new_val, op='copy') + self.val = new_val return self.val def get_val(self): return self.val - # TODO: Add functionality for boolean indexing. - def __getitem__(self, key): return self.val[key] def __setitem__(self, key, item): self.val[key] = item - def get_shape(self): - if len(self.domain) > 1: - shape_tuple = tuple(space.get_shape() for space in self.domain) - global_shape = reduce(lambda x, y: x + y, shape_tuple) - else: - global_shape = self.domain[0].get_shape() + @property + def shape(self): + shape_tuple = tuple(sp.get_shape() for sp in self.domain) + shape_tuple += (self.field_type.shape, ) + global_shape = reduce(lambda x, y: x + y, shape_tuple) - if isinstance(global_shape, tuple): - return global_shape - else: - return () + return global_shape def get_dim(self): """ @@ -401,19 +424,25 @@ class field(object): return np.prod(self.get_shape()) def get_dof(self, split=False): - dim = self.get_dim() - if np.issubdtype(self.dtype, np.complex): - return 2 * dim + dof_tuple = tuple(sp.get_dof(split=split) for sp in self.domain) + dof_tuple += (self.field_type.get_dof(split=split),) + if split: + return reduce(lambda x, y: x + y, dof_tuple) else: - return dim + return reduce(lambda x, y: x * y, dof_tuple) def cast(self, x=None, dtype=None): - if dtype is not None: - dtype = np.dtype(dtype) if dtype is None: dtype = self.dtype + else: + dtype = np.dtype(dtype) casted_x = self._cast_to_d2o(x, dtype=dtype) - return self._complement_cast(casted_x) + + for ind, sp in enumerate(self.domain): + casted_x = sp.complement_cast(casted_x, axis=self._axis_list[ind]) + + casted_x = self.field_type.complement_cast(casted_x) + return casted_x def _cast_to_d2o(self, x, dtype=None, shape=None, **kwargs): """ @@ -508,49 +537,6 @@ class field(object): # Cast the d2o return self.cast(x, dtype=dtype) - def _complement_cast(self, x): - for ind, space in enumerate(self.domain): - space._complement_cast(x, axis=self._axis_list[ind]) - return x - - def set_domain(self, new_domain=None, force=False): - """ - Resets the codomain of the field. - - Parameters - ---------- - new_codomain : space - The new space wherein the transform of the field should live. - (default=None). - - """ - # check codomain - if new_domain is None: - new_domain = self.codomain.get_codomain() - elif not force: - assert (self.codomain.check_codomain(new_domain)) - self.domain = new_domain - return self.domain - - def set_codomain(self, new_codomain=None, force=False): - """ - Resets the codomain of the field. - - Parameters - ---------- - new_codomain : space - The new space wherein the transform of the field should live. - (default=None). - - """ - # check codomain - if new_codomain is None: - new_codomain = self.domain.get_codomain() - elif not force: - assert (self.domain.check_codomain(new_codomain)) - self.codomain = new_codomain - return self.codomain - def weight(self, new_val=None, power=1, overwrite=False, spaces=None): """ Returns the field values, weighted with the volume factors to a diff --git a/setup.py b/setup.py index 5f084800..755fcb5e 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,8 @@ setup(name="ift_nifty", description="Numerical Information Field Theory", url="http://www.mpa-garching.mpg.de/ift/nifty/", packages=["nifty", "nifty.demos", "nifty.rg", "nifty.lm", - "nifty.operators", "nifty.dummys", "nifty.config"], + "nifty.operators", "nifty.dummys", "nifty.field_types", + "nifty.config"], package_dir={"nifty": ""}, zip_safe=False, dependency_links=[ -- GitLab