Commit b5e38b94 authored by csongor's avatar csongor
Browse files

WIP: fix _complement_cast

parent 0ba20735
Pipeline #3934 skipped
......@@ -272,17 +272,23 @@ class lm_space(point_space):
mol[self.paradict['lmax'] + 1:] = 2 # redundant: (l,m) and (l,-m)
return mol
def _cast_to_d2o(self, x, dtype=None, **kwargs):
casted_x = super(lm_space, self)._cast_to_d2o(x=x,
dtype=dtype,
**kwargs)
lmax = self.paradict['lmax']
complexity_mask = casted_x[:lmax+1].iscomplex()
if complexity_mask.any():
about.warnings.cprint("WARNING: Taking the absolute values for " +
"all complex entries where lmax==0")
casted_x[:lmax+1] = abs(casted_x[:lmax+1])
return casted_x
def _complement_cast(self, x, axis=None, **kwargs):
if axis is None:
lmax = self.paradict['lmax']
complexity_mask = x[:lmax+1].iscomplex()
if complexity_mask.any():
about.warnings.cprint("WARNING: Taking the absolute values for " +
"all complex entries where lmax==0")
x[:lmax+1] = abs(x[:lmax+1])
else:
# TODO hermitianize only on specific axis
lmax = self.paradict['lmax']
complexity_mask = x[:lmax+1].iscomplex()
if complexity_mask.any():
about.warnings.cprint("WARNING: Taking the absolute values for " +
"all complex entries where lmax==0")
x[:lmax+1] = abs(x[:lmax+1])
return x
# TODO: Extend to binning/log
def enforce_power(self, spec, size=None, kindex=None):
......
......@@ -319,31 +319,8 @@ class space(object):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'dof'."))
def cast(self, x, verbose=False):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
benevolent as possible.
Parameters
----------
x : {float, numpy.ndarray, nifty.field}
Object to be transformed into an array of valid field values.
Returns
-------
x : numpy.ndarray, distributed_data_object
Array containing the field values, which are compatible to the
space.
Other parameters
----------------
verbose : bool, *optional*
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'cast'."))
def _complement_cast(self, x, axis=None):
return x
# TODO: Move enforce power into power_indices class
def enforce_power(self, spec, **kwargs):
......@@ -1006,93 +983,6 @@ class point_space(space):
mol = self.cast(1, dtype=np.dtype('float'))
return self.calc_weight(mol, power=1)
def cast(self, x=None, dtype=None, **kwargs):
return self._cast_to_d2o(x=x, dtype=dtype, **kwargs)
def _cast_to_d2o(self, x, dtype=None, **kwargs):
"""
Computes valid field values from a given object, trying
to translate the given data into a valid form. Thereby it is as
benevolent as possible.
Parameters
----------
x : {float, numpy.ndarray, nifty.field}
Object to be transformed into an array of valid field values.
Returns
-------
x : numpy.ndarray, distributed_data_object
Array containing the field values, which are compatible to the
space.
Other parameters
----------------
verbose : bool, *optional*
Whether the method should raise a warning if information is
lost during casting (default: False).
"""
if dtype is not None:
dtype = np.dtype(dtype)
if dtype is None:
dtype = self.dtype
# Case 1: x is a distributed_data_object
if isinstance(x, distributed_data_object):
to_copy = False
# Check the shape
if np.any(np.array(x.shape) != np.array(self.get_shape())):
# Check if at least the number of degrees of freedom is equal
if x.get_dim() == self.get_dim():
try:
temp = x.copy_empty(global_shape=self.get_shape())
temp.set_local_data(x.get_local_data(), copy=False)
except:
# If the number of dof is equal or 1, use np.reshape...
about.warnings.cflush(
"WARNING: Trying to reshape the data. This " +
"operation is expensive as it consolidates the " +
"full data!\n")
temp = x.get_full_data()
temp = np.reshape(temp, self.get_shape())
# ... and cast again
return self._cast_to_d2o(temp,
dtype=dtype,
**kwargs)
else:
raise ValueError(about._errors.cstring(
"ERROR: Data has incompatible shape!"))
# Check the dtype
if x.dtype != dtype:
if x.dtype > dtype:
about.warnings.cflush(
"WARNING: Datatypes are of conflicting precision " +
"(own: " + str(dtype) + " <> foreign: " +
str(x.dtype) + ") and will be casted! Potential " +
"loss of precision!\n")
to_copy = True
if to_copy:
temp = x.copy_empty(dtype=dtype)
temp.set_data(to_key=(slice(None),),
data=x,
from_key=(slice(None),))
temp.hermitian = x.hermitian
x = temp
return x
# Case 2: x is something else
# Use general d2o casting
else:
x = distributed_data_object(x, global_shape=self.get_shape(),
dtype=dtype)
# Cast the d2o
return self.cast(x, dtype=dtype)
def enforce_power(self, spec, **kwargs):
"""
Raises an error since the power spectrum is ill-defined for point
......
......@@ -102,9 +102,9 @@ class field(object):
"""
def __init__(self, domain=None, val=None, codomain=None, comm=gc[
'default_comm'], copy=False, dtype=np.dtype('float64'), datamodel='not',
**kwargs):
def __init__(self, domain=None, val=None, codomain=None,
comm=gc['default_comm'], copy=False, dtype=np.dtype('float64'),
datamodel='fftw', **kwargs):
"""
Sets the attributes for a field class instance.
......@@ -149,15 +149,14 @@ class field(object):
**kwargs)
def _init_from_field(self, f, domain, codomain, comm, copy, dtype,
datamodel,
**kwargs):
datamodel, **kwargs):
# check domain
if domain is None:
domain = f.domain
# check codomain
if codomain is None:
if self.check_codomain(domain, f.codomain):
if self._check_codomain(domain, f.codomain):
codomain = f.codomain
else:
codomain = self.get_codomain(domain)
......@@ -181,10 +180,18 @@ class field(object):
def _init_from_array(self, val, domain, codomain, comm, copy, dtype,
datamodel, **kwargs):
if dtype is None:
dtype = np.dtype('float64')
dtype = self._get_dtype_from_domain(domain)
self.dtype = dtype
self.comm = self._parse_comm(comm)
# if val is a distributed data object, we take it's datamodel,
# since we don't want to redistribute large amounts of data, if not
# necessary
if isinstance(val, distributed_data_object):
if datamodel != val.distribution_strategy:
about.warnings.cprint("WARNING: datamodel set to val's "
"datamodel.")
datamodel = val.distribution_strategy
if datamodel not in DISTRIBUTION_STRATEGIES['global']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = \
......@@ -192,12 +199,13 @@ class field(object):
else:
self.datamodel = datamodel
# check domain
self.domain = self.check_valid_domain(domain=domain)
self.domain = self._check_valid_domain(domain=domain)
self._axis_list = self._get_axis_list_from_domain(domain=domain)
# check codomain
if codomain is None:
codomain = self.get_codomain(domain)
elif not self.check_codomain(domain=domain, codomain=codomain):
codomain = self.get_codomain(domain=domain)
elif not self._check_codomain(domain=domain, codomain=codomain):
raise ValueError(about._errors.cstring(
"ERROR: The given codomain is not compatible to the domain."))
self.codomain = codomain
......@@ -210,6 +218,19 @@ class field(object):
codomain=z, **kwargs), self.codomain)
self.set_val(new_val=val, copy=copy)
def _get_dtype_from_domain(self, domain=None):
if domain is None:
domain = self.domain
dtype_tuple = tuple(space.dtype for space in domain)
dtype = np.result_type(dtype_tuple)
return dtype
def _get_axis_list_from_domain(self, domain=None):
if domain is None:
domain = self.domain
axis_list = [space.get_shape() for space in domain]
return axis_list
def _parse_comm(self, comm):
# check if comm is a string -> the name of comm is given
# -> Extract it from the mpi_module
......@@ -229,7 +250,7 @@ class field(object):
"default-MPI-module's Intracomm Class."))
return result_comm
def check_valid_domain(self, domain):
def _check_valid_domain(self, domain):
if not isinstance(domain, tuple):
raise TypeError(about._errors.cstring(
"ERROR: The given domain is not a list."))
......@@ -237,13 +258,13 @@ class field(object):
if not isinstance(d, space):
raise TypeError(about._errors.cstring(
"ERROR: Given domain is not a space."))
elif d.dtype != self.dtype:
elif d.dtype > self.dtype:
raise AttributeError(about._errors.cstring(
"ERROR: The dtype of a space in the domain missmatches "
"ERROR: The dtype of a space in the domain is larger than "
"the field's dtype."))
return domain
def check_codomain(self, domain, codomain):
def _check_codomain(self, domain, codomain):
if codomain is None:
return False
if len(domain) == len(codomain):
......@@ -453,9 +474,7 @@ class field(object):
temp = x
temp = np.reshape(temp, shape)
# ... and cast again
return self._cast_to_d2o(temp,
dtype=dtype,
**kwargs)
return self._cast_to_d2o(temp, dtype=dtype, **kwargs)
else:
raise ValueError(about._errors.cstring(
......@@ -497,7 +516,8 @@ class field(object):
return self.cast(x, dtype=dtype)
def _complement_cast(self, x):
# TODO implement complement cast for multiple spaces.
for ind, space in enumerate(self.domain):
space._complement_cast(x, axis=self._axis_list[ind])
return x
def set_domain(self, new_domain=None, force=False):
......
......@@ -251,30 +251,23 @@ class rg_space(point_space):
def get_shape(self):
return tuple(self.paradict['shape'])
def _cast_to_d2o(self, x, dtype=None, hermitianize=True, **kwargs):
casted_x = super(rg_space, self)._cast_to_d2o(x=x,
dtype=dtype,
**kwargs)
if x is not None and hermitianize and \
self.paradict['complexity'] == 1 and not casted_x.hermitian:
about.warnings.cflush(
"WARNING: Data gets hermitianized. This operation is " +
"extremely expensive\n")
casted_x = utilities.hermitianize(casted_x)
return casted_x
def _cast_to_np(self, x, dtype=None, hermitianize=True, **kwargs):
casted_x = super(rg_space, self)._cast_to_np(x=x,
dtype=dtype,
**kwargs)
if x is not None and hermitianize and self.paradict['complexity'] == 1:
about.warnings.cflush(
"WARNING: Data gets hermitianized. This operation is " +
"extremely expensive\n")
casted_x = utilities.hermitianize(casted_x)
return casted_x
def _complement_cast(self, x, axis=None, hermitianize=True):
if axis is None:
if x is not None and hermitianize and self.paradict['complexity']\
== 1 and not x.hermitian:
about.warnings.cflush(
"WARNING: Data gets hermitianized. This operation is " +
"extremely expensive\n")
x = utilities.hermitianize(x)
else:
# TODO hermitianize only on specific axis
if x is not None and hermitianize and self.paradict['complexity']\
== 1 and not x.hermitian:
about.warnings.cflush(
"WARNING: Data gets hermitianized. This operation is " +
"extremely expensive\n")
x = utilities.hermitianize(x)
return x
def enforce_power(self, spec, size=None, kindex=None, codomain=None,
**kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment