Skip to content
Snippets Groups Projects
Commit 6bd3d2ee authored by Martin Reinecke's avatar Martin Reinecke
Browse files

temporary commit

parent f69589b8
No related branches found
No related tags found
No related merge requests found
......@@ -214,3 +214,25 @@ def to_ndarray(arr):
def from_ndarray(arr):
return data_object(arr)
def local_data(arr):
return arr._data
def ibegin(arr):
return (0,)*arr._data.ndim
def create_from_template(tmpl, local_data, dtype):
res = np.ndarray(tmpl.shape, dtype=dtype)
res[()] = local_data
return data_object(res)
def np_allreduce_sum(arr):
return arr
def dist_axis(arr):
return -1
......@@ -21,3 +21,25 @@ def to_ndarray(arr):
def from_ndarray(arr):
return np.asarray(arr)
def local_data(arr):
return arr
def ibegin(arr):
return (0,)*arr.ndim
def create_from_template(tmpl, local_data, dtype):
res = np.ndarray(tmpl.shape, dtype=dtype)
res[()] = local_data
return res
def np_allreduce_sum(arr):
return arr
def dist_axis(arr):
return -1
from .data_objects.my_own_do import *
#from .data_objects.numpy_do import *
#from .data_objects.my_own_do import *
from .data_objects.numpy_do import *
......@@ -20,7 +20,7 @@ from __future__ import division
import numpy as np
from .. import nifty_utilities as utilities
from ..low_level_library import hartley
from ..dobj import to_ndarray as to_np, from_ndarray as from_np
from .. import dobj
from ..field import Field
from ..spaces.gl_space import GLSpace
......@@ -60,10 +60,13 @@ class RGRGTransformation(Transformation):
"""
axes = x.domain.axes[self.space]
p2h = x.domain == self.pdom
if dobj.dist_axis(x.val) in axes:
raise NotImplementedError
ldat = dobj.local_data(x.val)
if p2h:
Tval = Field(self.hdom, from_np(hartley(to_np(x.val), axes)))
Tval = Field(self.hdom, dobj.create_from_template(x.val,hartley(ldat, axes),dtype=x.val.dtype))
else:
Tval = Field(self.pdom, from_np(hartley(to_np(x.val), axes)))
Tval = Field(self.pdom, dobj.create_from_template(x.val,hartley(ldat, axes),dtype=x.val.dtype))
fct = self.fct_p2h if p2h else self.fct_h2p
if fct != 1:
Tval *= fct
......@@ -73,19 +76,23 @@ class RGRGTransformation(Transformation):
class SlicingTransformation(Transformation):
def transform(self, x):
axes = x.domain.axes[self.space]
if dobj.dist_axis(x.val) in axes:
raise NotImplementedError
p2h = x.domain == self.pdom
idat = dobj.local_data(x.val)
if p2h:
res = Field(self.hdom, dtype=x.dtype)
odat = dobj.local_data(res.val)
for slice in utilities.get_slice_list(x.shape,
x.domain.axes[self.space]):
res.val[slice] = from_np(self._slice_p2h(to_np(x.val[slice])))
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = self._slice_p2h(idat[slice])
else:
res = Field(self.pdom, dtype=x.dtype)
odat = dobj.local_data(res.val)
for slice in utilities.get_slice_list(x.shape,
x.domain.axes[self.space]):
res.val[slice] = from_np(self._slice_h2p(to_np(x.val[slice])))
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = self._slice_h2p(idat[slice])
return res
......
......@@ -137,14 +137,21 @@ class PowerSpace(Space):
key = (harmonic_partner, binbounds)
if self._powerIndexCache.get(key) is None:
k_length_array = self.harmonic_partner.get_k_length_array()
temp_pindex = self._compute_pindex(
harmonic_partner=self.harmonic_partner,
k_length_array=k_length_array,
binbounds=binbounds)
temp_rho = dobj.to_ndarray(dobj.bincount(temp_pindex.ravel()))
if binbounds is None:
tmp = harmonic_partner.get_unique_k_lengths()
tbb = 0.5*(tmp[:-1]+tmp[1:])
else:
tbb = binbounds
locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val))
temp_pindex = dobj.create_from_template(k_length_array.val, local_data=locdat,
dtype=locdat.dtype)
nbin = len(tbb)
temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(), minlength=nbin)
temp_rho = dobj.np_allreduce_sum(temp_rho)
assert not (temp_rho == 0).any(), "empty bins detected"
temp_k_lengths = dobj.to_ndarray(dobj.bincount(temp_pindex.ravel(),
weights=k_length_array.val.ravel())) / temp_rho
temp_k_lengths = np.bincount(dobj.local_data(temp_pindex).ravel(),
weights=dobj.local_data(k_length_array.val).ravel(), minlength=nbin)
temp_k_lengths = dobj.np_allreduce_sum(temp_k_lengths) / temp_rho
temp_dvol = temp_rho*pdvol
self._powerIndexCache[key] = (binbounds,
temp_pindex,
......@@ -154,13 +161,6 @@ class PowerSpace(Space):
(self._binbounds, self._pindex, self._k_lengths, self._dvol) = \
self._powerIndexCache[key]
@staticmethod
def _compute_pindex(harmonic_partner, k_length_array, binbounds):
if binbounds is None:
tmp = harmonic_partner.get_unique_k_lengths()
binbounds = 0.5*(tmp[:-1]+tmp[1:])
return dobj.from_ndarray(np.searchsorted(binbounds, dobj.to_ndarray(k_length_array.val)))
# ---Mandatory properties and methods---
def __repr__(self):
......
......@@ -23,7 +23,7 @@ import numpy as np
from .space import Space
from .. import Field
from ..basic_arithmetics import exp
from ..dobj import to_ndarray as to_np, from_ndarray as from_np
from .. import dobj
class RGSpace(Space):
......@@ -78,17 +78,22 @@ class RGSpace(Space):
def get_k_length_array(self):
if (not self.harmonic):
raise NotImplementedError
res = np.arange(self.shape[0], dtype=np.float64)
out = Field((self,),dtype=np.float64)
oloc = dobj.local_data(out.val)
ibegin = dobj.ibegin(out.val)
res = np.arange(oloc.shape[0], dtype=np.float64) + ibegin[0]
res = np.minimum(res, self.shape[0]-res)*self.distances[0]
if len(self.shape) == 1:
return Field((self,), from_np(res))
oloc[()] = res
return out
res *= res
for i in range(1, len(self.shape)):
tmp = np.arange(self.shape[i], dtype=np.float64)
tmp = np.arange(oloc.shape[i], dtype=np.float64) + ibegin[i]
tmp = np.minimum(tmp, self.shape[i]-tmp)*self.distances[i]
tmp *= tmp
res = np.add.outer(res, tmp)
return Field((self,), from_np(np.sqrt(res)))
oloc[()] = np.sqrt(res)
return out
def get_unique_k_lengths(self):
if (not self.harmonic):
......@@ -110,7 +115,7 @@ class RGSpace(Space):
tmp[t2] = True
return np.sqrt(np.nonzero(tmp)[0])*self.distances[0]
else: # do it the hard way
tmp = np.unique(to_np(self.get_k_length_array().val)) # expensive!
tmp = np.unique(dobj.to_ndarray(self.get_k_length_array().val)) # expensive!
tol = 1e-12*tmp[-1]
# remove all points that are closer than tol to their right
# neighbors.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment