From 67bed0b7f781fae2d7ec2b0c42317e7f3fed1b54 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Wed, 8 Nov 2017 11:57:18 +0100
Subject: [PATCH] tweaks

---
 nifty/data_objects/my_own_do.py               | 40 +++++++----------
 nifty/data_objects/numpy_do.py                |  8 ----
 nifty/field.py                                |  2 +-
 nifty/operators/power_projection_operator.py  | 44 ++++++++++---------
 nifty/spaces/power_space.py                   |  4 +-
 nifty/spaces/rg_space.py                      |  1 +
 .../test_operators/test_smoothing_operator.py |  5 ++-
 7 files changed, 46 insertions(+), 58 deletions(-)

diff --git a/nifty/data_objects/my_own_do.py b/nifty/data_objects/my_own_do.py
index 8e9d86bf6..ecca5b12b 100644
--- a/nifty/data_objects/my_own_do.py
+++ b/nifty/data_objects/my_own_do.py
@@ -6,6 +6,7 @@ class data_object(object):
     def __init__(self, npdata):
         self._data = np.asarray(npdata)
 
+    # FIXME: subscripting support will most likely go away
     def __getitem__(self, key):
         res = self._data[key]
         return res if np.isscalar(res) else data_object(res)
@@ -37,6 +38,9 @@ class data_object(object):
         return data_object(self._data.imag)
 
     def _contraction_helper(self, op, axis):
+        if axis is not None:
+            if len(axis)==len(self._data.shape):
+                axis = None
         if axis is None:
             return getattr(self._data, op)()
 
@@ -164,32 +168,28 @@ def vdot(a, b):
     return np.vdot(a._data, b._data)
 
 
+def _math_helper(x, function, out):
+    if out is not None:
+        function(x._data, out=out._data)
+        return out
+    else:
+        return data_object(function(x._data))
+
+
 def abs(a, out=None):
-    if out is None:
-        out = empty_like(a)
-    np.abs(a._data, out=out._data)
-    return out
+    return _math_helper(a, np.abs, out)
 
 
 def exp(a, out=None):
-    if out is None:
-        out = empty_like(a)
-    np.exp(a._data, out=out._data)
-    return out
+    return _math_helper(a, np.exp, out)
 
 
 def log(a, out=None):
-    if out is None:
-        out = empty_like(a)
-    np.log(a._data, out=out._data)
-    return out
+    return _math_helper(a, np.log, out)
 
 
 def sqrt(a, out=None):
-    if out is None:
-        out = empty_like(a)
-    np.sqrt(a._data, out=out._data)
-    return out
+    return _math_helper(a, np.sqrt, out)
 
 
 def bincount(x, weights=None, minlength=None):
@@ -224,12 +224,6 @@ def ibegin(arr):
     return (0,)*arr._data.ndim
 
 
-def create_from_template(tmpl, local_data, dtype):
-    res = np.ndarray(tmpl.shape, dtype=dtype)
-    res[()] = local_data
-    return data_object(res)
-
-
 def np_allreduce_sum(arr):
     return arr
 
@@ -249,8 +243,6 @@ def from_local_data (shape, arr, dist_axis):
 def from_global_data (arr, dist_axis):
     if dist_axis!=-1:
         raise NotImplementedError
-    if shape!=arr.shape:
-        raise ValueError
     return data_object(arr)
 
 
diff --git a/nifty/data_objects/numpy_do.py b/nifty/data_objects/numpy_do.py
index de2eb0ccd..40a30adaf 100644
--- a/nifty/data_objects/numpy_do.py
+++ b/nifty/data_objects/numpy_do.py
@@ -31,12 +31,6 @@ def ibegin(arr):
     return (0,)*arr.ndim
 
 
-def create_from_template(tmpl, local_data, dtype):
-    res = np.ndarray(tmpl.shape, dtype=dtype)
-    res[()] = local_data
-    return res
-
-
 def np_allreduce_sum(arr):
     return arr
 
@@ -56,8 +50,6 @@ def from_local_data (shape, arr, dist_axis):
 def from_global_data (arr, dist_axis):
     if dist_axis!=-1:
         raise NotImplementedError
-    if shape!=arr.shape:
-        raise ValueError
     return arr
 
 
diff --git a/nifty/field.py b/nifty/field.py
index 0edb52d07..75c78cf60 100644
--- a/nifty/field.py
+++ b/nifty/field.py
@@ -455,7 +455,7 @@ class Field(object):
             raise TypeError("argument must be a Field")
         if other.domain != self.domain:
             raise ValueError("domains are incompatible.")
-        self.val[()] = other.val[()]
+        dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()]
 
     # ---General binary methods---
 
diff --git a/nifty/operators/power_projection_operator.py b/nifty/operators/power_projection_operator.py
index 8a817538b..33ca402df 100644
--- a/nifty/operators/power_projection_operator.py
+++ b/nifty/operators/power_projection_operator.py
@@ -55,34 +55,36 @@ class PowerProjectionOperator(LinearOperator):
         res = Field.zeros(self._target, dtype=x.dtype)
         if dobj.dist_axis(x.val) in x.domain.axes[self._space]:  # the distributed axis is part of the projected space
             pindex = dobj.local_data(pindex)
-            pindex.reshape((1, pindex.size, 1))
-            arr = dobj.local_data(x.weight(1).val)
-            firstaxis = x.domain.axes[self._space][0]
-            lastaxis = x.domain.axes[self._space][-1]
-            presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
-            postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
-            arr = arr.reshape((presize,pindex.size,postsize))
-            oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
-            np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
         else:
             pindex = dobj.to_ndarray(pindex)
-            pindex.reshape((1, pindex.size, 1))
-            arr = dobj.local_data(x.weight(1).val)
-            firstaxis = x.domain.axes[self._space][0]
-            lastaxis = x.domain.axes[self._space][-1]
-            presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
-            postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
-            arr = arr.reshape((presize,pindex.size,postsize))
-            oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
-            np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
+        pindex.reshape((1, pindex.size, 1))
+        arr = dobj.local_data(x.weight(1).val)
+        firstaxis = x.domain.axes[self._space][0]
+        lastaxis = x.domain.axes[self._space][-1]
+        presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
+        postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
+        arr = arr.reshape((presize,pindex.size,postsize))
+        oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
+        np.add.at(oarr, (slice(None), pindex.ravel(), slice(None)), arr)
         return res.weight(-1, spaces=self._space)
 
     def _adjoint_times(self, x):
         pindex = self._target[self._space].pindex
+        res = Field.empty(self._domain, dtype=x.dtype)
+        if dobj.dist_axis(x.val) in x.domain.axes[self._space]:  # the distributed axis is part of the projected space
+            pindex = dobj.local_data(pindex)
+        else:
+            pindex = dobj.to_ndarray(pindex)
         pindex = pindex.reshape((1, pindex.size, 1))
-        arr = x.val.reshape(x.domain.collapsed_shape_for_domain(self._space))
-        out = arr[(slice(None), dobj.to_ndarray(pindex.ravel()), slice(None))]
-        return Field(self._domain, out.reshape(self._domain.shape))
+        arr = dobj.local_data(x.val)
+        firstaxis = x.domain.axes[self._space][0]
+        lastaxis = x.domain.axes[self._space][-1]
+        presize = np.prod(arr.shape[0:firstaxis], dtype=np.int)
+        postsize = np.prod(arr.shape[lastaxis+1:], dtype=np.int)
+        arr = arr.reshape((presize,-1,postsize))
+        oarr = dobj.local_data(res.val).reshape((presize,-1,postsize))
+        oarr[()] = arr[(slice(None), pindex.ravel(), slice(None))]
+        return res
 
     @property
     def domain(self):
diff --git a/nifty/spaces/power_space.py b/nifty/spaces/power_space.py
index 05fe9ce6a..a75795226 100644
--- a/nifty/spaces/power_space.py
+++ b/nifty/spaces/power_space.py
@@ -143,8 +143,8 @@ class PowerSpace(Space):
             else:
                 tbb = binbounds
             locdat = np.searchsorted(tbb, dobj.local_data(k_length_array.val))
-            temp_pindex = dobj.create_from_template(
-                k_length_array.val, local_data=locdat, dtype=locdat.dtype)
+            temp_pindex = dobj.from_local_data(
+                k_length_array.val.shape, locdat, dobj.dist_axis(k_length_array.val))
             nbin = len(tbb)
             temp_rho = np.bincount(dobj.local_data(temp_pindex).ravel(),
                                    minlength=nbin)
diff --git a/nifty/spaces/rg_space.py b/nifty/spaces/rg_space.py
index 947e98e87..42b35affc 100644
--- a/nifty/spaces/rg_space.py
+++ b/nifty/spaces/rg_space.py
@@ -115,6 +115,7 @@ class RGSpace(Space):
             tmp[t2] = True
             return np.sqrt(np.nonzero(tmp)[0])*self.distances[0]
         else:  # do it the hard way
+            # FIXME: this needs to improve for MPI. Maybe unique()/gather()?
             tmp = np.unique(dobj.to_ndarray(self.get_k_length_array().val))  # expensive!
             tol = 1e-12*tmp[-1]
             # remove all points that are closer than tol to their right
diff --git a/test/test_operators/test_smoothing_operator.py b/test/test_operators/test_smoothing_operator.py
index 9180fabf6..2e6cd345f 100644
--- a/test/test_operators/test_smoothing_operator.py
+++ b/test/test_operators/test_smoothing_operator.py
@@ -56,8 +56,9 @@ class SmoothingOperator_Tests(unittest.TestCase):
     @expand(product(spaces, [0., .5, 5.]))
     def test_times(self, space, sigma):
         op = ift.FFTSmoothingOperator(space, sigma=sigma)
-        rand1 = ift.Field.zeros(space)
-        rand1.val[0] = 1.
+        fld = np.zeros(space.shape, dtype=np.float64)
+        fld[0] = 1.
+        rand1 = ift.Field(space, ift.dobj.from_global_data(fld, dist_axis=-1))
         tt1 = op.times(rand1)
         assert_allclose(1, tt1.sum())
 
-- 
GitLab