From 2e434668ccc50a1b67e9de0acadb11f85b11b64c Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Wed, 8 Aug 2018 16:59:23 +0200
Subject: [PATCH] simplification and cosmetics

---
 nifty5/__init__.py                     |   2 -
 nifty5/field.py                        | 118 +++----------------------
 nifty5/linearization.py                |   4 +-
 nifty5/minimization/scipy_minimizer.py |  18 ++--
 4 files changed, 23 insertions(+), 119 deletions(-)

diff --git a/nifty5/__init__.py b/nifty5/__init__.py
index e54fc6b0f..6d30fb855 100644
--- a/nifty5/__init__.py
+++ b/nifty5/__init__.py
@@ -78,8 +78,6 @@ from .library.amplitude_model import AmplitudeModel
 from .library.inverse_gamma_model import InverseGammaModel
 from .library.los_response import LOSResponse
 
-#from .library.inverse_gamma_model import InverseGammaModel
-
 from .library.wiener_filter_curvature import WienerFilterCurvature
 from .library.correlated_fields import CorrelatedField
 #                                         make_mf_correlated_field)
diff --git a/nifty5/field.py b/nifty5/field.py
index ba6d5787a..fbedb1e90 100644
--- a/nifty5/field.py
+++ b/nifty5/field.py
@@ -47,13 +47,11 @@ class Field(object):
     """
 
     def __init__(self, domain, val):
-        self._uni = None
         if not isinstance(domain, DomainTuple):
             raise TypeError("domain must be of type DomainTuple")
-        if not isinstance(val, dobj.data_object):
+        if type(val) is not dobj.data_object:
             if np.isscalar(val):
-                self._uni = val
-                val = dobj.uniform_full(domain.shape, val)
+                val = dobj.full(domain.shape, val)
             else:
                 raise TypeError("val must be of type dobj.data_object")
         if domain.shape != val.shape:
@@ -394,14 +392,10 @@ class Field(object):
         return self
 
     def __neg__(self):
-        if self._uni is None:
-            return Field(self._domain, -self._val)
-        return Field(self._domain, -self._uni)
+        return Field(self._domain, -self._val)
 
     def __abs__(self):
-        if self._uni is None:
-            return Field(self._domain, abs(self._val))
-        return Field(self._domain, abs(self._uni))
+        return Field(self._domain, abs(self._val))
 
     def _contraction_helper(self, op, spaces):
         if spaces is None:
@@ -617,96 +611,12 @@ class Field(object):
         return self + other
 
     def positive_tanh(self):
-        if self._uni is None:
-            return 0.5*(1.+self.tanh())
-        return Field(self._domain, 0.5*(1.+np.tanh(self._uni)))
-
-    def __add__(self, other):
-        # if other is a field, make sure that the domains match
-        if isinstance(other, Field):
-            if other._domain is not self._domain:
-                raise ValueError("domains are incompatible.")
-            if self._uni is None:
-                if other._uni is None:
-                    return Field(self._domain, self._val+other._val)
-                if other._uni == 0:
-                    return self
-                return Field(self._domain, self._val+other._uni)
-            else:
-                if self._uni == 0:
-                    return other
-                if other._uni is None:
-                    return Field(self._domain, other._val+self._uni)
-                return Field(self._domain, self._uni+other._uni)
-
-        if np.isscalar(other):
-            if self._uni is None:
-                return Field(self._domain, self._val+other)
-            return Field(self._domain, self._uni+other)
-        return NotImplemented
-
-    def __radd__(self, other):
-        return self.__add__(other)
-
-    def __sub__(self, other):
-        # if other is a field, make sure that the domains match
-        if isinstance(other, Field):
-            if other._domain is not self._domain:
-                raise ValueError("domains are incompatible.")
-            if self._uni is None:
-                if other._uni is None:
-                    return Field(self._domain, self._val-other._val)
-                if other._uni == 0:
-                    return self
-                return Field(self._domain, self._val-other._uni)
-            else:
-                if self._uni == 0:
-                    return -other
-                if other._uni is None:
-                    return Field(self._domain, self._uni-other._val)
-                return Field(self._domain, self._uni-other._uni)
-
-        if np.isscalar(other):
-            if self._uni is None:
-                return Field(self._domain, self._val-other)
-            return Field(self._domain, self._uni-other)
-        return NotImplemented
-
-    def __mul__(self, other):
-        # if other is a field, make sure that the domains match
-        if isinstance(other, Field):
-            if other._domain is not self._domain:
-                raise ValueError("domains are incompatible.")
-            if self._uni is None:
-                if other._uni is None:
-                    return Field(self._domain, self._val*other._val)
-                if other._uni == 1:
-                    return self
-                if other._uni == 0:
-                    return other
-                return Field(self._domain, self._val*other._uni)
-            else:
-                if self._uni == 1:
-                    return other
-                if self._uni == 0:
-                    return self
-                if other._uni is None:
-                    return Field(self._domain, other._val*self._uni)
-                return Field(self._domain, self._uni*other._uni)
-
-        if np.isscalar(other):
-            if self._uni is None:
-                if other == 1:
-                    return self
-                if other == 0:
-                    return Field(self._domain, other)
-                return Field(self._domain, self._val*other)
-            return Field(self._domain, self._uni*other)
-        return NotImplemented
-
-
-for op in ["__rsub__",
-           "__rmul__",
+        return 0.5*(1.+self.tanh())
+
+
+for op in ["__add__", "__radd__",
+           "__sub__", "__rsub__",
+           "__mul__", "__rmul__",
            "__div__", "__rdiv__",
            "__truediv__", "__rtruediv__",
            "__floordiv__", "__rfloordiv__",
@@ -739,11 +649,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
 for f in ["sqrt", "exp", "log", "tanh"]:
     def func(f):
         def func2(self):
-            if self._uni is None:
-                fu = getattr(dobj, f)
-                return Field(domain=self._domain, val=fu(self.val))
-            else:
-                fu = getattr(np, f)
-                return Field(domain=self._domain, val=fu(self._uni))
+            fu = getattr(dobj, f)
+            return Field(domain=self._domain, val=fu(self.val))
         return func2
     setattr(Field, f, func(f))
diff --git a/nifty5/linearization.py b/nifty5/linearization.py
index 88d7b4f15..d3394e6d9 100644
--- a/nifty5/linearization.py
+++ b/nifty5/linearization.py
@@ -102,10 +102,10 @@ class Linearization(object):
         from .operators.simple_linear_operators import VdotOperator
         if isinstance(other, (Field, MultiField)):
             return Linearization(
-                Field(DomainTuple.scalar_domain(),self._val.vdot(other)),
+                Field(DomainTuple.scalar_domain(), self._val.vdot(other)),
                 VdotOperator(other)(self._jac))
         return Linearization(
-            Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)),
+            Field(DomainTuple.scalar_domain(), self._val.vdot(other._val)),
             VdotOperator(self._val)(other._jac) +
             VdotOperator(other._val)(self._jac))
 
diff --git a/nifty5/minimization/scipy_minimizer.py b/nifty5/minimization/scipy_minimizer.py
index 7cc3830d2..9fad98daa 100644
--- a/nifty5/minimization/scipy_minimizer.py
+++ b/nifty5/minimization/scipy_minimizer.py
@@ -26,12 +26,12 @@ from .iteration_controller import IterationController
 from .minimizer import Minimizer
 
 
-def _toNdarray(fld):
+def _toArray(fld):
     return fld.to_global_data().reshape(-1)
 
 
-def _toFlatNdarray(fld):
-    return fld.val.flatten()
+def _toArray_rw(fld):
+    return fld.to_global_data_rw().reshape(-1)
 
 
 def _toField(arr, dom):
@@ -54,12 +54,12 @@ class _MinHelper(object):
 
     def jac(self, x):
         self._update(x)
-        return _toFlatNdarray(self._energy.gradient)
+        return _toArray_rw(self._energy.gradient)
 
     def hessp(self, x, p):
         self._update(x)
         res = self._energy.metric(_toField(p, self._domain))
-        return _toFlatNdarray(res)
+        return _toArray_rw(res)
 
 
 class ScipyMinimizer(Minimizer):
@@ -95,7 +95,7 @@ class ScipyMinimizer(Minimizer):
             else:
                 raise ValueError("unrecognized bounds")
 
-        x = hlp._energy.position.val.flatten()
+        x = _toArray_rw(hlp._energy.position)
         hessp = hlp.hessp if self._need_hessp else None
         r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac,
                          hessp=hessp, options=self._options, bounds=bounds)
@@ -147,11 +147,11 @@ class ScipyCG(Minimizer):
                 self._op = op
 
             def __call__(self, inp):
-                return _toNdarray(self._op(_toField(inp, self._op.domain)))
+                return _toArray(self._op(_toField(inp, self._op.domain)))
 
         op = energy._A
-        b = _toNdarray(energy._b)
-        sx = _toNdarray(energy.position)
+        b = _toArray(energy._b)
+        sx = _toArray(energy.position)
         sci_op = scipy_linop(shape=(op.domain.size, op.target.size),
                              matvec=mymatvec(op))
         prec_op = None
-- 
GitLab