From 2ec3da9a4bedcd7076bd4792c5776323caf979f4 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Fri, 6 Jul 2018 11:56:18 +0200
Subject: [PATCH] optimize the Field constructor

---
 nifty5/domain_tuple.py                    |  4 +-
 nifty5/field.py                           | 81 +++++++----------------
 nifty5/sugar.py                           |  4 +-
 test/test_field.py                        | 28 ++++----
 test/test_minimization/test_minimizers.py | 13 ++--
 5 files changed, 49 insertions(+), 81 deletions(-)

diff --git a/nifty5/domain_tuple.py b/nifty5/domain_tuple.py
index c44653955..ea446d8e1 100644
--- a/nifty5/domain_tuple.py
+++ b/nifty5/domain_tuple.py
@@ -25,8 +25,8 @@ from .domains.domain import Domain
 class DomainTuple(object):
     """Ordered sequence of Domain objects.
 
-    This class holds a set of :class:`Domain` objects, which together form the
-    space on which a :class:`Field` is defined.
+    This class holds a tuple of :class:`Domain` objects, which together form
+    the space on which a :class:`Field` is defined.
 
     Notes
     -----
diff --git a/nifty5/field.py b/nifty5/field.py
index 624762638..f827ddced 100644
--- a/nifty5/field.py
+++ b/nifty5/field.py
@@ -33,45 +33,28 @@ class Field(object):
 
     Parameters
     ----------
-    domain : None, DomainTuple, tuple of Domain, or Domain
+    domain : DomainTuple
+        the domain of the new Field
 
-    val : Field, data_object or scalar
-        The values the array should contain after init. A scalar input will
-        fill the whole array with this scalar. If a data_object is provided,
-        its dimensions must match the domain's.
-
-    dtype : type
-        A numpy.type. Most common are float and complex.
+    val : data_object
+        This object's global shape must match the domain shape
+        After construction, the object will no longer be writeable!
 
     Notes
     -----
     If possible, do not invoke the constructor directly, but use one of the
-    many convenience functions for Field conatruction!
+    many convenience functions for Field construction!
     """
 
-    def __init__(self, domain=None, val=None, dtype=None):
-        self._domain = self._infer_domain(domain=domain, val=val)
-
-        dtype = self._infer_dtype(dtype=dtype, val=val)
-        if isinstance(val, Field):
-            if self._domain != val._domain:
-                raise ValueError("Domain mismatch")
-            self._val = val._val
-
-        elif (np.isscalar(val)):
-            self._val = dobj.full(self._domain.shape, dtype=dtype,
-                                  fill_value=val)
-        elif isinstance(val, dobj.data_object):
-            if self._domain.shape == val.shape:
-                if dtype == val.dtype:
-                    self._val = val
-                else:
-                    self._val = dobj.from_object(val, dtype, True, True)
-            else:
-                raise ValueError("Shape mismatch")
-        else:
-            raise TypeError("unknown source type")
-
+    def __init__(self, domain, val):
+        if not isinstance(domain, DomainTuple):
+            raise TypeError("domain must be of type DomainTuple")
+        if not isinstance(val, dobj.data_object):
+            raise TypeError("val must be of type dobj.data_object")
+        if domain.shape != val.shape:
+            raise ValueError("mismatch between the shapes of val and domain")
+        self._domain = domain
+        self._val = val
         dobj.lock(self._val)
 
     # prevent implicit conversion to bool
@@ -99,7 +82,10 @@ class Field(object):
         """
         if not np.isscalar(val):
             raise TypeError("val must be a scalar")
-        return Field(DomainTuple.make(domain), val)
+        if not (np.isreal(val) or np.iscomplex(val)):
+            raise TypeError("need arithmetic scalar")
+        domain = DomainTuple.make(domain)
+        return Field(domain, dobj.full(domain.shape, fill_value=val))
 
     @staticmethod
     def from_global_data(domain, arr, sum_up=False):
@@ -118,12 +104,13 @@ class Field(object):
             If False, the contens of `arr` are used directly, and must be
             identical on all MPI tasks.
         """
-        return Field(domain, dobj.from_global_data(arr, sum_up))
+        return Field(DomainTuple.make(domain),
+                     dobj.from_global_data(arr, sum_up))
 
     @staticmethod
     def from_local_data(domain, arr):
-        domain = DomainTuple.make(domain)
-        return Field(domain, dobj.from_local_data(domain.shape, arr))
+        return Field(DomainTuple.make(domain),
+            dobj.from_local_data(domain.shape, arr))
 
     def to_global_data(self):
         """Returns an array containing the full data of the field.
@@ -167,25 +154,7 @@ class Field(object):
         -----
         No copy is made. If needed, use an additional copy() invocation.
         """
-        return Field(new_domain, self._val)
-
-    @staticmethod
-    def _infer_domain(domain, val=None):
-        if domain is None:
-            if isinstance(val, Field):
-                return val._domain
-            if np.isscalar(val):
-                return DomainTuple.make(())  # empty domain tuple
-            raise TypeError("could not infer domain from value")
-        return DomainTuple.make(domain)
-
-    @staticmethod
-    def _infer_dtype(dtype, val):
-        if dtype is not None:
-            return dtype
-        if val is None:
-            raise ValueError("could not infer dtype")
-        return np.result_type(val)
+        return Field(DomainTuple.make(new_domain), self._val)
 
     @staticmethod
     def from_random(random_type, domain, dtype=np.float64, **kwargs):
@@ -444,7 +413,7 @@ class Field(object):
                                   for i, dom in enumerate(self._domain)
                                   if i not in spaces)
 
-            return Field(domain=return_domain, val=data)
+            return Field(DomainTuple.make(return_domain), data)
 
     def sum(self, spaces=None):
         """Sums up over the sub-domains given by `spaces`.
diff --git a/nifty5/sugar.py b/nifty5/sugar.py
index b5069e2ba..0dd149ac5 100644
--- a/nifty5/sugar.py
+++ b/nifty5/sugar.py
@@ -40,7 +40,7 @@ def PS_field(pspace, func):
     if not isinstance(pspace, PowerSpace):
         raise TypeError
     data = dobj.from_global_data(func(pspace.k_lengths))
-    return Field(pspace, val=data)
+    return Field(DomainTuple.make(pspace), data)
 
 
 def get_signal_variance(spec, space):
@@ -158,7 +158,7 @@ def _create_power_field(domain, power_spectrum):
         if not isinstance(power_spectrum.domain[0], PowerSpace):
             raise TypeError("PowerSpace required")
         power_domain = power_spectrum.domain[0]
-        fp = Field(power_domain, val=power_spectrum.val)
+        fp = power_spectrum
     else:
         power_domain = PowerSpace(domain)
         fp = PS_field(power_domain, power_spectrum)
diff --git a/test/test_field.py b/test/test_field.py
index c5687e2ba..42a573420 100644
--- a/test/test_field.py
+++ b/test/test_field.py
@@ -139,16 +139,16 @@ class Test_Functionality(unittest.TestCase):
         assert_equal(d, d2)
 
     def test_empty_domain(self):
-        f = ift.Field((), 5)
+        f = ift.Field.full((), 5)
         assert_equal(f.local_data, 5)
-        f = ift.Field(None, 5)
+        f = ift.Field.full(None, 5)
         assert_equal(f.local_data, 5)
 
     def test_trivialities(self):
         s1 = ift.RGSpace((10,))
-        f1 = ift.Field(s1, 27)
+        f1 = ift.Field.full(s1, 27)
         assert_equal(f1.local_data, f1.real.local_data)
-        f1 = ift.Field(s1, 27.+3j)
+        f1 = ift.Field.full(s1, 27.+3j)
         assert_equal(f1.real.local_data, 27.)
         assert_equal(f1.imag.local_data, 3.)
         assert_equal(f1.local_data, +f1.local_data)
@@ -160,7 +160,7 @@ class Test_Functionality(unittest.TestCase):
 
     def test_weight(self):
         s1 = ift.RGSpace((10,))
-        f = ift.Field(s1, 10.)
+        f = ift.Field.full(s1, 10.)
         f2 = f.weight(1)
         assert_equal(f.weight(1).local_data, f2.local_data)
         assert_equal(f.total_volume(), 1)
@@ -170,7 +170,7 @@ class Test_Functionality(unittest.TestCase):
         assert_equal(f.scalar_weight(0), 0.1)
         assert_equal(f.scalar_weight((0,)), 0.1)
         s1 = ift.GLSpace(10)
-        f = ift.Field(s1, 10.)
+        f = ift.Field.full(s1, 10.)
         assert_equal(f.scalar_weight(), None)
         assert_equal(f.scalar_weight(0), None)
         assert_equal(f.scalar_weight((0,)), None)
@@ -178,7 +178,7 @@ class Test_Functionality(unittest.TestCase):
     @expand(product([ift.RGSpace(10), ift.GLSpace(10)],
                     [np.float64, np.complex128]))
     def test_reduction(self, dom, dt):
-        s1 = ift.Field(dom, 1., dtype=dt)
+        s1 = ift.Field.full(dom, dt(1.))
         assert_allclose(s1.mean(), 1.)
         assert_allclose(s1.mean(0), 1.)
         assert_allclose(s1.var(), 0., atol=1e-14)
@@ -189,13 +189,11 @@ class Test_Functionality(unittest.TestCase):
     def test_err(self):
         s1 = ift.RGSpace((10,))
         s2 = ift.RGSpace((11,))
-        f1 = ift.Field(s1, 27)
+        f1 = ift.Field.full(s1, 27)
         with assert_raises(ValueError):
-            f2 = ift.Field(s2, f1)
-        with assert_raises(ValueError):
-            f2 = ift.Field(s2, f1.val)
+            f2 = ift.Field(ift.DomainTuple.make(s2), f1.val)
         with assert_raises(TypeError):
-            f2 = ift.Field(s2, "xyz")
+            f2 = ift.Field.full(s2, "xyz")
         with assert_raises(TypeError):
             if f1:
                 pass
@@ -203,20 +201,20 @@ class Test_Functionality(unittest.TestCase):
             f1.full((2, 4, 6))
         with assert_raises(TypeError):
             f2 = ift.Field(None, None)
-        with assert_raises(ValueError):
+        with assert_raises(TypeError):
             f2 = ift.Field(s1, None)
         with assert_raises(ValueError):
             f1.imag
         with assert_raises(TypeError):
             f1.vdot(42)
         with assert_raises(ValueError):
-            f1.vdot(ift.Field(s2, 1.))
+            f1.vdot(ift.Field.full(s2, 1.))
         with assert_raises(TypeError):
             ift.full(s1, [2, 3])
 
     def test_stdfunc(self):
         s = ift.RGSpace((200,))
-        f = ift.Field(s, 27)
+        f = ift.Field.full(s, 27)
         assert_equal(f.local_data, 27)
         assert_equal(f.shape, (200,))
         assert_equal(f.dtype, np.int)
diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py
index 9685b8710..8607a8b98 100644
--- a/test/test_minimization/test_minimizers.py
+++ b/test/test_minimization/test_minimizers.py
@@ -133,7 +133,7 @@ class Test_Minimizers(unittest.TestCase):
     @expand(product(minimizers+slow_minimizers))
     def test_gauss(self, minimizer):
         space = ift.UnstructuredDomain((1,))
-        starting_point = ift.Field(space, val=3.)
+        starting_point = ift.Field.full(space, 3.)
 
         class ExpEnergy(ift.Energy):
             def __init__(self, position):
@@ -147,14 +147,15 @@ class Test_Minimizers(unittest.TestCase):
             @property
             def gradient(self):
                 x = self.position.to_global_data()[0]
-                return ift.Field(self.position.domain, val=2*x*np.exp(-(x**2)))
+                return ift.Field.full(self.position.domain,
+                                      2*x*np.exp(-(x**2)))
 
             @property
             def metric(self):
                 x = self.position.to_global_data()[0]
                 v = (2 - 4*x*x)*np.exp(-x**2)
                 return ift.DiagonalOperator(
-                    ift.Field(self.position.domain, val=v))
+                    ift.Field.full(self.position.domain, v))
 
         try:
             minimizer = eval(minimizer)
@@ -171,7 +172,7 @@ class Test_Minimizers(unittest.TestCase):
     @expand(product(minimizers+newton_minimizers+slow_minimizers))
     def test_cosh(self, minimizer):
         space = ift.UnstructuredDomain((1,))
-        starting_point = ift.Field(space, val=3.)
+        starting_point = ift.Field.full(space, 3.)
 
         class CoshEnergy(ift.Energy):
             def __init__(self, position):
@@ -185,14 +186,14 @@ class Test_Minimizers(unittest.TestCase):
             @property
             def gradient(self):
                 x = self.position.to_global_data()[0]
-                return ift.Field(self.position.domain, val=np.sinh(x))
+                return ift.Field.full(self.position.domain, np.sinh(x))
 
             @property
             def metric(self):
                 x = self.position.to_global_data()[0]
                 v = np.cosh(x)
                 return ift.DiagonalOperator(
-                    ift.Field(self.position.domain, val=v))
+                    ift.Field.full(self.position.domain, v))
 
         try:
             minimizer = eval(minimizer)
-- 
GitLab