Commit 67bed0b7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent 93d2e2d6
Pipeline #21196 passed with stage
in 4 minutes and 57 seconds
...@@ -6,6 +6,7 @@ class data_object(object): ...@@ -6,6 +6,7 @@ class data_object(object):
def __init__(self, npdata): def __init__(self, npdata):
self._data = np.asarray(npdata) self._data = np.asarray(npdata)
# FIXME: subscripting support will most likely go away
def __getitem__(self, key): def __getitem__(self, key):
res = self._data[key] res = self._data[key]
return res if np.isscalar(res) else data_object(res) return res if np.isscalar(res) else data_object(res)
...@@ -37,6 +38,9 @@ class data_object(object): ...@@ -37,6 +38,9 @@ class data_object(object):
return data_object(self._data.imag) return data_object(self._data.imag)
def _contraction_helper(self, op, axis): def _contraction_helper(self, op, axis):
if axis is not None:
if len(axis)==len(self._data.shape):
axis = None
if axis is None: if axis is None:
return getattr(self._data, op)() return getattr(self._data, op)()
...@@ -164,32 +168,28 @@ def vdot(a, b): ...@@ -164,32 +168,28 @@ def vdot(a, b):
return np.vdot(a._data, b._data) return np.vdot(a._data, b._data)
def abs(a, out=None): def _math_helper(x, function, out):
if out is None: if out is not None:
out = empty_like(a) function(x._data, out=out._data)
np.abs(a._data, out=out._data)
return out return out
else:
return data_object(function(x._data))
def abs(a, out=None):
return _math_helper(a, np.abs, out)
def exp(a, out=None): def exp(a, out=None):
if out is None: return _math_helper(a, np.exp, out)
out = empty_like(a)
np.exp(a._data, out=out._data)
return out
def log(a, out=None): def log(a, out=None):
if out is None: return _math_helper(a, np.log, out)
out = empty_like(a)
np.log(a._data, out=out._data)
return out
def sqrt(a, out=None): def sqrt(a, out=None):
if out is None: return _math_helper(a, np.sqrt, out)
out = empty_like(a)
np.sqrt(a._data, out=out._data)
return out
def bincount(x, weights=None, minlength=None): def bincount(x, weights=None, minlength=None):
...@@ -224,12 +224,6 @@ def ibegin(arr): ...@@ -224,12 +224,6 @@ def ibegin(arr):
return (0,)*arr._data.ndim return (0,)*arr._data.ndim
def create_from_template(tmpl, local_data, dtype):
res = np.ndarray(tmpl.shape, dtype=dtype)
res[()] = local_data
return data_object(res)
def np_allreduce_sum(arr): def np_allreduce_sum(arr):
return arr return arr
...@@ -249,8 +243,6 @@ def from_local_data (shape, arr, dist_axis): ...@@ -249,8 +243,6 @@ def from_local_data (shape, arr, dist_axis):
def from_global_data (arr, dist_axis): def from_global_data (arr, dist_axis):
if dist_axis!=-1: if dist_axis!=-1:
raise NotImplementedError raise NotImplementedError
if shape!=arr.shape:
raise ValueError
return data_object(arr) return data_object(arr)
......
...@@ -31,12 +31,6 @@ def ibegin(arr): ...@@ -31,12 +31,6 @@ def ibegin(arr):
return (0,)*arr.ndim return (0,)*arr.ndim
def create_from_template(tmpl, local_data, dtype):
res = np.ndarray(tmpl.shape, dtype=dtype)
res[()] = local_data
return res
def np_allreduce_sum(arr): def np_allreduce_sum(arr):
return arr return arr
...@@ -56,8 +50,6 @@ def from_local_data (shape, arr, dist_axis): ...@@ -56,8 +50,6 @@ def from_local_data (shape, arr, dist_axis):
def from_global_data (arr, dist_axis): def from_global_data (arr, dist_axis):
if dist_axis!=-1: if dist_axis!=-1:
raise NotImplementedError raise NotImplementedError
if shape!=arr.shape:
raise ValueError
return arr return arr
......
...@@ -455,7 +455,7 @@ class Field(object): ...@@ -455,7 +455,7 @@ class Field(object):
raise TypeError("argument must be a Field") raise TypeError("argument must be a Field")
if other.domain != self.domain: if other.domain != self.domain:
raise ValueError("domains are incompatible.") raise ValueError("domains are incompatible.")
self.val[()] = other.val[()] dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()]
# ---General binary methods--- # ---General binary methods---
......
...@@ -55,15 +55,6 @@ class PowerProjectionOperator(LinearOperator): ...@@ -55,15 +55,6 @@ class PowerProjectionOperator(LinearOperator):
res = Field.zeros(self._target, dtype=x.dtype) res = Field.zeros(self._target, dtype=x.dtype)
if dobj.dist_axis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space if dobj.dist_axis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space
pindex = dobj.local_data(pindex) pindex = dobj.local_data(pindex)
pindex.reshape((1, pindex.size, 1))
arr = dobj.local_data(x.weight(1).val)
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize,pindex.size,postsize))
oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
else: else:
pindex = dobj.to_ndarray(pindex) pindex = dobj.to_ndarray(pindex)
pindex.reshape((1, pindex.size, 1)) pindex.reshape((1, pindex.size, 1))
...@@ -79,10 +70,21 @@ class PowerProjectionOperator(LinearOperator): ...@@ -79,10 +70,21 @@ class PowerProjectionOperator(LinearOperator):
def _adjoint_times(self, x): def _adjoint_times(self, x):
pindex = self._target[self._space].pindex pindex = self._target[self._space].pindex
res = Field.empty(self._domain, dtype=x.dtype)
if dobj.dist_axis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space
pindex = dobj.local_data(pindex)
else:
pindex = dobj.to_ndarray(pindex)
pindex = pindex.reshape((1, pindex.size, 1)) pindex = pindex.reshape((1, pindex.size, 1))
arr = x.val.reshape(x.domain.collapsed_shape_for_domain(self._space)) arr = dobj.local_data(x.val)
out = arr[(slice(None), dobj.to_ndarray(pindex.ravel()), slice(None))] firstaxis = x.domain.axes[self._space][0]
return Field(self._domain, out.reshape(self._domain.shape)) lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize,-1,postsize))
oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
oarr[()] = arr[(slice(None), pindex.ravel(), slice(None))]
return res
@property @property
def domain(self): def domain(self):
......
...@@ -143,8 +143,8 @@ class PowerSpace(Space): ...@@ -143,8 +143,8 @@ class PowerSpace(Space):
else: else:
tbb = binbounds tbb = binbounds
locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val)) locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val))
temp_pindex = dobj.create_from_template( temp_pindex = dobj.from_local_data(
k_length_array.val, local_data=locdat, dtype=locdat.dtype) k_length_array.val.shape, locdat, dobj.dist_axis(k_length_array.val))
nbin = len(tbb) nbin = len(tbb)
temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(), temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(),
minlength=nbin) minlength=nbin)
......
...@@ -115,6 +115,7 @@ class RGSpace(Space): ...@@ -115,6 +115,7 @@ class RGSpace(Space):
tmp[t2] = True tmp[t2] = True
return np.sqrt(np.nonzero(tmp)[0])*self.distances[0] return np.sqrt(np.nonzero(tmp)[0])*self.distances[0]
else: # do it the hard way else: # do it the hard way
# FIXME: this needs to improve for MPI. Maybe unique()/gather()?
tmp = np.unique(dobj.to_ndarray(self.get_k_length_array().val)) # expensive! tmp = np.unique(dobj.to_ndarray(self.get_k_length_array().val)) # expensive!
tol = 1e-12*tmp[-1] tol = 1e-12*tmp[-1]
# remove all points that are closer than tol to their right # remove all points that are closer than tol to their right
......
...@@ -56,8 +56,9 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -56,8 +56,9 @@ class SmoothingOperator_Tests(unittest.TestCase):
@expand(product(spaces, [0., .5, 5.])) @expand(product(spaces, [0., .5, 5.]))
def test_times(self, space, sigma): def test_times(self, space, sigma):
op = ift.FFTSmoothingOperator(space, sigma=sigma) op = ift.FFTSmoothingOperator(space, sigma=sigma)
rand1 = ift.Field.zeros(space) fld = np.zeros(space.shape, dtype=np.float64)
rand1.val[0] = 1. fld[0] = 1.
rand1 = ift.Field(space, ift.dobj.from_global_data(fld, dist_axis=-1))
tt1 = op.times(rand1) tt1 = op.times(rand1)
assert_allclose(1, tt1.sum()) assert_allclose(1, tt1.sum())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment