Commit 504626ab authored by theos's avatar theos
Browse files

lm_space: Fixed the datamodel check

nifty_power_indices: Fixed the methods for d2o compatiblity
parent 22f9278e
...@@ -157,7 +157,6 @@ class lm_space(point_space): ...@@ -157,7 +157,6 @@ class lm_space(point_space):
raise ImportError(about._errors.cstring( raise ImportError(about._errors.cstring(
"ERROR: neither libsharp_wrapper_gl nor healpy activated.")) "ERROR: neither libsharp_wrapper_gl nor healpy activated."))
self._cache_dict = {'check_codomain': {}} self._cache_dict = {'check_codomain': {}}
self.paradict = lm_space_paradict(lmax=lmax, mmax=mmax) self.paradict = lm_space_paradict(lmax=lmax, mmax=mmax)
...@@ -165,13 +164,19 @@ class lm_space(point_space): ...@@ -165,13 +164,19 @@ class lm_space(point_space):
# check data type # check data type
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if dtype not in [np.dtype('complex64'), np.dtype('complex128')]: if dtype not in [np.dtype('complex64'), np.dtype('complex128')]:
about.warnings.cprint("WARNING: data type set to default.") about.warnings.cprint("WARNING: data type set to complex128.")
dtype = np.dtype('complex128') dtype = np.dtype('complex128')
self.dtype = dtype self.dtype = dtype
# set datamodel # set datamodel
if datamodel not in ['not']: if datamodel not in ['not']:
about.warnings.cprint("WARNING: datamodel set to default.") about.warnings.cprint(
"WARNING: %s is not a recommended datamodel for lm_space."
% datamodel)
if datamodel not in LM_DISTRIBUTION_STRATEGIES:
raise ValueError(about._errors.cstring(
"ERROR: %s is not a valid datamodel" % datamodel))
self.datamodel = datamodel self.datamodel = datamodel
self.discrete = True self.discrete = True
......
...@@ -508,7 +508,8 @@ class rg_power_indices(power_indices): ...@@ -508,7 +508,8 @@ class rg_power_indices(power_indices):
nkdict = distributed_data_object( nkdict = distributed_data_object(
global_shape=shape, global_shape=shape,
dtype=np.float128, dtype=np.float128,
distribution_strategy=self.datamodel) distribution_strategy=self.datamodel,
comm=self.comm)
if self.datamodel in DISTRIBUTION_STRATEGIES['slicing']: if self.datamodel in DISTRIBUTION_STRATEGIES['slicing']:
# get the node's individual slice of the first dimension # get the node's individual slice of the first dimension
slice_of_first_dimension = slice( slice_of_first_dimension = slice(
...@@ -672,15 +673,21 @@ class lm_power_indices(power_indices): ...@@ -672,15 +673,21 @@ class lm_power_indices(power_indices):
------- -------
nkdict : distributed_data_object nkdict : distributed_data_object
""" """
if self.datamodel == 'not':
if 'healpy' in gdi: # default if self.datamodel != 'not':
about.warnings.cprint(
"WARNING: full kdict is temporarily stored on every node " +
"altough disribution strategy != 'not'!")
if self.datamodel in self.allowed_distribution_strategies:
if 'healpy' in gdi:
nkdict = hp.Alm.getlm(self.lmax, i=None)[0] nkdict = hp.Alm.getlm(self.lmax, i=None)[0]
else: else:
nkdict = self._getlm()[0] nkdict = self._getlm()[0]
nkdict = distributed_data_object(
elif self.datamodel in self.allowed_distribution_strategies: nkdict,
raise NotImplementedError distribution_strategy=self.datamodel,
comm=self.comm)
else: else:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
"ERROR: Unsupported datamodel")) "ERROR: Unsupported datamodel"))
...@@ -695,22 +702,13 @@ class lm_power_indices(power_indices): ...@@ -695,22 +702,13 @@ class lm_power_indices(power_indices):
return l, m return l, m
def _compute_indices(self, nkdict): def _compute_indices(self, nkdict):
if self.datamodel in ['np','not']: if self.datamodel in self.allowed_distribution_strategies:
return self._compute_indices_np(nkdict)
elif self.datamodel in self.allowed_distribution_strategies:
return self._compute_indices_d2o(nkdict) return self._compute_indices_d2o(nkdict)
else: else:
raise ValueError(about._errors.cstring( raise ValueError(about._errors.cstring(
'ERROR: Datamodel is not supported.')) 'ERROR: Datamodel is not supported.'))
def _compute_indices_d2o(self, nkdict): def _compute_indices_d2o(self, nkdict):
"""
Internal helper function which computes pindex, kindex, rho and pundex
from a given nkdict
"""
raise NotImplementedError
def _compute_indices_np(self, nkdict):
""" """
Internal helper function which computes pindex, kindex, rho and pundex Internal helper function which computes pindex, kindex, rho and pundex
from a given nkdict from a given nkdict
...@@ -723,7 +721,7 @@ class lm_power_indices(power_indices): ...@@ -723,7 +721,7 @@ class lm_power_indices(power_indices):
########## ##########
# pindex # # pindex #
########## ##########
pindex = nkdict.astype(np.int, copy=True) pindex = nkdict.copy(dtype=np.int)
####### #######
# rho # # rho #
...@@ -733,6 +731,6 @@ class lm_power_indices(power_indices): ...@@ -733,6 +731,6 @@ class lm_power_indices(power_indices):
########## ##########
# pundex # # pundex #
########## ##########
pundex = self._compute_pundex_np(pindex, kindex) pundex = self._compute_pundex_d2o(pindex, kindex)
return pindex, kindex, rho, pundex return pindex, kindex, rho, pundex
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment