From 12e3d59777aa8e01af2fa846f482598f003da432 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sun, 24 Sep 2017 17:55:24 +0200
Subject: [PATCH] simplify DiagonalOperator

---
 nifty/data_objects/numpy_do.py                | 65 ++++++++++++++++++
 nifty/field.py                                |  6 +-
 nifty/nifty_utilities.py                      | 66 -------------------
 .../diagonal_operator/diagonal_operator.py    | 41 ++----------
 .../response_operator/response_operator.py    |  4 +-
 test/test_minimization/test_minimizers.py     |  2 +-
 test/test_operators/test_composed_operator.py | 12 ++--
 test/test_operators/test_diagonal_operator.py | 18 ++---
 8 files changed, 93 insertions(+), 121 deletions(-)

diff --git a/nifty/data_objects/numpy_do.py b/nifty/data_objects/numpy_do.py
index 97ea8dd45..de04b230e 100644
--- a/nifty/data_objects/numpy_do.py
+++ b/nifty/data_objects/numpy_do.py
@@ -3,6 +3,71 @@
 import numpy as np
 from numpy import ndarray as data_object
 from numpy import full, empty, sqrt, ones, zeros, vdot, abs, bincount
+from ..nifty_utilities import cast_iseq_to_tuple, get_slice_list
+from functools import reduce
 
 def from_object(object, dtype=None, copy=True):
     return np.array(object, dtype=dtype, copy=copy)
+
+def bincount_axis(obj, minlength=None, weights=None, axis=None):
+    if minlength is not None:
+        length = max(np.amax(obj) + 1, minlength)
+    else:
+        length = np.amax(obj) + 1
+
+    if obj.shape == ():
+        raise ValueError("object of too small depth for desired array")
+    data = obj
+
+    # if present, parse the axis keyword and transpose/reorder self.data
+    # such that all affected axes follow each other. Only if they are in a
+    # sequence flattening will be possible
+    if axis is not None:
+        # do the reordering
+        ndim = len(obj.shape)
+        axis = sorted(cast_iseq_to_tuple(axis))
+        reordering = [x for x in range(ndim) if x not in axis]
+        reordering += axis
+
+        data = np.transpose(data, reordering)
+        if weights is not None:
+            weights = np.transpose(weights, reordering)
+
+        reord_axis = list(range(ndim-len(axis), ndim))
+
+        # semi-flatten the dimensions in `axis`, i.e. after reordering
+        # the last ones.
+        semi_flat_dim = reduce(lambda x, y: x*y,
+                               data.shape[ndim-len(reord_axis):])
+        flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
+    else:
+        flat_shape = (reduce(lambda x, y: x*y, data.shape), )
+
+    data = np.ascontiguousarray(data.reshape(flat_shape))
+    if weights is not None:
+        weights = np.ascontiguousarray(weights.reshape(flat_shape))
+
+    # compute the local bincount results
+    # -> prepare the local result array
+    result_dtype = np.int if weights is None else np.float
+    local_counts = np.empty(flat_shape[:-1] + (length, ), dtype=result_dtype)
+    # iterate over all entries in the surviving axes and compute the local
+    # bincounts
+    for slice_list in get_slice_list(flat_shape, axes=(len(flat_shape)-1,)):
+        current_weights = None if weights is None else weights[slice_list]
+        local_counts[slice_list] = np.bincount(data[slice_list],
+                                               weights=current_weights,
+                                               minlength=length)
+
+    # restore the original ordering
+    # place the bincount stuff at the location of the first `axis` entry
+    if axis is not None:
+        # axis has been sorted above
+        insert_position = axis[0]
+        new_ndim = len(local_counts.shape)
+        return_order = (list(range(0, insert_position)) +
+                        [new_ndim-1, ] +
+                        list(range(insert_position, new_ndim-1)))
+        local_counts = np.ascontiguousarray(
+                            local_counts.transpose(return_order))
+    return local_counts
diff --git a/nifty/field.py b/nifty/field.py
index 626ae42e8..e1685c24a 100644
--- a/nifty/field.py
+++ b/nifty/field.py
@@ -231,8 +231,8 @@ class Field(object):
             new_pindex_shape[ax] = pindex.shape[i]
         pindex = np.broadcast_to(pindex.reshape(new_pindex_shape), field.shape)
 
-        power_spectrum = utilities.bincount_axis(pindex, weights=field.val,
-                                                 axis=axes)
+        power_spectrum = dobj.bincount_axis(pindex, weights=field.val,
+                                            axis=axes)
         new_rho_shape = [1] * len(power_spectrum.shape)
         new_rho_shape[axes[0]] = len(power_domain.rho)
         power_spectrum /= power_domain.rho.reshape(new_rho_shape)
@@ -510,7 +510,7 @@ class Field(object):
             # create a diagonal operator which is capable of taking care of the
             # axes-matching
             from .operators.diagonal_operator import DiagonalOperator
-            diag = DiagonalOperator(y.domain, y.conjugate(), copy=False)
+            diag = DiagonalOperator(y.conjugate(), copy=False)
             dotted = diag(x, spaces=spaces)
             return fct*dotted.sum(spaces=spaces)
 
diff --git a/nifty/nifty_utilities.py b/nifty/nifty_utilities.py
index 0b4d5dfce..c8df74b54 100644
--- a/nifty/nifty_utilities.py
+++ b/nifty/nifty_utilities.py
@@ -19,8 +19,6 @@
 from builtins import next, range
 import numpy as np
 from itertools import product
-from functools import reduce
-from .domain_object import DomainObject
 
 
 def get_slice_list(shape, axes):
@@ -74,67 +72,3 @@ def cast_iseq_to_tuple(seq):
     if np.isscalar(seq):
         return (int(seq),)
     return tuple(int(item) for item in seq)
-
-
-def bincount_axis(obj, minlength=None, weights=None, axis=None):
-    if minlength is not None:
-        length = max(np.amax(obj) + 1, minlength)
-    else:
-        length = np.amax(obj) + 1
-
-    if obj.shape == ():
-        raise ValueError("object of too small depth for desired array")
-    data = obj
-
-    # if present, parse the axis keyword and transpose/reorder self.data
-    # such that all affected axes follow each other. Only if they are in a
-    # sequence flattening will be possible
-    if axis is not None:
-        # do the reordering
-        ndim = len(obj.shape)
-        axis = sorted(cast_iseq_to_tuple(axis))
-        reordering = [x for x in range(ndim) if x not in axis]
-        reordering += axis
-
-        data = np.transpose(data, reordering)
-        if weights is not None:
-            weights = np.transpose(weights, reordering)
-
-        reord_axis = list(range(ndim-len(axis), ndim))
-
-        # semi-flatten the dimensions in `axis`, i.e. after reordering
-        # the last ones.
-        semi_flat_dim = reduce(lambda x, y: x*y,
-                               data.shape[ndim-len(reord_axis):])
-        flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
-    else:
-        flat_shape = (reduce(lambda x, y: x*y, data.shape), )
-
-    data = np.ascontiguousarray(data.reshape(flat_shape))
-    if weights is not None:
-        weights = np.ascontiguousarray(weights.reshape(flat_shape))
-
-    # compute the local bincount results
-    # -> prepare the local result array
-    result_dtype = np.int if weights is None else np.float
-    local_counts = np.empty(flat_shape[:-1] + (length, ), dtype=result_dtype)
-    # iterate over all entries in the surviving axes and compute the local
-    # bincounts
-    for slice_list in get_slice_list(flat_shape, axes=(len(flat_shape)-1,)):
-        current_weights = None if weights is None else weights[slice_list]
-        local_counts[slice_list] = np.bincount(data[slice_list],
-                                               weights=current_weights,
-                                               minlength=length)
-
-    # restore the original ordering
-    # place the bincount stuff at the location of the first `axis` entry
-    if axis is not None:
-        # axis has been sorted above
-        insert_position = axis[0]
-        new_ndim = len(local_counts.shape)
-        return_order = (list(range(0, insert_position)) +
-                        [new_ndim-1, ] +
-                        list(range(insert_position, new_ndim-1)))
-        local_counts = np.ascontiguousarray(
-                            local_counts.transpose(return_order))
-    return local_counts
diff --git a/nifty/operators/diagonal_operator/diagonal_operator.py b/nifty/operators/diagonal_operator/diagonal_operator.py
index 34672ec3b..778aa41b7 100644
--- a/nifty/operators/diagonal_operator/diagonal_operator.py
+++ b/nifty/operators/diagonal_operator/diagonal_operator.py
@@ -35,9 +35,7 @@ class DiagonalOperator(EndomorphicOperator):
 
     Parameters
     ----------
-    domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
-        The domain on which the Operator's input Field lives.
-    diagonal : {scalar, list, array, Field}
+    diagonal : Field
         The diagonal entries of the operator.
     copy : boolean
         Internal copy of the diagonal (default: True)
@@ -68,15 +66,14 @@ class DiagonalOperator(EndomorphicOperator):
 
     # ---Overwritten properties and methods---
 
-    def __init__(self, domain=(), diagonal=None, copy=True,
-                 default_spaces=None):
+    def __init__(self, diagonal, copy=True, default_spaces=None):
         super(DiagonalOperator, self).__init__(default_spaces)
 
-        self._domain = DomainTuple.make(domain)
-
+        if not isinstance(diagonal, Field):
+            raise TypeError("Field object required")
+        self._diagonal = diagonal if not copy else diagonal.copy()
         self._self_adjoint = None
         self._unitary = None
-        self.set_diagonal(diagonal=diagonal, copy=copy)
 
     def _times(self, x, spaces):
         return self._times_helper(x, spaces, operation=lambda z: z.__mul__)
@@ -119,13 +116,13 @@ class DiagonalOperator(EndomorphicOperator):
             The inverse of the diagonal of the Operator.
 
         """
-        return 1./self.diagonal(copy=False)
+        return 1./self._diagonal
 
     # ---Mandatory properties and methods---
 
     @property
     def domain(self):
-        return self._domain
+        return self._diagonal.domain
 
     @property
     def self_adjoint(self):
@@ -144,30 +141,6 @@ class DiagonalOperator(EndomorphicOperator):
 
     # ---Added properties and methods---
 
-    def set_diagonal(self, diagonal, copy=True):
-        """ Sets the diagonal of the Operator.
-
-        Parameters
-        ----------
-        diagonal : {scalar, list, array, Field}
-            The diagonal entries of the operator.
-        copy : boolean
-            Specifies if a copy of the input shall be made (default: True).
-
-        """
-
-        # use the casting functionality from Field to process `diagonal`
-        f = Field(domain=self.domain, val=diagonal, copy=copy)
-
-        # Reset the self_adjoint property:
-        self._self_adjoint = None
-
-        # Reset the unitarity property
-        self._unitary = None
-
-        # store the diagonal-field
-        self._diagonal = f
-
     def _times_helper(self, x, spaces, operation):
         # if the domain matches directly
         # -> multiply the fields directly
diff --git a/nifty/operators/response_operator/response_operator.py b/nifty/operators/response_operator/response_operator.py
index f8e89e958..d215b0928 100644
--- a/nifty/operators/response_operator/response_operator.py
+++ b/nifty/operators/response_operator/response_operator.py
@@ -59,8 +59,8 @@ class ResponseOperator(LinearOperator):
 
         kernel_smoothing = [FFTSmoothingOperator(self._domain[x], sigma[x])
                             for x in range(nsigma)]
-        kernel_exposure = [DiagonalOperator(self._domain[x],
-                           diagonal=exposure[x]) for x in range(nsigma)]
+        kernel_exposure = [DiagonalOperator(Field(self._domain[x],exposure[x]))
+                           for x in range(nsigma)]
 
         self._composed_kernel = ComposedOperator(kernel_smoothing)
         self._composed_exposure = ComposedOperator(kernel_exposure)
diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py
index ac9fe5619..59fddb5d8 100644
--- a/test/test_minimization/test_minimizers.py
+++ b/test/test_minimization/test_minimizers.py
@@ -21,7 +21,7 @@ class Test_Minimizers(unittest.TestCase):
         starting_point = ift.Field.from_random('normal', domain=space)*10
         covariance_diagonal = ift.Field.from_random(
                                   'uniform', domain=space) + 0.5
-        covariance = ift.DiagonalOperator(space, diagonal=covariance_diagonal)
+        covariance = ift.DiagonalOperator(covariance_diagonal)
         required_result = ift.Field(space, val=1.)
 
         IC = ift.DefaultIterationController(tol_abs_gradnorm=1e-5)
diff --git a/test/test_operators/test_composed_operator.py b/test/test_operators/test_composed_operator.py
index 0119a4d23..27f97606d 100644
--- a/test/test_operators/test_composed_operator.py
+++ b/test/test_operators/test_composed_operator.py
@@ -20,8 +20,8 @@ class ComposedOperator_Tests(unittest.TestCase):
     def test_property(self, space1, space2):
         rand1 = Field.from_random('normal', domain=space1)
         rand2 = Field.from_random('normal', domain=space2)
-        op1 = DiagonalOperator(space1, diagonal=rand1)
-        op2 = DiagonalOperator(space2, diagonal=rand2)
+        op1 = DiagonalOperator(rand1)
+        op2 = DiagonalOperator(rand2)
         op = ComposedOperator((op1, op2))
         if op.domain != (op1.domain[0], op2.domain[0]):
             raise TypeError
@@ -32,8 +32,8 @@ class ComposedOperator_Tests(unittest.TestCase):
     def test_times_adjoint_times(self, space1, space2):
         diag1 = Field.from_random('normal', domain=space1)
         diag2 = Field.from_random('normal', domain=space2)
-        op1 = DiagonalOperator(space1, diagonal=diag1)
-        op2 = DiagonalOperator(space2, diagonal=diag2)
+        op1 = DiagonalOperator(diag1)
+        op2 = DiagonalOperator(diag2)
 
         op = ComposedOperator((op1, op2))
 
@@ -48,8 +48,8 @@ class ComposedOperator_Tests(unittest.TestCase):
     def test_times_inverse_times(self, space1, space2):
         diag1 = Field.from_random('normal', domain=space1)
         diag2 = Field.from_random('normal', domain=space2)
-        op1 = DiagonalOperator(space1, diagonal=diag1)
-        op2 = DiagonalOperator(space2, diagonal=diag2)
+        op1 = DiagonalOperator(diag1)
+        op2 = DiagonalOperator(diag2)
 
         op = ComposedOperator((op1, op2))
 
diff --git a/test/test_operators/test_diagonal_operator.py b/test/test_operators/test_diagonal_operator.py
index 98618af70..36261e6b9 100644
--- a/test/test_operators/test_diagonal_operator.py
+++ b/test/test_operators/test_diagonal_operator.py
@@ -20,7 +20,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
     @expand(product(spaces, [True, False]))
     def test_property(self, space, copy):
         diag = Field.from_random('normal', domain=space)
-        D = DiagonalOperator(space, diagonal=diag)
+        D = DiagonalOperator(diag)
         if D.domain[0] != space:
             raise TypeError
         if D.unitary != False:
@@ -33,7 +33,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
         rand1 = Field.from_random('normal', domain=space)
         rand2 = Field.from_random('normal', domain=space)
         diag = Field.from_random('normal', domain=space)
-        D = DiagonalOperator(space, diagonal=diag, copy=copy)
+        D = DiagonalOperator(diag, copy=copy)
         tt1 = rand1.vdot(D.times(rand2))
         tt2 = rand2.vdot(D.times(rand1))
         assert_approx_equal(tt1, tt2)
@@ -42,7 +42,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
     def test_times_inverse(self, space, copy):
         rand1 = Field.from_random('normal', domain=space)
         diag = Field.from_random('normal', domain=space)
-        D = DiagonalOperator(space, diagonal=diag, copy=copy)
+        D = DiagonalOperator(diag, copy=copy)
         tt1 = D.times(D.inverse_times(rand1))
         assert_allclose(rand1.val, tt1.val)
 
@@ -50,7 +50,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
     def test_times(self, space, copy):
         rand1 = Field.from_random('normal', domain=space)
         diag = Field.from_random('normal', domain=space)
-        D = DiagonalOperator(space, diagonal=diag, copy=copy)
+        D = DiagonalOperator(diag, copy=copy)
         tt = D.times(rand1)
         assert_equal(tt.domain[0], space)
 
@@ -58,7 +58,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
     def test_adjoint_times(self, space, copy):
         rand1 = Field.from_random('normal', domain=space)
         diag = Field.from_random('normal', domain=space)
-        D = DiagonalOperator(space, diagonal=diag, copy=copy)
+        D = DiagonalOperator(diag, copy=copy)
         tt = D.adjoint_times(rand1)
         assert_equal(tt.domain[0], space)
 
@@ -66,7 +66,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
     def test_inverse_times(self, space, copy):
         rand1 = Field.from_random('normal', domain=space)
         diag = Field.from_random('normal', domain=space)
-        D = DiagonalOperator(space, diagonal=diag, copy=copy)
+        D = DiagonalOperator(diag, copy=copy)
         tt = D.inverse_times(rand1)
         assert_equal(tt.domain[0], space)
 
@@ -74,20 +74,20 @@ class DiagonalOperator_Tests(unittest.TestCase):
     def test_adjoint_inverse_times(self, space, copy):
         rand1 = Field.from_random('normal', domain=space)
         diag = Field.from_random('normal', domain=space)
-        D = DiagonalOperator(space, diagonal=diag, copy=copy)
+        D = DiagonalOperator(diag, copy=copy)
         tt = D.adjoint_inverse_times(rand1)
         assert_equal(tt.domain[0], space)
 
     @expand(product(spaces, [True, False]))
     def test_diagonal(self, space, copy):
         diag = Field.from_random('normal', domain=space)
-        D = DiagonalOperator(space, diagonal=diag, copy=copy)
+        D = DiagonalOperator(diag, copy=copy)
         diag_op = D.diagonal()
         assert_allclose(diag.val, diag_op.val)
 
     @expand(product(spaces, [True, False]))
     def test_inverse(self, space, copy):
         diag = Field.from_random('normal', domain=space)
-        D = DiagonalOperator(space, diagonal=diag, copy=copy)
+        D = DiagonalOperator(diag, copy=copy)
         diag_op = D.inverse_diagonal()
         assert_allclose(1./diag.val, diag_op.val)
-- 
GitLab