From 1107973c64a09f67c6000ea2cf053840a34d0610 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sat, 17 Feb 2018 12:42:42 +0100
Subject: [PATCH] more docs, various tweaks

---
 demos/wiener_filter_via_hamiltonian.py  |  3 ---
 nifty4/data_objects/distributed_do.py   | 18 +++++++++++---
 nifty4/data_objects/numpy_do.py         | 16 ++++++++++--
 nifty4/domain_tuple.py                  | 33 +++++++++++++++++++++++++
 nifty4/domains/power_space.py           |  2 +-
 nifty4/field.py                         | 18 ++++++++------
 nifty4/library/critical_power_energy.py |  1 +
 nifty4/sugar.py                         |  3 +--
 8 files changed, 75 insertions(+), 19 deletions(-)

diff --git a/demos/wiener_filter_via_hamiltonian.py b/demos/wiener_filter_via_hamiltonian.py
index 57a49d11e..1eedca8d4 100644
--- a/demos/wiener_filter_via_hamiltonian.py
+++ b/demos/wiener_filter_via_hamiltonian.py
@@ -76,9 +76,6 @@ if __name__ == "__main__":
     ift.plot(ht(m), name="reconstruction.png", **plotdict)
 
     # Sample uncertainty map
-    sample_variance = ift.Field.zeros(s_space)
-    sample_mean = ift.Field.zeros(s_space)
-
     mean, variance = ift.probe_with_posterior_samples(curv, ht, 50)
     ift.plot(variance, name="posterior_variance.png", **plotdict2)
     ift.plot(mean+ht(m), name="posterior_mean.png", **plotdict)
diff --git a/nifty4/data_objects/distributed_do.py b/nifty4/data_objects/distributed_do.py
index 8020a7cf2..aebe19fce 100644
--- a/nifty4/data_objects/distributed_do.py
+++ b/nifty4/data_objects/distributed_do.py
@@ -323,10 +323,20 @@ def sqrt(a, out=None):
     return _math_helper(a, np.sqrt, out)
 
 
-def from_object(object, dtype=None, copy=True):
-    return data_object(object._shape, np.array(object._data, dtype=dtype,
-                                               copy=copy),
-                       distaxis=object._distaxis)
+def from_object(object, dtype, copy, set_locked):
+    if dtype is None:
+        dtype = object.dtype
+    dtypes_equal = dtype == object.dtype
+    if set_locked and dtypes_equal and locked(object):
+        return object
+    if not dtypes_equal and not copy:
+        raise ValueError("cannot change data type without copying")
+    if set_locked and not copy:
+        raise ValueError("cannot lock object without copying")
+    data = np.array(object._data, dtype=dtype, copy=copy)
+    if set_locked:
+        lock(data)
+    return data_object(object._shape, data, distaxis=object._distaxis)
 
 
 # This function draws all random numbers on all tasks, to produce the same
diff --git a/nifty4/data_objects/numpy_do.py b/nifty4/data_objects/numpy_do.py
index 5683a0ff8..9d13bffdc 100644
--- a/nifty4/data_objects/numpy_do.py
+++ b/nifty4/data_objects/numpy_do.py
@@ -38,8 +38,20 @@ def mprint(*args):
     print(*args)
 
 
-def from_object(object, dtype=None, copy=True):
-    return np.array(object, dtype=dtype, copy=copy)
+def from_object(object, dtype, copy, set_locked):
+    if dtype is None:
+        dtype = object.dtype
+    dtypes_equal = dtype == object.dtype
+    if set_locked and dtypes_equal and locked(object):
+        return object
+    if not dtypes_equal and not copy:
+        raise ValueError("cannot change data type without copying")
+    if set_locked and not copy:
+        raise ValueError("cannot lock object without copying")
+    res = np.array(object, dtype=dtype, copy=copy)
+    if set_locked:
+        lock(res)
+    return res
 
 
 def from_random(random_type, shape, dtype=np.float64, **kwargs):
diff --git a/nifty4/domain_tuple.py b/nifty4/domain_tuple.py
index ef8c73c26..901624c00 100644
--- a/nifty4/domain_tuple.py
+++ b/nifty4/domain_tuple.py
@@ -21,6 +21,17 @@ from .domains.domain import Domain
 
 
 class DomainTuple(object):
+    """Ordered sequence of Domain objects.
+
+    This class holds a set of :class:`Domain` objects, which together form the
+    space on which a :class:`Field` is defined.
+
+    Notes
+    -----
+
+    DomainTuples should never be created using the constructor, but rather
+    via the factory function :attr:`make`!
+    """
     _tupleCache = {}
 
     def __init__(self, domain):
@@ -44,6 +55,18 @@ class DomainTuple(object):
 
     @staticmethod
     def make(domain):
+        """Returns a DomainTuple matching `domain`.
+
+        This function checks whether a matching DomainTuple already exists.
+        If yes, this object is returned, otherwise a new DomainTuple object
+        is created and returned.
+
+        Parameters
+        ----------
+        domain : Domain or tuple of Domain or DomainTuple
+            The geometrical structure for which the DomainTuple shall be
+            obtained.
+        """
         if isinstance(domain, DomainTuple):
             return domain
         domain = DomainTuple._parse_domain(domain)
@@ -75,14 +98,24 @@ class DomainTuple(object):
 
     @property
     def shape(self):
+        """tuple of int: number of pixels along each axis
+
+        The shape of the array-like object required to store information
+        living on the DomainTuple.
+        """
         return self._shape
 
     @property
     def size(self):
+        """int : total number of pixels.
+
+        Equivalent to the products over all entries in the object's shape.
+        """
         return self._size
 
     @property
     def axes(self):
+        """tuple of tuple of int : shapes of the underlying domains"""
         return self._axtuple
 
     def __len__(self):
diff --git a/nifty4/domains/power_space.py b/nifty4/domains/power_space.py
index fd49ba1de..8f5f4b9e5 100644
--- a/nifty4/domains/power_space.py
+++ b/nifty4/domains/power_space.py
@@ -239,5 +239,5 @@ class PowerSpace(StructuredDomain):
 
     @property
     def k_lengths(self):
-        """numpy.ndarray(float) : sorted array of all k-vector lengths."""
+        """numpy.ndarray(float) : k-vector length for each bin."""
         return self._k_lengths
diff --git a/nifty4/field.py b/nifty4/field.py
index a9391e681..0403b0431 100644
--- a/nifty4/field.py
+++ b/nifty4/field.py
@@ -45,23 +45,26 @@ class Field(object):
     dtype : type
         A numpy.type. Most common are float and complex.
 
-    copy: bool
+    copy : bool
     """
 
-    def __init__(self, domain=None, val=None, dtype=None, copy=False):
+    def __init__(self, domain=None, val=None, dtype=None, copy=False,
+                 locked=False):
         self._domain = self._infer_domain(domain=domain, val=val)
 
         dtype = self._infer_dtype(dtype=dtype, val=val)
         if isinstance(val, Field):
             if self._domain != val._domain:
                 raise ValueError("Domain mismatch")
-            self._val = dobj.from_object(val.val, dtype=dtype, copy=copy)
+            self._val = dobj.from_object(val.val, dtype=dtype, copy=copy,
+                                         set_locked=locked)
         elif (np.isscalar(val)):
             self._val = dobj.full(self._domain.shape, dtype=dtype,
                                   fill_value=val)
         elif isinstance(val, dobj.data_object):
             if self._domain.shape == val.shape:
-                self._val = dobj.from_object(val, dtype=dtype, copy=copy)
+                self._val = dobj.from_object(val, dtype=dtype, copy=copy,
+                                             set_locked=locked)
             else:
                 raise ValueError("Shape mismatch")
         elif val is None:
@@ -69,6 +72,9 @@ class Field(object):
         else:
             raise TypeError("unknown source type")
 
+        if locked:
+            dobj.lock(self._val)
+
     @staticmethod
     def full(domain, val, dtype=None):
         if not np.isscalar(val):
@@ -251,9 +257,7 @@ class Field(object):
         """
         if self.locked:
             return self
-        res = Field(val=self, copy=True)
-        res.lock()
-        return res
+        return Field(val=self, copy=True, locked=True)
 
     def scalar_weight(self, spaces=None):
         if np.isscalar(spaces):
diff --git a/nifty4/library/critical_power_energy.py b/nifty4/library/critical_power_energy.py
index 5ec4d32f9..3fc6cca4f 100644
--- a/nifty4/library/critical_power_energy.py
+++ b/nifty4/library/critical_power_energy.py
@@ -90,6 +90,7 @@ class CriticalPowerEnergy(Energy):
             Dist = PowerDistributor(target=self.m.domain,
                                     power_space=self.position.domain[0])
             if self.D is not None:
+                # MR FIXME: we should use stuff from probing.utils for this
                 w = Field.zeros(self.position.domain, dtype=self.m.dtype)
                 for i in range(self.samples):
                     sample = self.D.draw_sample() + self.m
diff --git a/nifty4/sugar.py b/nifty4/sugar.py
index a2f665efc..f8cccf0e4 100644
--- a/nifty4/sugar.py
+++ b/nifty4/sugar.py
@@ -239,8 +239,7 @@ def create_power_operator(domain, power_spectrum, space=None, dtype=None):
     space = int(space)
     return DiagonalOperator(
         create_power_field(domain[space], power_spectrum, dtype),
-        domain=domain,
-        spaces=space)
+        domain=domain, spaces=space)
 
 
 def create_composed_ht_operator(domain, codomain=None):
-- 
GitLab