From acb67ac294fafadf0b7daef48eccf60286217a70 Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Wed, 20 Jun 2018 23:21:23 +0200
Subject: [PATCH] Restructuring models

---
 nifty5/models/__init__.py                     |   3 +-
 nifty5/models/binary_helpers.py               |  85 ++++++++++++++
 nifty5/models/linear.py                       |  43 +++++++
 nifty5/models/model.py                        | 106 +-----------------
 nifty5/models/variable.py                     |   2 +-
 nifty5/operators/__init__.py                  |   3 +-
 .../selection_operator.py                     |   0
 7 files changed, 138 insertions(+), 104 deletions(-)
 create mode 100644 nifty5/models/binary_helpers.py
 create mode 100644 nifty5/models/linear.py
 rename nifty5/{models => operators}/selection_operator.py (100%)

diff --git a/nifty5/models/__init__.py b/nifty5/models/__init__.py
index ac43207f1..0949e7516 100644
--- a/nifty5/models/__init__.py
+++ b/nifty5/models/__init__.py
@@ -1,7 +1,8 @@
 from .constant import Constant
+from .linear import LinearModel
 from .local_nonlinearity import (LocalModel, PointwiseExponential,
                                  PointwisePositiveTanh, PointwiseTanh)
-from .model import LinearModel, Model
+from .model import Model
 from .variable import Variable
 
 __all__ = ['Model', 'Constant', 'LocalModel', 'Variable',
diff --git a/nifty5/models/binary_helpers.py b/nifty5/models/binary_helpers.py
new file mode 100644
index 000000000..fbb7529f5
--- /dev/null
+++ b/nifty5/models/binary_helpers.py
@@ -0,0 +1,85 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Copyright(C) 2013-2018 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
+# and financially supported by the Studienstiftung des deutschen Volkes.
+
+from ..multi import MultiField
+from ..sugar import makeOp
+from .model import Model
+
+
+def _joint_position(op1, op2):
+    a = op1.position._val
+    b = op2.position._val
+    # Note: In python >3.5 one could do {**a, **b}
+    ab = a.copy()
+    ab.update(b)
+    return MultiField(ab)
+
+
+class ScalarMul(Model):
+    def __init__(self, factor, op):
+        super(ScalarMul, self).__init__(op.position)
+        if not isinstance(factor, (float, int)):
+            raise TypeError
+
+        self._op = op
+        self._factor = factor
+
+        self._value = self._factor * self._op.value
+        self._gradient = self._factor * self._op.gradient
+
+    def at(self, position):
+        return self.__class__(self._factor, self._op.at(position))
+
+
+class Add(Model):
+    def __init__(self, position, op1, op2):
+        super(Add, self).__init__(position)
+
+        self._op1 = op1.at(position)
+        self._op2 = op2.at(position)
+
+        self._value = self._op1.value + self._op2.value
+        self._gradient = self._op1.gradient + self._op2.gradient
+
+    @staticmethod
+    def make(op1, op2):
+        position = _joint_position(op1, op2)
+        return Add(position, op1, op2)
+
+    def at(self, position):
+        return self.__class__(position, self._op1, self._op2)
+
+
+class Mul(Model):
+    def __init__(self, position, op1, op2):
+        super(Mul, self).__init__(position)
+
+        self._op1 = op1.at(position)
+        self._op2 = op2.at(position)
+
+        self._value = self._op1.value * self._op2.value
+        self._gradient = (makeOp(self._op1.value) * self._op2.gradient +
+                          makeOp(self._op2.value) * self._op1.gradient)
+
+    @staticmethod
+    def make(op1, op2):
+        position = _joint_position(op1, op2)
+        return Mul(position, op1, op2)
+
+    def at(self, position):
+        return self.__class__(position, self._op1, self._op2)
diff --git a/nifty5/models/linear.py b/nifty5/models/linear.py
new file mode 100644
index 000000000..de391df99
--- /dev/null
+++ b/nifty5/models/linear.py
@@ -0,0 +1,43 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+# Copyright(C) 2013-2018 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
+# and financially supported by the Studienstiftung des deutschen Volkes.
+
+from ..operators.selection_operator import SelectionOperator
+from .model import Model
+
+
+class LinearModel(Model):
+    def __init__(self, inp, lin_op):
+        """
+        Computes lin_op(inp) where lin_op is a Linear Operator
+        """
+        from ..operators import LinearOperator
+        super(LinearModel, self).__init__(inp.position)
+
+        if not isinstance(lin_op, LinearOperator):
+            raise TypeError("needs a LinearOperator as input")
+
+        self._lin_op = lin_op
+        self._inp = inp
+        if isinstance(self._lin_op, SelectionOperator):
+            self._lin_op = SelectionOperator(self._inp.value.domain,
+                                             self._lin_op._key)
+
+        self._value = self._lin_op(self._inp.value)
+        self._gradient = self._lin_op*self._inp.gradient
+
+    def at(self, position):
+        return self.__class__(self._inp.at(position), self._lin_op)
diff --git a/nifty5/models/model.py b/nifty5/models/model.py
index ebb091266..fb320b422 100644
--- a/nifty5/models/model.py
+++ b/nifty5/models/model.py
@@ -16,10 +16,8 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
-from ..multi import MultiField
-from ..sugar import makeOp
+from ..operators.selection_operator import SelectionOperator
 from ..utilities import NiftyMetaBase
-from .selection_operator import SelectionOperator
 
 
 class Model(NiftyMetaBase()):
@@ -48,17 +46,21 @@ class Model(NiftyMetaBase()):
     def __add__(self, other):
         if not isinstance(other, Model):
             raise TypeError
+        from .binary_helpers import Add
         return Add.make(self, other)
 
     def __sub__(self, other):
         if not isinstance(other, Model):
             raise TypeError
+        from .binary_helpers import Add
         return Add.make(self, (-1) * other)
 
     def __mul__(self, other):
         if isinstance(other, (float, int)):
+            from .binary_helpers import ScalarMul
             return ScalarMul(other, self)
         if isinstance(other, Model):
+            from .binary_helpers import Mul
             return Mul.make(self, other)
         raise NotImplementedError
 
@@ -66,101 +68,3 @@ class Model(NiftyMetaBase()):
         if isinstance(other, (float, int)):
             return self.__mul__(other)
         raise NotImplementedError
-
-
-def _joint_position(op1, op2):
-    a = op1.position._val
-    b = op2.position._val
-    # Note: In python >3.5 one could do {**a, **b}
-    ab = a.copy()
-    ab.update(b)
-    return MultiField(ab)
-
-
-class Mul(Model):
-    """
-    Please note: If you multiply two operators which share some keys in the
-    position but have different values there, it is not guaranteed which value
-    will be used for the product.
-    """
-    def __init__(self, position, op1, op2):
-        super(Mul, self).__init__(position)
-
-        self._op1 = op1.at(position)
-        self._op2 = op2.at(position)
-
-        self._value = self._op1.value * self._op2.value
-        self._gradient = (makeOp(self._op1.value) * self._op2.gradient +
-                          makeOp(self._op2.value) * self._op1.gradient)
-
-    @staticmethod
-    def make(op1, op2):
-        position = _joint_position(op1, op2)
-        return Mul(position, op1, op2)
-
-    def at(self, position):
-        return self.__class__(position, self._op1, self._op2)
-
-
-class Add(Model):
-    """
-    Please note: If you add two operators which share some keys in the position
-    but have different values there, it is not guaranteed which value will be
-    used for the sum.
-    """
-    def __init__(self, position, op1, op2):
-        super(Add, self).__init__(position)
-
-        self._op1 = op1.at(position)
-        self._op2 = op2.at(position)
-
-        self._value = self._op1.value + self._op2.value
-        self._gradient = self._op1.gradient + self._op2.gradient
-
-    @staticmethod
-    def make(op1, op2):
-        position = _joint_position(op1, op2)
-        return Add(position, op1, op2)
-
-    def at(self, position):
-        return self.__class__(position, self._op1, self._op2)
-
-
-class ScalarMul(Model):
-    def __init__(self, factor, op):
-        super(ScalarMul, self).__init__(op.position)
-        if not isinstance(factor, (float, int)):
-            raise TypeError
-
-        self._op = op
-        self._factor = factor
-
-        self._value = self._factor * self._op.value
-        self._gradient = self._factor * self._op.gradient
-
-    def at(self, position):
-        return self.__class__(self._factor, self._op.at(position))
-
-
-class LinearModel(Model):
-    def __init__(self, inp, lin_op):
-        """
-        Computes lin_op(inp) where lin_op is a Linear Operator
-        """
-        from ..operators import LinearOperator
-        super(LinearModel, self).__init__(inp.position)
-
-        if not isinstance(lin_op, LinearOperator):
-            raise TypeError("needs a LinearOperator as input")
-
-        self._lin_op = lin_op
-        self._inp = inp
-        if isinstance(self._lin_op, SelectionOperator):
-            self._lin_op = SelectionOperator(self._inp.value.domain,
-                                             self._lin_op._key)
-
-        self._value = self._lin_op(self._inp.value)
-        self._gradient = self._lin_op*self._inp.gradient
-
-    def at(self, position):
-        return self.__class__(self._inp.at(position), self._lin_op)
diff --git a/nifty5/models/variable.py b/nifty5/models/variable.py
index 509d6fbfe..75fb56038 100644
--- a/nifty5/models/variable.py
+++ b/nifty5/models/variable.py
@@ -16,8 +16,8 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
-from .model import Model
 from ..operators.scaling_operator import ScalingOperator
+from .model import Model
 
 
 class Variable(Model):
diff --git a/nifty5/operators/__init__.py b/nifty5/operators/__init__.py
index 03a580a25..35b8753be 100644
--- a/nifty5/operators/__init__.py
+++ b/nifty5/operators/__init__.py
@@ -8,11 +8,12 @@ from .harmonic_transform_operator import HarmonicTransformOperator
 from .inversion_enabler import InversionEnabler
 from .laplace_operator import LaplaceOperator
 from .linear_operator import LinearOperator
+from .model_gradient_operator import ModelGradientOperator
 from .power_distributor import PowerDistributor
 from .sampling_enabler import SamplingEnabler
 from .sandwich_operator import SandwichOperator
 from .scaling_operator import ScalingOperator
-from .model_gradient_operator import ModelGradientOperator
+from .selection_operator import SelectionOperator
 from .smoothness_operator import SmoothnessOperator
 
 __all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
diff --git a/nifty5/models/selection_operator.py b/nifty5/operators/selection_operator.py
similarity index 100%
rename from nifty5/models/selection_operator.py
rename to nifty5/operators/selection_operator.py
-- 
GitLab