Commit bbd1c13e authored by theos's avatar theos
Browse files

Renamed 'datamodel' to 'distribution_strategy'.

parent c3c16adb
...@@ -80,8 +80,8 @@ variable_default_field_dtype = keepers.Variable( ...@@ -80,8 +80,8 @@ variable_default_field_dtype = keepers.Variable(
_dtype_validator, _dtype_validator,
genus='str') genus='str')
variable_default_datamodel = keepers.Variable( variable_default_distribution_strategy = keepers.Variable(
'default_datamodel', 'default_distribution_strategy',
['fftw', 'equal'], ['fftw', 'equal'],
lambda z: (('pyfftw' in dependency_injector) lambda z: (('pyfftw' in dependency_injector)
if z == 'fftw' else True), if z == 'fftw' else True),
...@@ -95,7 +95,7 @@ nifty_configuration = keepers.get_Configuration( ...@@ -95,7 +95,7 @@ nifty_configuration = keepers.get_Configuration(
variable_use_libsharp, variable_use_libsharp,
variable_verbosity, variable_verbosity,
variable_default_field_dtype, variable_default_field_dtype,
variable_default_datamodel, variable_default_distribution_strategy,
], ],
path=os.path.expanduser('~') + "/.nifty/nifty_config") path=os.path.expanduser('~') + "/.nifty/nifty_config")
######## ########
......
...@@ -25,7 +25,7 @@ class Field(object): ...@@ -25,7 +25,7 @@ class Field(object):
# ---Initialization methods--- # ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, field_type=None, def __init__(self, domain=None, val=None, dtype=None, field_type=None,
datamodel=None, copy=False): distribution_strategy=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val) self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain) self.domain_axes = self._get_axes_tuple(self.domain)
...@@ -44,8 +44,9 @@ class Field(object): ...@@ -44,8 +44,9 @@ class Field(object):
domain=self.domain, domain=self.domain,
field_type=self.field_type) field_type=self.field_type)
self.datamodel = self._parse_datamodel(datamodel=datamodel, self.distribution_strategy = self._parse_distribution_strategy(
val=val) distribution_strategy=distribution_strategy,
val=val)
self.set_val(new_val=val, copy=copy) self.set_val(new_val=val, copy=copy)
...@@ -111,28 +112,28 @@ class Field(object): ...@@ -111,28 +112,28 @@ class Field(object):
return dtype return dtype
def _parse_datamodel(self, datamodel, val): def _parse_distribution_strategy(self, distribution_strategy, val):
if datamodel is None: if distribution_strategy is None:
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
datamodel = val.distribution_strategy distribution_strategy = val.distribution_strategy
elif isinstance(val, Field): elif isinstance(val, Field):
datamodel = val.datamodel distribution_strategy = val.distribution_strategy
else: else:
about.warnings.cprint("WARNING: Datamodel set to default!") about.warnings.cprint("WARNING: Datamodel set to default!")
datamodel = gc['default_datamodel'] distribution_strategy = gc['default_distribution_strategy']
elif datamodel not in DISTRIBUTION_STRATEGIES['all']: elif distribution_strategy not in DISTRIBUTION_STRATEGIES['all']:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: Invalid datamodel!")) "ERROR: Invalid distribution_strategy!"))
return datamodel return distribution_strategy
# ---Factory methods--- # ---Factory methods---
@classmethod @classmethod
def from_random(cls, random_type, domain=None, dtype=None, field_type=None, def from_random(cls, random_type, domain=None, dtype=None, field_type=None,
datamodel=None, **kwargs): distribution_strategy=None, **kwargs):
# create a initially empty field # create a initially empty field
f = cls(domain=domain, dtype=dtype, field_type=field_type, f = cls(domain=domain, dtype=dtype, field_type=field_type,
datamodel=datamodel) distribution_strategy=distribution_strategy)
# now use the processed input in terms of f in order to parse the # now use the processed input in terms of f in order to parse the
# random arguments # random arguments
...@@ -229,7 +230,7 @@ class Field(object): ...@@ -229,7 +230,7 @@ class Field(object):
harmonic_domain = self.domain[space_index] harmonic_domain = self.domain[space_index]
power_domain = PowerSpace(harmonic_domain=harmonic_domain, power_domain = PowerSpace(harmonic_domain=harmonic_domain,
datamodel=distribution_strategy, distribution_strategy=distribution_strategy,
log=log, nbin=nbin, binbounds=binbounds, log=log, nbin=nbin, binbounds=binbounds,
dtype=power_dtype) dtype=power_dtype)
...@@ -358,11 +359,12 @@ class Field(object): ...@@ -358,11 +359,12 @@ class Field(object):
else: else:
result_list = [None] result_list = [None]
result_list = [self.__class__.from_random('normal', result_list = [self.__class__.from_random(
result_domain, 'normal',
dtype=harmonic_domain.dtype, result_domain,
field_type=self.field_type, dtype=harmonic_domain.dtype,
datamodel=self.datamodel) field_type=self.field_type,
distribution_strategy=self.distribution_strategy)
for x in result_list] for x in result_list]
# from now on extract the values from the random fields for further # from now on extract the values from the random fields for further
...@@ -512,24 +514,25 @@ class Field(object): ...@@ -512,24 +514,25 @@ class Field(object):
dtype = self.dtype dtype = self.dtype
return_x = distributed_data_object( return_x = distributed_data_object(
global_shape=self.shape, global_shape=self.shape,
dtype=dtype, dtype=dtype,
distribution_strategy=self.datamodel) distribution_strategy=self.distribution_strategy)
return_x.set_full_data(x, copy=False) return_x.set_full_data(x, copy=False)
return return_x return return_x
def copy(self, domain=None, dtype=None, field_type=None, def copy(self, domain=None, dtype=None, field_type=None,
datamodel=None): distribution_strategy=None):
copied_val = self.get_val(copy=True) copied_val = self.get_val(copy=True)
new_field = self.copy_empty(domain=domain, new_field = self.copy_empty(
dtype=dtype, domain=domain,
field_type=field_type, dtype=dtype,
datamodel=datamodel) field_type=field_type,
distribution_strategy=distribution_strategy)
new_field.set_val(new_val=copied_val, copy=False) new_field.set_val(new_val=copied_val, copy=False)
return new_field return new_field
def copy_empty(self, domain=None, dtype=None, field_type=None, def copy_empty(self, domain=None, dtype=None, field_type=None,
datamodel=None): distribution_strategy=None):
if domain is None: if domain is None:
domain = self.domain domain = self.domain
else: else:
...@@ -545,8 +548,8 @@ class Field(object): ...@@ -545,8 +548,8 @@ class Field(object):
else: else:
field_type = self._parse_field_type(field_type) field_type = self._parse_field_type(field_type)
if datamodel is None: if distribution_strategy is None:
datamodel = self.datamodel distribution_strategy = self.distribution_strategy
fast_copyable = True fast_copyable = True
try: try:
...@@ -562,13 +565,13 @@ class Field(object): ...@@ -562,13 +565,13 @@ class Field(object):
fast_copyable = False fast_copyable = False
if (fast_copyable and dtype == self.dtype and if (fast_copyable and dtype == self.dtype and
datamodel == self.datamodel): distribution_strategy == self.distribution_strategy):
new_field = self._fast_copy_empty() new_field = self._fast_copy_empty()
else: else:
new_field = Field(domain=domain, new_field = Field(domain=domain,
dtype=dtype, dtype=dtype,
field_type=field_type, field_type=field_type,
datamodel=datamodel) distribution_strategy=distribution_strategy)
return new_field return new_field
def _fast_copy_empty(self): def _fast_copy_empty(self):
......
...@@ -16,26 +16,63 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -16,26 +16,63 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), implemented=False, def __init__(self, domain=(), field_type=(), implemented=False,
diagonal=None, bare=False, copy=True, datamodel=None): diagonal=None, bare=False, copy=True,
distribution_strategy=None):
super(DiagonalOperator, self).__init__(domain=domain, super(DiagonalOperator, self).__init__(domain=domain,
field_type=field_type, field_type=field_type,
implemented=implemented) implemented=implemented)
self._implemented = bool(implemented) self._implemented = bool(implemented)
if datamodel is None: if distribution_strategy is None:
if isinstance(diagonal, distributed_data_object): if isinstance(diagonal, distributed_data_object):
datamodel = diagonal.distribution_strategy distribution_strategy = diagonal.distribution_strategy
elif isinstance(diagonal, Field): elif isinstance(diagonal, Field):
datamodel = diagonal.datamodel distribution_strategy = diagonal.distribution_strategy
self.datamodel = self._parse_datamodel(datamodel=datamodel, self.distribution_strategy = self._parse_distribution_strategy(
val=diagonal) distribution_strategy=distribution_strategy,
val=diagonal)
self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy) self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)
def _times(self, x, spaces, types): def _times(self, x, spaces, types):
pass # if the distribution_strategy of self is sub-slice compatible to
# the one of x, reshape the local data of self and apply it directly
active_axes = []
if spaces is None:
for axes in x.domain_axes:
active_axes += axes
else:
for space_index in spaces:
active_axes += x.domain_axes[space_index]
if types is None:
for axes in x.field_type_axes:
active_axes += axes
else:
for type_index in types:
active_axes += x.field_type_axes[type_index]
if x.val.get_axes_local_distribution_strategy(active_axes) == \
self.distribution_strategy:
local_data = self._diagonal.val.get_local_data(copy=False)
# check if domains match completely
# -> multiply directly
# check if axes_local_distribution_strategy matches.
# If yes, extract local data of self.diagonal and x and use numpy
# reshape.
# assert that indices in spaces and types are striktly increasing
# otherwise a wild transpose would be necessary
# build new shape (1,1,x,1,y,1,1,z)
# copy self.diagonal into new shape
# apply reshaped array to x
def _adjoint_times(self, x, spaces, types): def _adjoint_times(self, x, spaces, types):
pass pass
...@@ -97,29 +134,29 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -97,29 +134,29 @@ class DiagonalOperator(EndomorphicOperator):
# ---Added properties and methods--- # ---Added properties and methods---
@property @property
def datamodel(self): def distribution_strategy(self):
return self._datamodel return self._distribution_strategy
def _parse_datamodel(self, datamodel, val): def _parse_distribution_strategy(self, distribution_strategy, val):
if datamodel is None: if distribution_strategy is None:
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
datamodel = val.distribution_strategy distribution_strategy = val.distribution_strategy
elif isinstance(val, Field): elif isinstance(val, Field):
datamodel = val.datamodel distribution_strategy = val.distribution_strategy
else: else:
about.warnings.cprint("WARNING: Datamodel set to default!") about.warnings.cprint("WARNING: Datamodel set to default!")
datamodel = gc['default_datamodel'] distribution_strategy = gc['default_distribution_strategy']
elif datamodel not in DISTRIBUTION_STRATEGIES['all']: elif distribution_strategy not in DISTRIBUTION_STRATEGIES['all']:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: Invalid datamodel!")) "ERROR: Invalid distribution_strategy!"))
return datamodel return distribution_strategy
def set_diagonal(self, diagonal, bare=False, copy=True): def set_diagonal(self, diagonal, bare=False, copy=True):
# use the casting functionality from Field to process `diagonal` # use the casting functionality from Field to process `diagonal`
f = Field(domain=self.domain, f = Field(domain=self.domain,
val=diagonal, val=diagonal,
field_type=self.field_type, field_type=self.field_type,
datamodel=self.datamodel, distribution_strategy=self.distribution_strategy,
copy=copy) copy=copy)
# weight if the given values were `bare` and `implemented` is True # weight if the given values were `bare` and `implemented` is True
......
...@@ -157,8 +157,7 @@ class LinearOperator(object): ...@@ -157,8 +157,7 @@ class LinearOperator(object):
# cases: # cases:
# 1. Case: # 1. Case:
# The user specifies with `spaces` that the operators domain should # The user specifies with `spaces` that the operators domain should
# be applied to a certain domain in the domain-tuple of x. This is # be applied to certain spaces in the domain-tuple of x.
# only valid if len(self.domain)==1.
# 2. Case: # 2. Case:
# The domains of self and x match completely. # The domains of self and x match completely.
...@@ -175,16 +174,8 @@ class LinearOperator(object): ...@@ -175,16 +174,8 @@ class LinearOperator(object):
"ERROR: The operator's and and field's domains don't " "ERROR: The operator's and and field's domains don't "
"match.")) "match."))
else: else:
if len(self_domain) > 1: for i, space_index in enumerate(spaces):
raise ValueError(about._errors.cstring( if x.domain[space_index] != self_domain[i]:
"ERROR: Specifying `spaces` for operators with multiple "
"domain spaces is not valid."))
elif len(spaces) != len(self_domain):
raise ValueError(about._errors.cstring(
"ERROR: Length of `spaces` does not match the number of "
"spaces in the operator's domain."))
elif len(spaces) == 1:
if x.domain[spaces[0]] != self_domain[0]:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: The operator's and and field's domains don't " "ERROR: The operator's and and field's domains don't "
"match.")) "match."))
...@@ -195,19 +186,12 @@ class LinearOperator(object): ...@@ -195,19 +186,12 @@ class LinearOperator(object):
"ERROR: The operator's and and field's field_types don't " "ERROR: The operator's and and field's field_types don't "
"match.")) "match."))
else: else:
if len(self_field_type) > 1: for i, field_type_index in enumerate(types):
raise ValueError(about._errors.cstring( if x.field_types[field_type_index] != self_field_type[i]:
"ERROR: Specifying `types` for operators with multiple "
"field-types is not valid."))
elif len(types) != len(self_field_type):
raise ValueError(about._errors.cstring(
"ERROR: Length of `types` does not match the number of "
"the operator's field-types."))
elif len(types) == 1:
if x.field_type[types[0]] != self_field_type[0]:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: The operator's and and field's field_type " "ERROR: The operator's and and field's field_type "
"don't match.")) "don't match."))
return (spaces, types) return (spaces, types)
def __repr__(self): def __repr__(self):
......
...@@ -2840,7 +2840,7 @@ class response_operator(operator): ...@@ -2840,7 +2840,7 @@ class response_operator(operator):
# TODO: Fix the target spaces # TODO: Fix the target spaces
target = Space(assignments, target = Space(assignments,
dtype=self.domain.dtype, dtype=self.domain.dtype,
datamodel=self.domain.datamodel) distribution_strategy=self.domain.distribution_strategy)
else: else:
# check target # check target
if not isinstance(target, Space): if not isinstance(target, Space):
......
...@@ -499,7 +499,7 @@ class LMSpace(Space): ...@@ -499,7 +499,7 @@ class LMSpace(Space):
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: unsupported codomain.")) "ERROR: unsupported codomain."))
# if self.datamodel != 'not': # if self.distribution_strategy != 'not':
# about.warnings.cprint( # about.warnings.cprint(
# "WARNING: Field data is consolidated to all nodes for " # "WARNING: Field data is consolidated to all nodes for "
# "external alm2map method!") # "external alm2map method!")
...@@ -569,7 +569,7 @@ class LMSpace(Space): ...@@ -569,7 +569,7 @@ class LMSpace(Space):
elif sigma < 0: elif sigma < 0:
raise ValueError(about._errors.cstring("ERROR: invalid sigma.")) raise ValueError(about._errors.cstring("ERROR: invalid sigma."))
# if self.datamodel != 'not': # if self.distribution_strategy != 'not':
# about.warnings.cprint( # about.warnings.cprint(
# "WARNING: Field data is consolidated to all nodes for " # "WARNING: Field data is consolidated to all nodes for "
# "external smoothalm method!") # "external smoothalm method!")
...@@ -612,7 +612,7 @@ class LMSpace(Space): ...@@ -612,7 +612,7 @@ class LMSpace(Space):
lmax = self.paradict['lmax'] lmax = self.paradict['lmax']
mmax = self.paradict['mmax'] mmax = self.paradict['mmax']
# if self.datamodel != 'not': # if self.distribution_strategy != 'not':
# about.warnings.cprint( # about.warnings.cprint(
# "WARNING: Field data is consolidated to all nodes for " # "WARNING: Field data is consolidated to all nodes for "
# "external anaalm/alm2cl method!") # "external anaalm/alm2cl method!")
......
...@@ -14,7 +14,7 @@ class PowerSpace(Space): ...@@ -14,7 +14,7 @@ class PowerSpace(Space):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, harmonic_domain=RGSpace((1,)), datamodel='not', def __init__(self, harmonic_domain=RGSpace((1,)), distribution_strategy='not',
log=False, nbin=None, binbounds=None, log=False, nbin=None, binbounds=None,
dtype=np.dtype('float')): dtype=np.dtype('float')):
...@@ -31,7 +31,7 @@ class PowerSpace(Space): ...@@ -31,7 +31,7 @@ class PowerSpace(Space):
power_index = PowerIndexFactory.get_power_index( power_index = PowerIndexFactory.get_power_index(
domain=self.harmonic_domain, domain=self.harmonic_domain,
distribution_strategy=datamodel, distribution_strategy=distribution_strategy,
log=log, log=log,
nbin=nbin, nbin=nbin,
binbounds=binbounds) binbounds=binbounds)
...@@ -71,9 +71,9 @@ class PowerSpace(Space): ...@@ -71,9 +71,9 @@ class PowerSpace(Space):
return reduce(lambda x, y: x*y, self.pindex.shape) return reduce(lambda x, y: x*y, self.pindex.shape)
def copy(self): def copy(self):
datamodel = self.pindex.distribution_strategy distribution_strategy = self.pindex.distribution_strategy
return self.__class__(harmonic_domain=self.harmonic_domain, return self.__class__(harmonic_domain=self.harmonic_domain,
datamodel=datamodel, distribution_strategy=distribution_strategy,
log=self.log, log=self.log,
nbin=self.nbin, nbin=self.nbin,
binbounds=self.binbounds, binbounds=self.binbounds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment