From 67bed0b7f781fae2d7ec2b0c42317e7f3fed1b54 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Wed, 8 Nov 2017 11:57:18 +0100 Subject: [PATCH] tweaks --- nifty/data_objects/my_own_do.py | 40 +++++++---------- nifty/data_objects/numpy_do.py | 8 ---- nifty/field.py | 2 +- nifty/operators/power_projection_operator.py | 44 ++++++++++--------- nifty/spaces/power_space.py | 4 +- nifty/spaces/rg_space.py | 1 + .../test_operators/test_smoothing_operator.py | 5 ++- 7 files changed, 46 insertions(+), 58 deletions(-) diff --git a/nifty/data_objects/my_own_do.py b/nifty/data_objects/my_own_do.py index 8e9d86bf6..ecca5b12b 100644 --- a/nifty/data_objects/my_own_do.py +++ b/nifty/data_objects/my_own_do.py @@ -6,6 +6,7 @@ class data_object(object): def __init__(self, npdata): self._data = np.asarray(npdata) + # FIXME: subscripting support will most likely go away def __getitem__(self, key): res = self._data[key] return res if np.isscalar(res) else data_object(res) @@ -37,6 +38,9 @@ class data_object(object): return data_object(self._data.imag) def _contraction_helper(self, op, axis): + if axis is not None: + if len(axis)==len(self._data.shape): + axis = None if axis is None: return getattr(self._data, op)() @@ -164,32 +168,28 @@ def vdot(a, b): return np.vdot(a._data, b._data) +def _math_helper(x, function, out): + if out is not None: + function(x._data, out=out._data) + return out + else: + return data_object(function(x._data)) + + def abs(a, out=None): - if out is None: - out = empty_like(a) - np.abs(a._data, out=out._data) - return out + return _math_helper(a, np.abs, out) def exp(a, out=None): - if out is None: - out = empty_like(a) - np.exp(a._data, out=out._data) - return out + return _math_helper(a, np.exp, out) def log(a, out=None): - if out is None: - out = empty_like(a) - np.log(a._data, out=out._data) - return out + return _math_helper(a, np.log, out) def sqrt(a, out=None): - if out is None: - out = empty_like(a) - np.sqrt(a._data, out=out._data) - return out + return _math_helper(a, np.sqrt, out) def bincount(x, weights=None, minlength=None): @@ -224,12 +224,6 @@ def ibegin(arr): return (0,)*arr._data.ndim -def create_from_template(tmpl, local_data, dtype): - res = np.ndarray(tmpl.shape, dtype=dtype) - res[()] = local_data - return data_object(res) - - def np_allreduce_sum(arr): return arr @@ -249,8 +243,6 @@ def from_local_data (shape, arr, dist_axis): def from_global_data (arr, dist_axis): if dist_axis!=-1: raise NotImplementedError - if shape!=arr.shape: - raise ValueError return data_object(arr) diff --git a/nifty/data_objects/numpy_do.py b/nifty/data_objects/numpy_do.py index de2eb0ccd..40a30adaf 100644 --- a/nifty/data_objects/numpy_do.py +++ b/nifty/data_objects/numpy_do.py @@ -31,12 +31,6 @@ def ibegin(arr): return (0,)*arr.ndim -def create_from_template(tmpl, local_data, dtype): - res = np.ndarray(tmpl.shape, dtype=dtype) - res[()] = local_data - return res - - def np_allreduce_sum(arr): return arr @@ -56,8 +50,6 @@ def from_local_data (shape, arr, dist_axis): def from_global_data (arr, dist_axis): if dist_axis!=-1: raise NotImplementedError - if shape!=arr.shape: - raise ValueError return arr diff --git a/nifty/field.py b/nifty/field.py index 0edb52d07..75c78cf60 100644 --- a/nifty/field.py +++ b/nifty/field.py @@ -455,7 +455,7 @@ class Field(object): raise TypeError("argument must be a Field") if other.domain != self.domain: raise ValueError("domains are incompatible.") - self.val[()] = other.val[()] + dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()] # ---General binary methods--- diff --git a/nifty/operators/power_projection_operator.py b/nifty/operators/power_projection_operator.py index 8a817538b..33ca402df 100644 --- a/nifty/operators/power_projection_operator.py +++ b/nifty/operators/power_projection_operator.py @@ -55,34 +55,36 @@ class PowerProjectionOperator(LinearOperator): res = Field.zeros(self._target, dtype=x.dtype) if dobj.dist_axis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space pindex = dobj.local_data(pindex) - pindex.reshape((1, pindex.size, 1)) - arr = dobj.local_data(x.weight(1).val) - firstaxis = x.domain.axes[self._space][0] - lastaxis = x.domain.axes[self._space][-1] - presize = np.prod(arr.shape[0:firstaxis], dtype=np.int) - postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int) - arr = arr.reshape((presize,pindex.size,postsize)) - oarr = dobj.local_data(res.val).reshape((presize,-1,postsize)) - np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr) else: pindex = dobj.to_ndarray(pindex) - pindex.reshape((1, pindex.size, 1)) - arr = dobj.local_data(x.weight(1).val) - firstaxis = x.domain.axes[self._space][0] - lastaxis = x.domain.axes[self._space][-1] - presize = np.prod(arr.shape[0:firstaxis], dtype=np.int) - postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int) - arr = arr.reshape((presize,pindex.size,postsize)) - oarr = dobj.local_data(res.val).reshape((presize,-1,postsize)) - np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr) + pindex.reshape((1, pindex.size, 1)) + arr = dobj.local_data(x.weight(1).val) + firstaxis = x.domain.axes[self._space][0] + lastaxis = x.domain.axes[self._space][-1] + presize = np.prod(arr.shape[0:firstaxis], dtype=np.int) + postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int) + arr = arr.reshape((presize,pindex.size,postsize)) + oarr = dobj.local_data(res.val).reshape((presize,-1,postsize)) + np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr) return res.weight(-1, spaces=self._space) def _adjoint_times(self, x): pindex = self._target[self._space].pindex + res = Field.empty(self._domain, dtype=x.dtype) + if dobj.dist_axis(x.val) in x.domain.axes[self._space]: # the distributed axis is part of the projected space + pindex = dobj.local_data(pindex) + else: + pindex = dobj.to_ndarray(pindex) pindex = pindex.reshape((1, pindex.size, 1)) - arr = x.val.reshape(x.domain.collapsed_shape_for_domain(self._space)) - out = arr[(slice(None), dobj.to_ndarray(pindex.ravel()), slice(None))] - return Field(self._domain, out.reshape(self._domain.shape)) + arr = dobj.local_data(x.val) + firstaxis = x.domain.axes[self._space][0] + lastaxis = x.domain.axes[self._space][-1] + presize = np.prod(arr.shape[0:firstaxis], dtype=np.int) + postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int) + arr = arr.reshape((presize,-1,postsize)) + oarr = dobj.local_data(res.val).reshape((presize,-1,postsize)) + oarr[()] = arr[(slice(None), pindex.ravel(), slice(None))] + return res @property def domain(self): diff --git a/nifty/spaces/power_space.py b/nifty/spaces/power_space.py index 05fe9ce6a..a75795226 100644 --- a/nifty/spaces/power_space.py +++ b/nifty/spaces/power_space.py @@ -143,8 +143,8 @@ class PowerSpace(Space): else: tbb = binbounds locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val)) - temp_pindex = dobj.create_from_template( - k_length_array.val, local_data=locdat, dtype=locdat.dtype) + temp_pindex = dobj.from_local_data( + k_length_array.val.shape, locdat, dobj.dist_axis(k_length_array.val)) nbin = len(tbb) temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(), minlength=nbin) diff --git a/nifty/spaces/rg_space.py b/nifty/spaces/rg_space.py index 947e98e87..42b35affc 100644 --- a/nifty/spaces/rg_space.py +++ b/nifty/spaces/rg_space.py @@ -115,6 +115,7 @@ class RGSpace(Space): tmp[t2] = True return np.sqrt(np.nonzero(tmp)[0])*self.distances[0] else: # do it the hard way + # FIXME: this needs to improve for MPI. Maybe unique()/gather()? tmp = np.unique(dobj.to_ndarray(self.get_k_length_array().val)) # expensive! tol = 1e-12*tmp[-1] # remove all points that are closer than tol to their right diff --git a/test/test_operators/test_smoothing_operator.py b/test/test_operators/test_smoothing_operator.py index 9180fabf6..2e6cd345f 100644 --- a/test/test_operators/test_smoothing_operator.py +++ b/test/test_operators/test_smoothing_operator.py @@ -56,8 +56,9 @@ class SmoothingOperator_Tests(unittest.TestCase): @expand(product(spaces, [0., .5, 5.])) def test_times(self, space, sigma): op = ift.FFTSmoothingOperator(space, sigma=sigma) - rand1 = ift.Field.zeros(space) - rand1.val[0] = 1. + fld = np.zeros(space.shape, dtype=np.float64) + fld[0] = 1. + rand1 = ift.Field(space, ift.dobj.from_global_data(fld, dist_axis=-1)) tt1 = op.times(rand1) assert_allclose(1, tt1.sum()) -- GitLab