diff --git a/nifty5/models/__init__.py b/nifty5/models/__init__.py index ac43207f1edad32336f861ef1f02a37d33c47292..0949e7516186709b70a9efdb04102dd9f5ff1cb7 100644 --- a/nifty5/models/__init__.py +++ b/nifty5/models/__init__.py @@ -1,7 +1,8 @@ from .constant import Constant +from .linear import LinearModel from .local_nonlinearity import (LocalModel, PointwiseExponential, PointwisePositiveTanh, PointwiseTanh) -from .model import LinearModel, Model +from .model import Model from .variable import Variable __all__ = ['Model', 'Constant', 'LocalModel', 'Variable', diff --git a/nifty5/models/binary_helpers.py b/nifty5/models/binary_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb7529f59ecfc9bb6d748c82655fb5e79a704c3 --- /dev/null +++ b/nifty5/models/binary_helpers.py @@ -0,0 +1,85 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from ..multi import MultiField +from ..sugar import makeOp +from .model import Model + + +def _joint_position(op1, op2): + a = op1.position._val + b = op2.position._val + # Note: In python >3.5 one could do {**a, **b} + ab = a.copy() + ab.update(b) + return MultiField(ab) + + +class ScalarMul(Model): + def __init__(self, factor, op): + super(ScalarMul, self).__init__(op.position) + if not isinstance(factor, (float, int)): + raise TypeError + + self._op = op + self._factor = factor + + self._value = self._factor * self._op.value + self._gradient = self._factor * self._op.gradient + + def at(self, position): + return self.__class__(self._factor, self._op.at(position)) + + +class Add(Model): + def __init__(self, position, op1, op2): + super(Add, self).__init__(position) + + self._op1 = op1.at(position) + self._op2 = op2.at(position) + + self._value = self._op1.value + self._op2.value + self._gradient = self._op1.gradient + self._op2.gradient + + @staticmethod + def make(op1, op2): + position = _joint_position(op1, op2) + return Add(position, op1, op2) + + def at(self, position): + return self.__class__(position, self._op1, self._op2) + + +class Mul(Model): + def __init__(self, position, op1, op2): + super(Mul, self).__init__(position) + + self._op1 = op1.at(position) + self._op2 = op2.at(position) + + self._value = self._op1.value * self._op2.value + self._gradient = (makeOp(self._op1.value) * self._op2.gradient + + makeOp(self._op2.value) * self._op1.gradient) + + @staticmethod + def make(op1, op2): + position = _joint_position(op1, op2) + return Mul(position, op1, op2) + + def at(self, position): + return self.__class__(position, self._op1, self._op2) diff --git a/nifty5/models/linear.py b/nifty5/models/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..de391df99338f316a7836f7533c100ef6e4fc2c8 --- /dev/null +++ b/nifty5/models/linear.py @@ -0,0 +1,43 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from ..operators.selection_operator import SelectionOperator +from .model import Model + + +class LinearModel(Model): + def __init__(self, inp, lin_op): + """ + Computes lin_op(inp) where lin_op is a Linear Operator + """ + from ..operators import LinearOperator + super(LinearModel, self).__init__(inp.position) + + if not isinstance(lin_op, LinearOperator): + raise TypeError("needs a LinearOperator as input") + + self._lin_op = lin_op + self._inp = inp + if isinstance(self._lin_op, SelectionOperator): + self._lin_op = SelectionOperator(self._inp.value.domain, + self._lin_op._key) + + self._value = self._lin_op(self._inp.value) + self._gradient = self._lin_op*self._inp.gradient + + def at(self, position): + return self.__class__(self._inp.at(position), self._lin_op) diff --git a/nifty5/models/model.py b/nifty5/models/model.py index ebb0912661686725934d93b145091a395ae29739..fb320b422bec2cd6ba6bf9ccf2f5e87a8a07333c 100644 --- a/nifty5/models/model.py +++ b/nifty5/models/model.py @@ -16,10 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ..multi import MultiField -from ..sugar import makeOp +from ..operators.selection_operator import SelectionOperator from ..utilities import NiftyMetaBase -from .selection_operator import SelectionOperator class Model(NiftyMetaBase()): @@ -48,17 +46,21 @@ class Model(NiftyMetaBase()): def __add__(self, other): if not isinstance(other, Model): raise TypeError + from .binary_helpers import Add return Add.make(self, other) def __sub__(self, other): if not isinstance(other, Model): raise TypeError + from .binary_helpers import Add return Add.make(self, (-1) * other) def __mul__(self, other): if isinstance(other, (float, int)): + from .binary_helpers import ScalarMul return ScalarMul(other, self) if isinstance(other, Model): + from .binary_helpers import Mul return Mul.make(self, other) raise NotImplementedError @@ -66,101 +68,3 @@ class Model(NiftyMetaBase()): if isinstance(other, (float, int)): return self.__mul__(other) raise NotImplementedError - - -def _joint_position(op1, op2): - a = op1.position._val - b = op2.position._val - # Note: In python >3.5 one could do {**a, **b} - ab = a.copy() - ab.update(b) - return MultiField(ab) - - -class Mul(Model): - """ - Please note: If you multiply two operators which share some keys in the - position but have different values there, it is not guaranteed which value - will be used for the product. - """ - def __init__(self, position, op1, op2): - super(Mul, self).__init__(position) - - self._op1 = op1.at(position) - self._op2 = op2.at(position) - - self._value = self._op1.value * self._op2.value - self._gradient = (makeOp(self._op1.value) * self._op2.gradient + - makeOp(self._op2.value) * self._op1.gradient) - - @staticmethod - def make(op1, op2): - position = _joint_position(op1, op2) - return Mul(position, op1, op2) - - def at(self, position): - return self.__class__(position, self._op1, self._op2) - - -class Add(Model): - """ - Please note: If you add two operators which share some keys in the position - but have different values there, it is not guaranteed which value will be - used for the sum. - """ - def __init__(self, position, op1, op2): - super(Add, self).__init__(position) - - self._op1 = op1.at(position) - self._op2 = op2.at(position) - - self._value = self._op1.value + self._op2.value - self._gradient = self._op1.gradient + self._op2.gradient - - @staticmethod - def make(op1, op2): - position = _joint_position(op1, op2) - return Add(position, op1, op2) - - def at(self, position): - return self.__class__(position, self._op1, self._op2) - - -class ScalarMul(Model): - def __init__(self, factor, op): - super(ScalarMul, self).__init__(op.position) - if not isinstance(factor, (float, int)): - raise TypeError - - self._op = op - self._factor = factor - - self._value = self._factor * self._op.value - self._gradient = self._factor * self._op.gradient - - def at(self, position): - return self.__class__(self._factor, self._op.at(position)) - - -class LinearModel(Model): - def __init__(self, inp, lin_op): - """ - Computes lin_op(inp) where lin_op is a Linear Operator - """ - from ..operators import LinearOperator - super(LinearModel, self).__init__(inp.position) - - if not isinstance(lin_op, LinearOperator): - raise TypeError("needs a LinearOperator as input") - - self._lin_op = lin_op - self._inp = inp - if isinstance(self._lin_op, SelectionOperator): - self._lin_op = SelectionOperator(self._inp.value.domain, - self._lin_op._key) - - self._value = self._lin_op(self._inp.value) - self._gradient = self._lin_op*self._inp.gradient - - def at(self, position): - return self.__class__(self._inp.at(position), self._lin_op) diff --git a/nifty5/models/variable.py b/nifty5/models/variable.py index 509d6fbfedfe4921279db9c0ffa6d63710ec13e8..75fb5603890236ea60a654d40dabee5d75085fc8 100644 --- a/nifty5/models/variable.py +++ b/nifty5/models/variable.py @@ -16,8 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from .model import Model from ..operators.scaling_operator import ScalingOperator +from .model import Model class Variable(Model): diff --git a/nifty5/operators/__init__.py b/nifty5/operators/__init__.py index 03a580a2503a31babf03f031014732a5b5350d05..35b8753be5ab041bbf7019c39311413b9be01a2b 100644 --- a/nifty5/operators/__init__.py +++ b/nifty5/operators/__init__.py @@ -8,11 +8,12 @@ from .harmonic_transform_operator import HarmonicTransformOperator from .inversion_enabler import InversionEnabler from .laplace_operator import LaplaceOperator from .linear_operator import LinearOperator +from .model_gradient_operator import ModelGradientOperator from .power_distributor import PowerDistributor from .sampling_enabler import SamplingEnabler from .sandwich_operator import SandwichOperator from .scaling_operator import ScalingOperator -from .model_gradient_operator import ModelGradientOperator +from .selection_operator import SelectionOperator from .smoothness_operator import SmoothnessOperator __all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator", diff --git a/nifty5/models/selection_operator.py b/nifty5/operators/selection_operator.py similarity index 100% rename from nifty5/models/selection_operator.py rename to nifty5/operators/selection_operator.py