From bc6cb4bc196074232cc3e019660a5b42b6ddd0ad Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Tue, 12 Sep 2017 16:38:18 +0200
Subject: [PATCH] field simplifications

---
 nifty2go/field.py | 65 +++++++++++++----------------------------------
 1 file changed, 17 insertions(+), 48 deletions(-)

diff --git a/nifty2go/field.py b/nifty2go/field.py
index becc8148b..38117375c 100644
--- a/nifty2go/field.py
+++ b/nifty2go/field.py
@@ -166,8 +166,6 @@ class Field(object):
         See Also
         --------
         power_synthesize
-
-
         """
 
         generator_function = getattr(Random, random_type)
@@ -230,9 +228,8 @@ class Field(object):
         # power_space instances
         for sp in self.domain:
             if not sp.harmonic and not isinstance(sp, PowerSpace):
-                raise TypeError(
-                    "Field has a space in `domain` which is neither "
-                    "harmonic nor a PowerSpace.")
+                raise TypeError("Field has a space in `domain` which is "
+                                "neither harmonic nor a PowerSpace.")
 
         # check if the `spaces` input is valid
         spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
@@ -267,10 +264,8 @@ class Field(object):
 
     @classmethod
     def _single_power_analyze(cls, work_field, space_index, binbounds):
-
         if not work_field.domain[space_index].harmonic:
-            raise ValueError(
-                "The analyzed space must be harmonic.")
+            raise ValueError("The analyzed space must be harmonic.")
 
         # Create the target PowerSpace instance:
         # If the associated signal-space field was real, we extract the
@@ -319,9 +314,8 @@ class Field(object):
         semiscaled_local_shape = [1] * len(target_shape)
         for i in range(len(axes)):
             semiscaled_local_shape[axes[i]] = pindex.shape[i]
-        semiscaled_local_data = pindex.reshape(semiscaled_local_shape)
         result_obj = np.empty(target_shape, dtype=pindex.dtype)
-        result_obj[()] = semiscaled_local_data
+        result_obj[()] = pindex.reshape(semiscaled_local_shape)
         return result_obj
 
     def power_synthesize(self, spaces=None, real_power=True, real_signal=True,
@@ -379,18 +373,15 @@ class Field(object):
         if spaces is None:
             spaces = list(range(len(self.domain)))
 
-        for power_space_index in spaces:
-            power_space = self.domain[power_space_index]
-            if not isinstance(power_space, PowerSpace):
+        for i in spaces:
+            if not isinstance(self.domain[i], PowerSpace):
                 raise ValueError("A PowerSpace is needed for field "
                                  "synthetization.")
 
         # create the result domain
         result_domain = list(self.domain)
-        for power_space_index in spaces:
-            power_space = self.domain[power_space_index]
-            harmonic_domain = power_space.harmonic_partner
-            result_domain[power_space_index] = harmonic_domain
+        for i in spaces:
+            result_domain[i] = self.domain[i].harmonic_partner
 
         # create random samples: one or two, depending on whether the
         # power spectrum is real or complex
@@ -422,14 +413,9 @@ class Field(object):
                            for i in result_list]
 
         if real_power:
-            result = result_list[0]
-            if not issubclass(result_list[0].dtype.type,
-                              np.complexfloating):
-                result = result.real
+            return result_list[0]
         else:
-            result = result_list[0] + 1j*result_list[1]
-
-        return result
+            return result_list[0] + 1j*result_list[1]
 
     @staticmethod
     def _hermitian_decomposition(val, preserve_gaussian_variance=False):
@@ -497,11 +483,10 @@ class Field(object):
     def total_volume(self):
         """ Returns the total volume of all spaces in the domain.
         """
-        volume_tuple = tuple(sp.total_volume for sp in self.domain)
-        try:
-            return reduce(lambda x, y: x * y, volume_tuple)
-        except TypeError:
+        if len(self.domain) == 0:
             return 0.
+        volume_tuple = tuple(sp.total_volume for sp in self.domain)
+        return reduce(lambda x, y: x * y, volume_tuple)
 
     @property
     def real(self):
@@ -517,31 +502,17 @@ class Field(object):
 
     # ---Special unary/binary operations---
 
-    def copy(self, domain=None, dtype=None):
+    def copy(self):
         """ Returns a full copy of the Field.
 
-        If no keyword arguments are given, the returned object will be an
-        identical copy of the original Field. By explicit specification one is
-        able to define the domain and the dtype of the returned Field.
-
-        Parameters
-        ----------
-        domain : DomainObject
-            The new domain the Field shall have.
-
-        dtype : type
-            The new dtype the Field shall have.
+        The returned object will be an identical copy of the original Field.
 
         Returns
         -------
         out : Field
             The output object. An identical copy of 'self'.
-
         """
-
-        if domain is None:
-            domain = self.domain
-        return Field(domain=domain, val=self._val, dtype=dtype, copy=True)
+        return Field(val=self, copy=True)
 
     def scalar_weight(self, spaces=None):
         if np.isscalar(spaces):
@@ -699,10 +670,8 @@ class Field(object):
 
         axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces)
 
-        try:
+        if len(axes_list) > 0:
             axes_list = reduce(lambda x, y: x+y, axes_list)
-        except TypeError:
-            axes_list = ()
 
         # perform the contraction on the data
         data = getattr(self.val, op)(axis=axes_list)
-- 
GitLab