Commit ca3ed015 authored by Theo Steininger's avatar Theo Steininger

Refactored __init__ -> decomposed it into dedicated functions.

parent c6f46393
......@@ -22,8 +22,6 @@ from d2o import distributed_data_object
from nifty.spaces.space import Space
_PSCache = {}
class PowerSpace(Space):
""" NIFTY class for spaces of power spectra.
......@@ -84,10 +82,11 @@ class PowerSpace(Space):
"""
_powerIndexCache = {}
# ---Overwritten properties and methods---
def __init__(self, harmonic_partner,
distribution_strategy='not',
def __init__(self, harmonic_partner, distribution_strategy='not',
logarithmic=None, nbin=None, binbounds=None):
super(PowerSpace, self).__init__()
self._ignore_for_hash += ['_pindex', '_kindex', '_rho']
......@@ -107,14 +106,36 @@ class PowerSpace(Space):
key = (harmonic_partner, distribution_strategy, logarithmic, nbin,
binbounds)
if _PSCache.get(key) is not None:
(self._binbounds, self._pindex, self._kindex, self._rho) \
= _PSCache[key]
return
if self._powerIndexCache.get(key) is None:
distance_array = \
self.harmonic_partner.get_distance_array(distribution_strategy)
temp_binbounds = self._compute_binbounds(
harmonic_partner=self.harmonic_partner,
distribution_strategy=distribution_strategy,
logarithmic=logarithmic,
nbin=nbin,
binbounds=binbounds)
temp_pindex = self._compute_pindex(
distance_array=distance_array,
binbounds=temp_binbounds,
distribution_strategy=distribution_strategy)
temp_rho = temp_pindex.bincount().get_full_data()
temp_kindex = \
(temp_pindex.bincount(weights=distance_array).get_full_data() /
temp_rho)
self._powerIndexCache[key] = (temp_binbounds,
temp_pindex,
temp_kindex,
temp_rho)
(self._binbounds, self._pindex, self._kindex, self._rho) = \
self._powerIndexCache[key]
def _compute_binbounds(self, harmonic_partner, distribution_strategy,
logarithmic, nbin, binbounds):
self._binbounds = None
if logarithmic is None and nbin is None and binbounds is None:
bb = self._harmonic_partner.get_natural_binbounds()
result = None
else:
if binbounds is not None:
bb = np.sort(np.array(binbounds))
......@@ -143,24 +164,22 @@ class PowerSpace(Space):
0.5 * (k[1] + k[2]) + dk * np.arange(nbin-2)]
if(logarithmic):
bb = np.exp(bb)
self._binbounds = tuple(bb)
result = tuple(bb)
return result
def _compute_pindex(self, distance_array, binbounds,
distribution_strategy):
dists = self._harmonic_partner.get_distance_array(
distribution_strategy)
# Compute pindex, kindex and rho according to bb
self._pindex = distributed_data_object(
global_shape=dists.shape,
dtype=np.int,
distribution_strategy=distribution_strategy)
self._pindex.set_local_data(np.searchsorted(
bb, dists.get_local_data())) # also expensive!
self._rho = self._pindex.bincount().get_full_data()
self._kindex = self._pindex.bincount(
weights=dists).get_full_data()/self._rho
_PSCache[key] = \
(self._binbounds, self._pindex, self._kindex, self._rho)
pindex = distributed_data_object(
global_shape=distance_array.shape,
dtype=np.int,
distribution_strategy=distribution_strategy)
if binbounds is None:
binbounds = self.harmonic_partner.get_natural_binbounds()
pindex.set_local_data(
np.searchsorted(binbounds, distance_array.get_local_data()))
return pindex
def pre_cast(self, x, axes):
""" Casts power spectrum functions to discretized power spectra.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment