Commit 67bed0b7 authored by Martin Reinecke's avatar Martin Reinecke

tweaks

parent 93d2e2d6
Pipeline #21196 passed with stage
in 4 minutes and 57 seconds
......@@ -6,6 +6,7 @@ class data_object(object):
def __init__(self, npdata):
self._data = np.asarray(npdata)
# FIXME: subscripting support will most likely go away
def __getitem__(self, key):
res = self._data[key]
return res if np.isscalar(res) else data_object(res)
......@@ -37,6 +38,9 @@ class data_object(object):
return data_object(self._data.imag)
def _contraction_helper(self, op, axis):
if axis is not None:
if len(axis)==len(self._data.shape):
axis = None
if axis is None:
return getattr(self._data, op)()
......@@ -164,32 +168,28 @@ def vdot(a, b):
return np.vdot(a._data, b._data)
def _math_helper(x, function, out):
if out is not None:
function(x._data, out=out._data)
return out
else:
return data_object(function(x._data))
def abs(a, out=None):
if out is None:
out = empty_like(a)
np.abs(a._data, out=out._data)
return out
return _math_helper(a, np.abs, out)
def exp(a, out=None):
if out is None:
out = empty_like(a)
np.exp(a._data, out=out._data)
return out
return _math_helper(a, np.exp, out)
def log(a, out=None):
if out is None:
out = empty_like(a)
np.log(a._data, out=out._data)
return out
return _math_helper(a, np.log, out)
def sqrt(a, out=None):
if out is None:
out = empty_like(a)
np.sqrt(a._data, out=out._data)
return out
return _math_helper(a, np.sqrt, out)
def bincount(x, weights=None, minlength=None):
......@@ -224,12 +224,6 @@ def ibegin(arr):
return (0,)*arr._data.ndim
def create_from_template(tmpl, local_data, dtype):
res = np.ndarray(tmpl.shape, dtype=dtype)
res[()] = local_data
return data_object(res)
def np_allreduce_sum(arr):
return arr
......@@ -249,8 +243,6 @@ def from_local_data (shape, arr, dist_axis):
def from_global_data (arr, dist_axis):
if dist_axis!=-1:
raise NotImplementedError
if shape!=arr.shape:
raise ValueError
return data_object(arr)
......
......@@ -31,12 +31,6 @@ def ibegin(arr):
return (0,)*arr.ndim
def create_from_template(tmpl, local_data, dtype):
res = np.ndarray(tmpl.shape, dtype=dtype)
res[()] = local_data
return res
def np_allreduce_sum(arr):
return arr
......@@ -56,8 +50,6 @@ def from_local_data (shape, arr, dist_axis):
def from_global_data (arr, dist_axis):
if dist_axis!=-1:
raise NotImplementedError
if shape!=arr.shape:
raise ValueError
return arr
......
......@@ -455,7 +455,7 @@ class Field(object):
raise TypeError("argument must be a Field")
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
self.val[()] = other.val[()]
dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()]
# ---General binary methods---
......
......@@ -55,34 +55,36 @@ class PowerProjectionOperator(LinearOperator):
res = Field.zeros(self._target, dtype=x.dtype)
if dobj.dist_axis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space
pindex = dobj.local_data(pindex)
pindex.reshape((1, pindex.size, 1))
arr = dobj.local_data(x.weight(1).val)
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize,pindex.size,postsize))
oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
else:
pindex = dobj.to_ndarray(pindex)
pindex.reshape((1, pindex.size, 1))
arr = dobj.local_data(x.weight(1).val)
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize,pindex.size,postsize))
oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
pindex.reshape((1, pindex.size, 1))
arr = dobj.local_data(x.weight(1).val)
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize,pindex.size,postsize))
oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
return res.weight(-1, spaces=self._space)
def _adjoint_times(self, x):
pindex = self._target[self._space].pindex
res = Field.empty(self._domain, dtype=x.dtype)
if dobj.dist_axis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space
pindex = dobj.local_data(pindex)
else:
pindex = dobj.to_ndarray(pindex)
pindex = pindex.reshape((1, pindex.size, 1))
arr = x.val.reshape(x.domain.collapsed_shape_for_domain(self._space))
out = arr[(slice(None), dobj.to_ndarray(pindex.ravel()), slice(None))]
return Field(self._domain, out.reshape(self._domain.shape))
arr = dobj.local_data(x.val)
firstaxis = x.domain.axes[self._space][0]
lastaxis = x.domain.axes[self._space][-1]
presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
arr = arr.reshape((presize,-1,postsize))
oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
oarr[()] = arr[(slice(None), pindex.ravel(), slice(None))]
return res
@property
def domain(self):
......
......@@ -143,8 +143,8 @@ class PowerSpace(Space):
else:
tbb = binbounds
locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val))
temp_pindex = dobj.create_from_template(
k_length_array.val, local_data=locdat, dtype=locdat.dtype)
temp_pindex = dobj.from_local_data(
k_length_array.val.shape, locdat, dobj.dist_axis(k_length_array.val))
nbin = len(tbb)
temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(),
minlength=nbin)
......
......@@ -115,6 +115,7 @@ class RGSpace(Space):
tmp[t2] = True
return np.sqrt(np.nonzero(tmp)[0])*self.distances[0]
else: # do it the hard way
# FIXME: this needs to improve for MPI. Maybe unique()/gather()?
tmp = np.unique(dobj.to_ndarray(self.get_k_length_array().val)) # expensive!
tol = 1e-12*tmp[-1]
# remove all points that are closer than tol to their right
......
......@@ -56,8 +56,9 @@ class SmoothingOperator_Tests(unittest.TestCase):
@expand(product(spaces, [0., .5, 5.]))
def test_times(self, space, sigma):
op = ift.FFTSmoothingOperator(space, sigma=sigma)
rand1 = ift.Field.zeros(space)
rand1.val[0] = 1.
fld = np.zeros(space.shape, dtype=np.float64)
fld[0] = 1.
rand1 = ift.Field(space, ift.dobj.from_global_data(fld, dist_axis=-1))
tt1 = op.times(rand1)
assert_allclose(1, tt1.sum())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment