From eadb48d653b4aed37fd9d431122a1d428555ce21 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Tue, 14 Nov 2017 21:15:58 +0100
Subject: [PATCH] consolidation

---
 nifty/__init__.py                             |  4 +-
 nifty/domain_object.py                        |  3 +-
 nifty/domain_tuple.py                         | 16 +-----
 nifty/energies/energy.py                      |  3 +-
 nifty/energies/quadratic_energy.py            |  2 +-
 nifty/field.py                                | 54 ++++++++-----------
 .../critical_filter/critical_power_energy.py  |  2 +-
 .../log_normal_wiener_filter_curvature.py     |  2 +-
 .../log_normal_wiener_filter_energy.py        |  2 +-
 .../wiener_filter/wiener_filter_energy.py     |  2 +-
 nifty/memoization.py                          | 31 -----------
 .../iteration_controller.py                   |  2 +-
 nifty/minimization/minimizer.py               |  2 +-
 nifty/nifty_meta.py                           | 38 -------------
 nifty/operators/diagonal_operator.py          |  2 +-
 nifty/operators/fft_operator_support.py       |  2 +-
 nifty/operators/linear_operator.py            |  2 +-
 nifty/probing/prober.py                       |  2 +-
 nifty/spaces/rg_space.py                      | 13 ++---
 nifty/sugar.py                                |  4 +-
 nifty/{nifty_utilities.py => utilities.py}    | 54 ++++++++++++++++++-
 test/test_field.py                            |  2 +-
 22 files changed, 97 insertions(+), 147 deletions(-)
 delete mode 100644 nifty/memoization.py
 delete mode 100644 nifty/nifty_meta.py
 rename nifty/{nifty_utilities.py => utilities.py} (55%)

diff --git a/nifty/__init__.py b/nifty/__init__.py
index 8595f2c0c..e122f454e 100644
--- a/nifty/__init__.py
+++ b/nifty/__init__.py
@@ -10,7 +10,7 @@ from .domain_object import DomainObject
 
 from .basic_arithmetics import *
 
-from .nifty_utilities import *
+from .utilities import *
 
 from .field_types import *
 
@@ -31,5 +31,3 @@ from . import plotting
 from . import library
 
 from . import dobj
-
-from .memoization import memo
diff --git a/nifty/domain_object.py b/nifty/domain_object.py
index 22193c51b..301d4599a 100644
--- a/nifty/domain_object.py
+++ b/nifty/domain_object.py
@@ -18,7 +18,7 @@
 
 from __future__ import division
 import abc
-from .nifty_meta import NiftyMeta
+from .utilities import NiftyMeta
 from future.utils import with_metaclass
 
 
@@ -38,7 +38,6 @@ class DomainObject(with_metaclass(
         raise NotImplementedError
 
     def __hash__(self):
-        # Extract the identifying parts from the vars(self) dict.
         result_hash = 0
         for key in self._needed_for_hash:
             result_hash ^= hash(vars(self)[key])
diff --git a/nifty/domain_tuple.py b/nifty/domain_tuple.py
index 87e5abb56..df52725ff 100644
--- a/nifty/domain_tuple.py
+++ b/nifty/domain_tuple.py
@@ -105,24 +105,10 @@ class DomainTuple(object):
         return self._dom == x._dom
 
     def __ne__(self, x):
-        if not isinstance(x, DomainTuple):
-            x = DomainTuple.make(x)
-        if self is x:
-            return False
-        return self._dom != x._dom
+        return not self.__eq__(x)
 
     def __str__(self):
         res = "DomainTuple, len: " + str(len(self.domains))
         for i in self.domains:
             res += "\n" + str(i)
         return res
-
-    def collapsed_shape_for_domain(self, ispace):
-        """Returns a three-component shape, with the total number of pixels
-        in the domains before the requested space in res[0], the total number
-        of pixels in the requested space in res[1], and the remaining pixels in
-        res[2].
-        """
-        return (self._accdims[ispace],
-                self._accdims[ispace+1]//self._accdims[ispace],
-                self._accdims[-1]//self._accdims[ispace+1])
diff --git a/nifty/energies/energy.py b/nifty/energies/energy.py
index c6df68e11..d53e3c1ec 100644
--- a/nifty/energies/energy.py
+++ b/nifty/energies/energy.py
@@ -16,8 +16,7 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
-from ..nifty_meta import NiftyMeta
-from ..memoization import memo
+from ..utilities import memo, NiftyMeta
 from future.utils import with_metaclass
 
 
diff --git a/nifty/energies/quadratic_energy.py b/nifty/energies/quadratic_energy.py
index f50e4aae4..b87d6bc56 100644
--- a/nifty/energies/quadratic_energy.py
+++ b/nifty/energies/quadratic_energy.py
@@ -1,5 +1,5 @@
 from .energy import Energy
-from ..memoization import memo
+from ..utilities import memo
 
 
 class QuadraticEnergy(Energy):
diff --git a/nifty/field.py b/nifty/field.py
index fa6481a19..99861349d 100644
--- a/nifty/field.py
+++ b/nifty/field.py
@@ -19,7 +19,7 @@
 from __future__ import division
 from builtins import range
 import numpy as np
-from . import nifty_utilities as utilities
+from . import utilities
 from .domain_tuple import DomainTuple
 from functools import reduce
 from . import dobj
@@ -65,7 +65,6 @@ class Field(object):
             *the given domain contains something that is not a DomainObject
              instance
             *val is an array that has a different dimension than the domain
-
     """
 
     # ---Initialization methods---
@@ -179,7 +178,6 @@ class Field(object):
         out : Field
             The output object.
         """
-
         domain = DomainTuple.make(domain)
         return Field(domain=domain,
                      val=dobj.from_random(random_type, dtype=dtype,
@@ -194,10 +192,6 @@ class Field(object):
     def val(self):
         """ Returns the data object associated with this Field.
         No copy is made.
-
-        Returns
-        -------
-        out : numpy.ndarray
         """
         return self._val
 
@@ -211,10 +205,8 @@ class Field(object):
 
         Returns
         -------
-        out : tuple
-            The output object. The tuple contains the dimensions of the spaces
-            in domain.
-       """
+        Integer tuple containing the dimensions of the spaces in domain.
+        """
         return self.domain.shape
 
     @property
@@ -232,14 +224,12 @@ class Field(object):
 
     @property
     def real(self):
-        """ The real part of the field (data is not copied).
-        """
+        """ The real part of the field (data is not copied)."""
         return Field(self.domain, self.val.real)
 
     @property
     def imag(self):
-        """ The imaginary part of the field (data is not copied).
-        """
+        """ The imaginary part of the field (data is not copied)."""
         return Field(self.domain, self.val.imag)
 
     # ---Special unary/binary operations---
@@ -290,7 +280,6 @@ class Field(object):
         -------
         out : Field
             The weighted field.
-
         """
         if out is None:
             out = self.copy()
@@ -313,7 +302,8 @@ class Field(object):
                 new_shape[self.domain.axes[ind][0]:
                           self.domain.axes[ind][-1]+1] = wgt.shape
                 wgt = wgt.reshape(new_shape)
-                if dobj.distaxis(self._val) >= 0 and ind == 0:  # we need to distribute the weights along axis 0
+                if dobj.distaxis(self._val) >= 0 and ind == 0:
+                    # we need to distribute the weights along axis 0
                     wgt = dobj.local_data(dobj.from_global_data(wgt))
                 out *= wgt**power
         fct = fct**power
@@ -336,8 +326,7 @@ class Field(object):
 
         Returns
         -------
-        out : float, complex
-
+        out : float, complex, either scalar or Field
         """
         if not isinstance(x, Field):
             raise ValueError("The dot-partner must be an instance of " +
@@ -354,15 +343,19 @@ class Field(object):
 
         if spaces is None:
             return fct*dobj.vdot(y.val, x.val)
-        else:
-            spaces = utilities.cast_iseq_to_tuple(spaces)
-            active_axes = []
-            for i in spaces:
-                active_axes += self.domain.axes[i]
-            res = 0.
-            for sl in utilities.get_slice_list(self.shape, active_axes):
-                res += dobj.vdot(y.val, x.val[sl])
-            return res*fct
+
+        spaces = utilities.cast_iseq_to_tuple(spaces)
+        if spaces == tuple(range(len(self.domain))):  # full contraction
+            return fct*dobj.vdot(y.val, x.val)
+
+        raise NotImplementedError("special case for vdot not yet implemented")
+        active_axes = []
+        for i in spaces:
+            active_axes += self.domain.axes[i]
+        res = 0.
+        for sl in utilities.get_slice_list(self.shape, active_axes):
+            res += dobj.vdot(y.val, x.val[sl])
+        return res*fct
 
     def norm(self):
         """ Computes the L2-norm of the field values.
@@ -371,7 +364,6 @@ class Field(object):
         -------
         norm : float
             The L2-norm of the field values.
-
         """
         return np.sqrt(np.abs(self.vdot(x=self)))
 
@@ -380,9 +372,7 @@ class Field(object):
 
         Returns
         -------
-        cc : field
-            The complex conjugated field.
-
+        The complex conjugated field.
         """
         return Field(self.domain, self.val.conjugate(), self.dtype)
 
diff --git a/nifty/library/critical_filter/critical_power_energy.py b/nifty/library/critical_filter/critical_power_energy.py
index 0c7af23a5..30d73826f 100644
--- a/nifty/library/critical_filter/critical_power_energy.py
+++ b/nifty/library/critical_filter/critical_power_energy.py
@@ -3,7 +3,7 @@ from ...operators.smoothness_operator import SmoothnessOperator
 from ...operators.power_projection_operator import PowerProjectionOperator
 from ...operators.inversion_enabler import InversionEnabler
 from . import CriticalPowerCurvature
-from ...memoization import memo
+from ...utilities import memo
 from ... import Field, exp
 from ...sugar import generate_posterior_sample
 
diff --git a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py
index 74309c4f6..add9cc58f 100644
--- a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py
+++ b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py
@@ -1,5 +1,5 @@
 from ...operators import EndomorphicOperator
-from ...memoization import memo
+from ...utilities import memo
 from ...basic_arithmetics import exp
 from ...sugar import create_composed_fft_operator
 
diff --git a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py
index 56f411841..f3a7a59dc 100644
--- a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py
+++ b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py
@@ -1,5 +1,5 @@
 from ...energies.energy import Energy
-from ...memoization import memo
+from ...utilities import memo
 from . import LogNormalWienerFilterCurvature
 from ...sugar import create_composed_fft_operator
 from ...operators.inversion_enabler import InversionEnabler
diff --git a/nifty/library/wiener_filter/wiener_filter_energy.py b/nifty/library/wiener_filter/wiener_filter_energy.py
index a6210a88c..51e33c2db 100644
--- a/nifty/library/wiener_filter/wiener_filter_energy.py
+++ b/nifty/library/wiener_filter/wiener_filter_energy.py
@@ -1,5 +1,5 @@
 from ...energies.energy import Energy
-from ...memoization import memo
+from ...utilities import memo
 from ...operators.inversion_enabler import InversionEnabler
 from . import WienerFilterCurvature
 
diff --git a/nifty/memoization.py b/nifty/memoization.py
deleted file mode 100644
index daa21d55c..000000000
--- a/nifty/memoization.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program.  If not, see <http://www.gnu.org/licenses/>.
-#
-# Copyright(C) 2013-2017 Max-Planck-Society
-#
-# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
-# and financially supported by the Studienstiftung des deutschen Volkes.
-
-
-def memo(f):
-    name = f.__name__
-
-    def wrapped_f(self):
-        if not hasattr(self, "_cache"):
-            self._cache = {}
-        try:
-            return self._cache[name]
-        except KeyError:
-            self._cache[name] = f(self)
-            return self._cache[name]
-    return wrapped_f
diff --git a/nifty/minimization/iteration_controlling/iteration_controller.py b/nifty/minimization/iteration_controlling/iteration_controller.py
index 001f06eca..ccdcb3899 100644
--- a/nifty/minimization/iteration_controlling/iteration_controller.py
+++ b/nifty/minimization/iteration_controlling/iteration_controller.py
@@ -18,7 +18,7 @@
 
 from builtins import range
 import abc
-from ...nifty_meta import NiftyMeta
+from ...utilities import NiftyMeta
 from future.utils import with_metaclass
 
 
diff --git a/nifty/minimization/minimizer.py b/nifty/minimization/minimizer.py
index ee472a345..0809da8a3 100644
--- a/nifty/minimization/minimizer.py
+++ b/nifty/minimization/minimizer.py
@@ -17,7 +17,7 @@
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
 import abc
-from ..nifty_meta import NiftyMeta
+from ..utilities import NiftyMeta
 from future.utils import with_metaclass
 
 
diff --git a/nifty/nifty_meta.py b/nifty/nifty_meta.py
deleted file mode 100644
index 3d24bc621..000000000
--- a/nifty/nifty_meta.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import abc
-
-
-class DocStringInheritor(type):
-    """
-    A variation on
-    http://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95
-    by Paul McGuire
-    """
-    def __new__(meta, name, bases, clsdict):
-        if not('__doc__' in clsdict and clsdict['__doc__']):
-            for mro_cls in (mro_cls for base in bases
-                            for mro_cls in base.mro()):
-                doc = mro_cls.__doc__
-                if doc:
-                    clsdict['__doc__'] = doc
-                    break
-        for attr, attribute in list(clsdict.items()):
-            if not attribute.__doc__:
-                for mro_cls in (mro_cls for base in bases
-                                for mro_cls in base.mro()
-                                if hasattr(mro_cls, attr)):
-                    doc = getattr(getattr(mro_cls, attr), '__doc__')
-                    if doc:
-                        if isinstance(attribute, property):
-                            clsdict[attr] = property(attribute.fget,
-                                                     attribute.fset,
-                                                     attribute.fdel,
-                                                     doc)
-                        else:
-                            attribute.__doc__ = doc
-                        break
-        return super(DocStringInheritor, meta).__new__(meta, name,
-                                                       bases, clsdict)
-
-
-class NiftyMeta(DocStringInheritor, abc.ABCMeta):
-    pass
diff --git a/nifty/operators/diagonal_operator.py b/nifty/operators/diagonal_operator.py
index 11b0ed469..6ff48998d 100644
--- a/nifty/operators/diagonal_operator.py
+++ b/nifty/operators/diagonal_operator.py
@@ -21,7 +21,7 @@ import numpy as np
 from ..field import Field
 from ..domain_tuple import DomainTuple
 from .endomorphic_operator import EndomorphicOperator
-from ..nifty_utilities import cast_iseq_to_tuple
+from ..utilities import cast_iseq_to_tuple
 from .. import dobj
 
 
diff --git a/nifty/operators/fft_operator_support.py b/nifty/operators/fft_operator_support.py
index c1f170912..37730afe8 100644
--- a/nifty/operators/fft_operator_support.py
+++ b/nifty/operators/fft_operator_support.py
@@ -18,7 +18,7 @@
 
 from __future__ import division
 import numpy as np
-from .. import nifty_utilities as utilities
+from .. import utilities
 from ..low_level_library import hartley
 from .. import dobj
 from ..field import Field
diff --git a/nifty/operators/linear_operator.py b/nifty/operators/linear_operator.py
index e6400bc26..fad51f958 100644
--- a/nifty/operators/linear_operator.py
+++ b/nifty/operators/linear_operator.py
@@ -18,7 +18,7 @@
 
 from builtins import str
 import abc
-from ..nifty_meta import NiftyMeta
+from ..utilities import NiftyMeta
 from ..field import Field
 from future.utils import with_metaclass
 
diff --git a/nifty/probing/prober.py b/nifty/probing/prober.py
index cc9207d5d..70fd69a6a 100644
--- a/nifty/probing/prober.py
+++ b/nifty/probing/prober.py
@@ -21,7 +21,7 @@ from builtins import range
 from builtins import object
 import numpy as np
 from ..field import Field, DomainTuple
-from .. import nifty_utilities as utilities
+from .. import utilities
 
 
 class Prober(object):
diff --git a/nifty/spaces/rg_space.py b/nifty/spaces/rg_space.py
index 52736a7cd..04272e2d8 100644
--- a/nifty/spaces/rg_space.py
+++ b/nifty/spaces/rg_space.py
@@ -51,7 +51,9 @@ class RGSpace(Space):
         self._needed_for_hash += ["_distances", "_shape", "_harmonic"]
 
         self._harmonic = bool(harmonic)
-        self._shape = self._parse_shape(shape)
+        if np.isscalar(shape):
+            shape = (shape,)
+        self._shape = tuple(int(i) for i in shape)
         self._distances = self._parse_distances(distances)
         self._dvol = float(reduce(lambda x, y: x*y, self._distances))
         self._dim = int(reduce(lambda x, y: x*y, self._shape))
@@ -163,17 +165,12 @@ class RGSpace(Space):
 
     @property
     def distances(self):
-        """Distance between two grid points along each axis. It is a tuple
+        """Distance between grid points along each axis. It is a tuple
         of positive floating point numbers with the n-th entry giving the
-        distances of grid points along the n-th dimension.
+        distance between neighboring grid points along the n-th dimension.
         """
         return self._distances
 
-    def _parse_shape(self, shape):
-        if np.isscalar(shape):
-            return (shape,)
-        return tuple(np.array(shape, dtype=np.int))
-
     def _parse_distances(self, distances):
         if distances is None:
             if self.harmonic:
diff --git a/nifty/sugar.py b/nifty/sugar.py
index 8c168f74e..656e06667 100644
--- a/nifty/sugar.py
+++ b/nifty/sugar.py
@@ -18,8 +18,8 @@
 
 import numpy as np
 from . import Space, PowerSpace, Field, ComposedOperator, DiagonalOperator,\
-              PowerProjectionOperator, FFTOperator, sqrt, DomainTuple, dobj
-from . import nifty_utilities as utilities
+              PowerProjectionOperator, FFTOperator, sqrt, DomainTuple, dobj,\
+              utilities
 
 __all__ = ['PS_field',
            'power_analyze',
diff --git a/nifty/nifty_utilities.py b/nifty/utilities.py
similarity index 55%
rename from nifty/nifty_utilities.py
rename to nifty/utilities.py
index c8df74b54..6f6899ce1 100644
--- a/nifty/nifty_utilities.py
+++ b/nifty/utilities.py
@@ -19,6 +19,7 @@
 from builtins import next, range
 import numpy as np
 from itertools import product
+import abc
 
 
 def get_slice_list(shape, axes):
@@ -42,10 +43,8 @@ def get_slice_list(shape, axes):
     ------
     ValueError
         If shape is empty.
-    ValueError
         If axes(axis) does not match shape.
     """
-
     if shape is None:
         raise ValueError("shape cannot be None.")
 
@@ -72,3 +71,54 @@ def cast_iseq_to_tuple(seq):
     if np.isscalar(seq):
         return (int(seq),)
     return tuple(int(item) for item in seq)
+
+
+def memo(f):
+    name = f.__name__
+
+    def wrapped_f(self):
+        if not hasattr(self, "_cache"):
+            self._cache = {}
+        try:
+            return self._cache[name]
+        except KeyError:
+            self._cache[name] = f(self)
+            return self._cache[name]
+    return wrapped_f
+
+
+class _DocStringInheritor(type):
+    """
+    A variation on
+    http://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95
+    by Paul McGuire
+    """
+    def __new__(meta, name, bases, clsdict):
+        if not('__doc__' in clsdict and clsdict['__doc__']):
+            for mro_cls in (mro_cls for base in bases
+                            for mro_cls in base.mro()):
+                doc = mro_cls.__doc__
+                if doc:
+                    clsdict['__doc__'] = doc
+                    break
+        for attr, attribute in list(clsdict.items()):
+            if not attribute.__doc__:
+                for mro_cls in (mro_cls for base in bases
+                                for mro_cls in base.mro()
+                                if hasattr(mro_cls, attr)):
+                    doc = getattr(getattr(mro_cls, attr), '__doc__')
+                    if doc:
+                        if isinstance(attribute, property):
+                            clsdict[attr] = property(attribute.fget,
+                                                     attribute.fset,
+                                                     attribute.fdel,
+                                                     doc)
+                        else:
+                            attribute.__doc__ = doc
+                        break
+        return super(_DocStringInheritor, meta).__new__(meta, name,
+                                                        bases, clsdict)
+
+
+class NiftyMeta(_DocStringInheritor, abc.ABCMeta):
+    pass
diff --git a/test/test_field.py b/test/test_field.py
index 9da6e095f..52f998a20 100644
--- a/test/test_field.py
+++ b/test/test_field.py
@@ -127,5 +127,5 @@ class Test_Functionality(unittest.TestCase):
         s = ift.RGSpace((10,))
         f1 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
         f2 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
-        # assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0))
+        assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0))
         assert_allclose(f1.vdot(f2), np.conj(f2.vdot(f1)))
-- 
GitLab