diff --git a/nifty5/models/__init__.py b/nifty5/models/__init__.py
index ac43207f1edad32336f861ef1f02a37d33c47292..0949e7516186709b70a9efdb04102dd9f5ff1cb7 100644
--- a/nifty5/models/__init__.py
+++ b/nifty5/models/__init__.py
@@ -1,7 +1,8 @@
from .constant import Constant
+from .linear import LinearModel
from .local_nonlinearity import (LocalModel, PointwiseExponential,
PointwisePositiveTanh, PointwiseTanh)
-from .model import LinearModel, Model
+from .model import Model
from .variable import Variable
__all__ = ['Model', 'Constant', 'LocalModel', 'Variable',
diff --git a/nifty5/models/binary_helpers.py b/nifty5/models/binary_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbb7529f59ecfc9bb6d748c82655fb5e79a704c3
--- /dev/null
+++ b/nifty5/models/binary_helpers.py
@@ -0,0 +1,85 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+#
+# Copyright(C) 2013-2018 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
+# and financially supported by the Studienstiftung des deutschen Volkes.
+
+from ..multi import MultiField
+from ..sugar import makeOp
+from .model import Model
+
+
+def _joint_position(op1, op2):
+ a = op1.position._val
+ b = op2.position._val
+ # Note: In python >3.5 one could do {**a, **b}
+ ab = a.copy()
+ ab.update(b)
+ return MultiField(ab)
+
+
+class ScalarMul(Model):
+ def __init__(self, factor, op):
+ super(ScalarMul, self).__init__(op.position)
+ if not isinstance(factor, (float, int)):
+ raise TypeError
+
+ self._op = op
+ self._factor = factor
+
+ self._value = self._factor * self._op.value
+ self._gradient = self._factor * self._op.gradient
+
+ def at(self, position):
+ return self.__class__(self._factor, self._op.at(position))
+
+
+class Add(Model):
+ def __init__(self, position, op1, op2):
+ super(Add, self).__init__(position)
+
+ self._op1 = op1.at(position)
+ self._op2 = op2.at(position)
+
+ self._value = self._op1.value + self._op2.value
+ self._gradient = self._op1.gradient + self._op2.gradient
+
+ @staticmethod
+ def make(op1, op2):
+ position = _joint_position(op1, op2)
+ return Add(position, op1, op2)
+
+ def at(self, position):
+ return self.__class__(position, self._op1, self._op2)
+
+
+class Mul(Model):
+ def __init__(self, position, op1, op2):
+ super(Mul, self).__init__(position)
+
+ self._op1 = op1.at(position)
+ self._op2 = op2.at(position)
+
+ self._value = self._op1.value * self._op2.value
+ self._gradient = (makeOp(self._op1.value) * self._op2.gradient +
+ makeOp(self._op2.value) * self._op1.gradient)
+
+ @staticmethod
+ def make(op1, op2):
+ position = _joint_position(op1, op2)
+ return Mul(position, op1, op2)
+
+ def at(self, position):
+ return self.__class__(position, self._op1, self._op2)
diff --git a/nifty5/models/linear.py b/nifty5/models/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..de391df99338f316a7836f7533c100ef6e4fc2c8
--- /dev/null
+++ b/nifty5/models/linear.py
@@ -0,0 +1,43 @@
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+# Copyright(C) 2013-2018 Max-Planck-Society
+#
+# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
+# and financially supported by the Studienstiftung des deutschen Volkes.
+
+from ..operators.selection_operator import SelectionOperator
+from .model import Model
+
+
+class LinearModel(Model):
+ def __init__(self, inp, lin_op):
+ """
+ Computes lin_op(inp) where lin_op is a Linear Operator
+ """
+ from ..operators import LinearOperator
+ super(LinearModel, self).__init__(inp.position)
+
+ if not isinstance(lin_op, LinearOperator):
+ raise TypeError("needs a LinearOperator as input")
+
+ self._lin_op = lin_op
+ self._inp = inp
+ if isinstance(self._lin_op, SelectionOperator):
+ self._lin_op = SelectionOperator(self._inp.value.domain,
+ self._lin_op._key)
+
+ self._value = self._lin_op(self._inp.value)
+ self._gradient = self._lin_op*self._inp.gradient
+
+ def at(self, position):
+ return self.__class__(self._inp.at(position), self._lin_op)
diff --git a/nifty5/models/model.py b/nifty5/models/model.py
index ebb0912661686725934d93b145091a395ae29739..fb320b422bec2cd6ba6bf9ccf2f5e87a8a07333c 100644
--- a/nifty5/models/model.py
+++ b/nifty5/models/model.py
@@ -16,10 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
-from ..multi import MultiField
-from ..sugar import makeOp
+from ..operators.selection_operator import SelectionOperator
from ..utilities import NiftyMetaBase
-from .selection_operator import SelectionOperator
class Model(NiftyMetaBase()):
@@ -48,17 +46,21 @@ class Model(NiftyMetaBase()):
def __add__(self, other):
if not isinstance(other, Model):
raise TypeError
+ from .binary_helpers import Add
return Add.make(self, other)
def __sub__(self, other):
if not isinstance(other, Model):
raise TypeError
+ from .binary_helpers import Add
return Add.make(self, (-1) * other)
def __mul__(self, other):
if isinstance(other, (float, int)):
+ from .binary_helpers import ScalarMul
return ScalarMul(other, self)
if isinstance(other, Model):
+ from .binary_helpers import Mul
return Mul.make(self, other)
raise NotImplementedError
@@ -66,101 +68,3 @@ class Model(NiftyMetaBase()):
if isinstance(other, (float, int)):
return self.__mul__(other)
raise NotImplementedError
-
-
-def _joint_position(op1, op2):
- a = op1.position._val
- b = op2.position._val
- # Note: In python >3.5 one could do {**a, **b}
- ab = a.copy()
- ab.update(b)
- return MultiField(ab)
-
-
-class Mul(Model):
- """
- Please note: If you multiply two operators which share some keys in the
- position but have different values there, it is not guaranteed which value
- will be used for the product.
- """
- def __init__(self, position, op1, op2):
- super(Mul, self).__init__(position)
-
- self._op1 = op1.at(position)
- self._op2 = op2.at(position)
-
- self._value = self._op1.value * self._op2.value
- self._gradient = (makeOp(self._op1.value) * self._op2.gradient +
- makeOp(self._op2.value) * self._op1.gradient)
-
- @staticmethod
- def make(op1, op2):
- position = _joint_position(op1, op2)
- return Mul(position, op1, op2)
-
- def at(self, position):
- return self.__class__(position, self._op1, self._op2)
-
-
-class Add(Model):
- """
- Please note: If you add two operators which share some keys in the position
- but have different values there, it is not guaranteed which value will be
- used for the sum.
- """
- def __init__(self, position, op1, op2):
- super(Add, self).__init__(position)
-
- self._op1 = op1.at(position)
- self._op2 = op2.at(position)
-
- self._value = self._op1.value + self._op2.value
- self._gradient = self._op1.gradient + self._op2.gradient
-
- @staticmethod
- def make(op1, op2):
- position = _joint_position(op1, op2)
- return Add(position, op1, op2)
-
- def at(self, position):
- return self.__class__(position, self._op1, self._op2)
-
-
-class ScalarMul(Model):
- def __init__(self, factor, op):
- super(ScalarMul, self).__init__(op.position)
- if not isinstance(factor, (float, int)):
- raise TypeError
-
- self._op = op
- self._factor = factor
-
- self._value = self._factor * self._op.value
- self._gradient = self._factor * self._op.gradient
-
- def at(self, position):
- return self.__class__(self._factor, self._op.at(position))
-
-
-class LinearModel(Model):
- def __init__(self, inp, lin_op):
- """
- Computes lin_op(inp) where lin_op is a Linear Operator
- """
- from ..operators import LinearOperator
- super(LinearModel, self).__init__(inp.position)
-
- if not isinstance(lin_op, LinearOperator):
- raise TypeError("needs a LinearOperator as input")
-
- self._lin_op = lin_op
- self._inp = inp
- if isinstance(self._lin_op, SelectionOperator):
- self._lin_op = SelectionOperator(self._inp.value.domain,
- self._lin_op._key)
-
- self._value = self._lin_op(self._inp.value)
- self._gradient = self._lin_op*self._inp.gradient
-
- def at(self, position):
- return self.__class__(self._inp.at(position), self._lin_op)
diff --git a/nifty5/models/variable.py b/nifty5/models/variable.py
index 509d6fbfedfe4921279db9c0ffa6d63710ec13e8..75fb5603890236ea60a654d40dabee5d75085fc8 100644
--- a/nifty5/models/variable.py
+++ b/nifty5/models/variable.py
@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
-from .model import Model
from ..operators.scaling_operator import ScalingOperator
+from .model import Model
class Variable(Model):
diff --git a/nifty5/operators/__init__.py b/nifty5/operators/__init__.py
index 03a580a2503a31babf03f031014732a5b5350d05..35b8753be5ab041bbf7019c39311413b9be01a2b 100644
--- a/nifty5/operators/__init__.py
+++ b/nifty5/operators/__init__.py
@@ -8,11 +8,12 @@ from .harmonic_transform_operator import HarmonicTransformOperator
from .inversion_enabler import InversionEnabler
from .laplace_operator import LaplaceOperator
from .linear_operator import LinearOperator
+from .model_gradient_operator import ModelGradientOperator
from .power_distributor import PowerDistributor
from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator
-from .model_gradient_operator import ModelGradientOperator
+from .selection_operator import SelectionOperator
from .smoothness_operator import SmoothnessOperator
__all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
diff --git a/nifty5/models/selection_operator.py b/nifty5/operators/selection_operator.py
similarity index 100%
rename from nifty5/models/selection_operator.py
rename to nifty5/operators/selection_operator.py