There is a maintenance of MPCDF Gitlab on Thursday, April 22st 2020, 9:00 am CEST - Expect some service interruptions during this time

Commit cfcdd7d8 authored by theos's avatar theos

Modified nifty_field.py to successfully initialize with multiple spaces and a field_type.

Therefore made the Field_type base class more advanced.
Added default field dtype to nifty configuration.
parent 4ff7b806
...@@ -50,7 +50,8 @@ from nifty_paradict import space_paradict,\ ...@@ -50,7 +50,8 @@ from nifty_paradict import space_paradict,\
point_space_paradict,\ point_space_paradict,\
nested_space_paradict nested_space_paradict
from field_types import Field_array from field_types import Field_type,\
Field_array
from operators import * from operators import *
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import numpy as np
import keepers import keepers
# Setup the dependency injector # Setup the dependency injector
...@@ -50,6 +50,21 @@ variable_verbosity = keepers.Variable('verbosity', ...@@ -50,6 +50,21 @@ variable_verbosity = keepers.Variable('verbosity',
lambda z: z == abs(int(z)), lambda z: z == abs(int(z)),
genus='int') genus='int')
def _dtype_validator(dtype):
try:
np.dtype(dtype)
except(TypeError):
return False
else:
return True
variable_default_field_dtype = keepers.Variable(
'default_field_dtype',
['float'],
_dtype_validator,
genus='str')
variable_default_datamodel = keepers.Variable( variable_default_datamodel = keepers.Variable(
'default_datamodel', 'default_datamodel',
['fftw', 'equal'], ['fftw', 'equal'],
...@@ -64,6 +79,7 @@ nifty_configuration = keepers.get_Configuration( ...@@ -64,6 +79,7 @@ nifty_configuration = keepers.get_Configuration(
variable_use_healpy, variable_use_healpy,
variable_use_libsharp, variable_use_libsharp,
variable_verbosity, variable_verbosity,
variable_default_field_dtype,
variable_default_datamodel, variable_default_datamodel,
], ],
path=os.path.expanduser('~') + "/.nifty/nifty_config") path=os.path.expanduser('~') + "/.nifty/nifty_config")
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from field_array import Field_array from field_array import Field_type,\
\ No newline at end of file Field_array
\ No newline at end of file
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np
class Base_field_type(object):
def __init__(self, shape):
self.shape = shape
@property class Field_type(object):
def shape(self): def __init__(self, shape, dtype):
return self._shape
@shape.setter
def shape(self, shape):
try: try:
new_shape = tuple([int(i) for i in shape]) new_shape = tuple([int(i) for i in shape])
except TypeError: except TypeError:
new_shape = (int(shape), ) new_shape = (int(shape), )
self._shape = new_shape self._shape = new_shape
self._dtype = np.dtype(dtype)
@property
def shape(self):
return self._shape
@property
def dtype(self):
return self._dtype
def get_dof(self, split=False):
if issubclass(self.dtype.type, np.complexfloating):
multiplicator = 2
else:
multiplicator = 1
if split:
dof = tuple(multiplicator*np.array(self.shape))
else:
dof = multiplicator*reduce(lambda x, y: x*y, self.shape)
return dof
def process(self, method_name, array, inplace=True, **kwargs): def process(self, method_name, array, inplace=True, **kwargs):
try: try:
result_array = self.__getattr__(method_name)(array, result_array = self.__getattr__(method_name)(array,
...@@ -29,3 +46,6 @@ class Base_field_type(object): ...@@ -29,3 +46,6 @@ class Base_field_type(object):
result_array = array.copy() result_array = array.copy()
return result_array return result_array
def complement_cast(self, x, axis=None):
return x
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from base_field_type import Base_field_type from base_field_type import Field_type
class Field_array(Base_field_type): class Field_array(Field_type):
pass pass
...@@ -311,7 +311,7 @@ class space(object): ...@@ -311,7 +311,7 @@ class space(object):
raise NotImplementedError(about._errors.cstring( raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'dof'.")) "ERROR: no generic instance method 'dof'."))
def _complement_cast(self, x, axis=None): def complement_cast(self, x, axis=None):
return x return x
# TODO: Move enforce power into power_indices class # TODO: Move enforce power into power_indices class
......
...@@ -9,6 +9,9 @@ from nifty.config import about, \ ...@@ -9,6 +9,9 @@ from nifty.config import about, \
nifty_configuration as gc, \ nifty_configuration as gc, \
dependency_injector as gdi dependency_injector as gdi
from nifty.field_types import Field_type,\
Field_array
from nifty.nifty_core import space from nifty.nifty_core import space
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
...@@ -102,10 +105,9 @@ class field(object): ...@@ -102,10 +105,9 @@ class field(object):
""" """
def __init__(self, domain=None, val=None, codomain=None, def __init__(self, domain, val=None, codomain=None,
dtype=None, field_type=None, copy=False, dtype=None, field_type=None, copy=False,
datamodel=gc['default_datamodel'], comm=gc['default_comm'], datamodel=None, comm=None, **kwargs):
**kwargs):
""" """
Sets the attributes for a field class instance. Sets the attributes for a field class instance.
...@@ -137,6 +139,7 @@ class field(object): ...@@ -137,6 +139,7 @@ class field(object):
comm=comm, comm=comm,
copy=copy, copy=copy,
dtype=dtype, dtype=dtype,
field_type=field_type,
datamodel=datamodel, datamodel=datamodel,
**kwargs) **kwargs)
else: else:
...@@ -146,11 +149,12 @@ class field(object): ...@@ -146,11 +149,12 @@ class field(object):
comm=comm, comm=comm,
copy=copy, copy=copy,
dtype=dtype, dtype=dtype,
field_type=field_type,
datamodel=datamodel, datamodel=datamodel,
**kwargs) **kwargs)
def _init_from_field(self, f, domain, codomain, comm, copy, dtype, def _init_from_field(self, f, domain, codomain, comm, copy, dtype,
datamodel, **kwargs): field_type, datamodel, **kwargs):
# check domain # check domain
if domain is None: if domain is None:
domain = f.domain domain = f.domain
...@@ -179,37 +183,45 @@ class field(object): ...@@ -179,37 +183,45 @@ class field(object):
**kwargs) **kwargs)
def _init_from_array(self, val, domain, codomain, comm, copy, dtype, def _init_from_array(self, val, domain, codomain, comm, copy, dtype,
datamodel, **kwargs): field_type, datamodel, **kwargs):
if dtype is None:
dtype = self._get_dtype_from_domain(domain)
self.dtype = dtype
self.comm = self._parse_comm(comm)
# if val is a distributed data object, we take it's datamodel,
# since we don't want to redistribute large amounts of data, if not
# necessary
if isinstance(val, distributed_data_object):
if datamodel != val.distribution_strategy:
about.warnings.cprint("WARNING: datamodel set to val's "
"datamodel.")
datamodel = val.distribution_strategy
if datamodel not in DISTRIBUTION_STRATEGIES['global']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = \
gc['default_distribution_strategy']
else:
self.datamodel = datamodel
# check domain # check domain
self.domain = self._check_valid_domain(domain=domain) self.domain = self._parse_domain(domain=domain)
self._axis_list = self._get_axis_list_from_domain(domain=domain) self._axis_list = self._get_axis_list_from_domain(domain=self.domain)
# check codomain # check codomain
if codomain is None: if codomain is None:
codomain = self.get_codomain(domain=domain) codomain = self.get_codomain(domain=self.domain)
elif not self._check_codomain(domain=domain, codomain=codomain): else:
self.codomain = self._parse_codomain(codomain, self.domain)
if field_type is None:
field_type = Field_array(shape=(), dtype=np.float)
elif not isinstance(field_type, Field_type):
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: The given codomain is not compatible to the domain.")) "ERROR: The given field_type object is not an "
self.codomain = codomain "instance of nifty.Field_type."))
self.field_type = field_type
if dtype is None:
dtype = self._infer_dtype(domain=self.domain,
dtype=dtype,
field_type=self.field_type)
self.dtype = dtype
if comm is not None:
self.comm = self._parse_comm(comm)
elif isinstance(val, distributed_data_object):
self.comm = val.comm
else:
self.comm = gc['default_comm']
if datamodel in DISTRIBUTION_STRATEGIES['all']:
self.datamodel = datamodel
elif isinstance(val, distributed_data_object):
self.datamodel = val.distribution_strategy
else:
self.datamodel = gc['default_datamodel']
if val is None: if val is None:
if kwargs == {}: if kwargs == {}:
...@@ -220,16 +232,17 @@ class field(object): ...@@ -220,16 +232,17 @@ class field(object):
**kwargs) **kwargs)
self.set_val(new_val=val, copy=copy) self.set_val(new_val=val, copy=copy)
def _get_dtype_from_domain(self, domain=None): def _infer_dtype(self, domain=None, dtype=None, field_type=None):
if domain is None: dtype_tuple = (np.dtype(gc['default_field_dtype']),)
domain = self.domain if domain is not None:
dtype_tuple = tuple(np.dtype(space.dtype) for space in domain) dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
if field_type is not None:
dtype_tuple += (field_type.dtype,)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple) dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
return dtype return dtype
def _get_axis_list_from_domain(self, domain=None): def _get_axis_list_from_domain(self, domain):
if domain is None:
domain = self.domain
i = 0 i = 0
axis_list = [] axis_list = []
for sp in domain: for sp in domain:
...@@ -259,36 +272,36 @@ class field(object): ...@@ -259,36 +272,36 @@ class field(object):
"default-MPI-module's Intracomm Class.")) "default-MPI-module's Intracomm Class."))
return result_comm return result_comm
def _check_valid_domain(self, domain): def _parse_domain(self, domain):
if not isinstance(domain, tuple): if not isinstance(domain, tuple):
raise TypeError(about._errors.cstring( domain = (domain,)
"ERROR: The given domain is not a list."))
for d in domain: for d in domain:
if not isinstance(d, space): if not isinstance(d, space):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
"ERROR: Given domain is not a space.")) "ERROR: Given domain contains something that is not a "
elif d.dtype > self.dtype: "nifty.space."))
raise AttributeError(about._errors.cstring(
"ERROR: The dtype of a space in the domain is larger than "
"the field's dtype."))
return domain return domain
def _check_codomain(self, domain, codomain): def _parse_codomain(self, codomain, domain):
if codomain is None: if not isinstance(codomain, tuple):
return False codomain = (codomain,)
if len(domain) == len(codomain): if len(domain) != len(codomain):
return np.all(map((lambda d, c: d._check_codomain(c)), domain, raise ValueError(about._errors.cstring(
codomain)) "ERROR: domain and codomain do not have the same length."))
else: for (cd, d) in zip(codomain, domain):
return False if not isinstance(cd, space):
raise TypeError(about._errors.cstring(
"ERROR: Given codomain contains something that is not a"
"nifty.space."))
if not d.check_codomain(cd):
raise ValueError(about._errors.cstring(
"ERROR: codomain contains a space that is not compatible "
"to its domain-counterpart."))
return codomain
def get_codomain(self, domain): def get_codomain(self, domain):
if len(domain) == 1: codomain = tuple(sp.get_codomain() for sp in domain)
return (domain[0].get_codomain(),) return codomain
else:
codomain = tuple(space.get_codomain() for space in domain)
self.codomain = codomain
return codomain
def get_random_values(self, domain=None, codomain=None, **kwargs): def get_random_values(self, domain=None, codomain=None, **kwargs):
raise NotImplementedError(about._errors.cstring( raise NotImplementedError(about._errors.cstring(
...@@ -297,10 +310,13 @@ class field(object): ...@@ -297,10 +310,13 @@ class field(object):
def __len__(self): def __len__(self):
return int(self.get_dim()[0]) return int(self.get_dim()[0])
def copy(self, domain=None, codomain=None, **kwargs): def copy(self, domain=None, codomain=None, field_type=None, **kwargs):
copied_val = self._unary_operation(self.get_val(), op='copy', **kwargs) copied_val = self._unary_operation(self.get_val(), op='copy', **kwargs)
new_field = self.copy_empty(domain=domain, codomain=codomain) new_field = self.copy_empty(domain=domain,
new_field.set_val(new_val=copied_val) codomain=codomain,
field_type=field_type)
new_field.set_val(new_val=copied_val,
copy=True)
return new_field return new_field
def _fast_copy_empty(self): def _fast_copy_empty(self):
...@@ -318,28 +334,41 @@ class field(object): ...@@ -318,28 +334,41 @@ class field(object):
return new_field return new_field
def copy_empty(self, domain=None, codomain=None, dtype=None, comm=None, def copy_empty(self, domain=None, codomain=None, dtype=None, comm=None,
datamodel=None, **kwargs): datamodel=None, field_type=None, **kwargs):
if domain is None: if domain is None:
domain = self.domain domain = self.domain
if codomain is None: if codomain is None:
codomain = self.codomain codomain = self.codomain
if dtype is None: if dtype is None:
dtype = self.dtype dtype = self.dtype
if comm is None: if comm is None:
comm = self.comm comm = self.comm
if datamodel is None: if datamodel is None:
datamodel = self.datamodel datamodel = self.datamodel
if (domain is self.domain and if field_type is None:
codomain is self.codomain and field_type = self.field_type
dtype == self.dtype and
comm == self.comm and _fast_copyable = True
datamodel == self.datamodel and for i in len(self.domain):
kwargs == {}): if self.domain[i] is not domain[i]:
_fast_copyable = False
break
if self.codomain[i] is not codomain[i]:
_fast_copyable = False
break
if (_fast_copyable and dtype == self.dtype and comm == self.comm and
datamodel == self.datamodel and
field_type is self.field_type and kwargs == {}):
new_field = self._fast_copy_empty() new_field = self._fast_copy_empty()
else: else:
new_field = field(domain=domain, codomain=codomain, dtype=dtype, new_field = field(domain=domain, codomain=codomain, dtype=dtype,
comm=comm, datamodel=datamodel, **kwargs) comm=comm, datamodel=datamodel,
field_type=field_type, **kwargs)
return new_field return new_field
def set_val(self, new_val=None, copy=False): def set_val(self, new_val=None, copy=False):
...@@ -352,34 +381,28 @@ class field(object): ...@@ -352,34 +381,28 @@ class field(object):
New field values either as a constant or an arbitrary array. New field values either as a constant or an arbitrary array.
""" """
if new_val is not None: new_val = self.cast(new_val)
if copy: if copy:
new_val = self.unary_operation(new_val, op='copy') new_val = self.unary_operation(new_val, op='copy')
self.val = self.cast(new_val) self.val = new_val
return self.val return self.val
def get_val(self): def get_val(self):
return self.val return self.val
# TODO: Add functionality for boolean indexing.
def __getitem__(self, key): def __getitem__(self, key):
return self.val[key] return self.val[key]
def __setitem__(self, key, item): def __setitem__(self, key, item):
self.val[key] = item self.val[key] = item
def get_shape(self): @property
if len(self.domain) > 1: def shape(self):
shape_tuple = tuple(space.get_shape() for space in self.domain) shape_tuple = tuple(sp.get_shape() for sp in self.domain)
global_shape = reduce(lambda x, y: x + y, shape_tuple) shape_tuple += (self.field_type.shape, )
else: global_shape = reduce(lambda x, y: x + y, shape_tuple)
global_shape = self.domain[0].get_shape()
if isinstance(global_shape, tuple): return global_shape
return global_shape
else:
return ()
def get_dim(self): def get_dim(self):
""" """
...@@ -401,19 +424,25 @@ class field(object): ...@@ -401,19 +424,25 @@ class field(object):
return np.prod(self.get_shape()) return np.prod(self.get_shape())
def get_dof(self, split=False): def get_dof(self, split=False):
dim = self.get_dim() dof_tuple = tuple(sp.get_dof(split=split) for sp in self.domain)
if np.issubdtype(self.dtype, np.complex): dof_tuple += (self.field_type.get_dof(split=split),)
return 2 * dim if split:
return reduce(lambda x, y: x + y, dof_tuple)
else: else:
return dim return reduce(lambda x, y: x * y, dof_tuple)
def cast(self, x=None, dtype=None): def cast(self, x=None, dtype=None):
if dtype is not None:
dtype = np.dtype(dtype)
if dtype is None: if dtype is None:
dtype = self.dtype dtype = self.dtype
else:
dtype = np.dtype(dtype)
casted_x = self._cast_to_d2o(x, dtype=dtype) casted_x = self._cast_to_d2o(x, dtype=dtype)
return self._complement_cast(casted_x)
for ind, sp in enumerate(self.domain):
casted_x = sp.complement_cast(casted_x, axis=self._axis_list[ind])
casted_x = self.field_type.complement_cast(casted_x)
return casted_x
def _cast_to_d2o(self, x, dtype=None, shape=None, **kwargs): def _cast_to_d2o(self, x, dtype=None, shape=None, **kwargs):
""" """
...@@ -508,49 +537,6 @@ class field(object): ...@@ -508,49 +537,6 @@ class field(object):
# Cast the d2o # Cast the d2o
return self.cast(x, dtype=dtype) return self.cast(x, dtype=dtype)
def _complement_cast(self, x):
for ind, space in enumerate(self.domain):
space._complement_cast(x, axis=self._axis_list[ind])
return x
def set_domain(self, new_domain=None, force=False):
"""
Resets the codomain of the field.
Parameters
----------
new_codomain : space