Skip to content
Snippets Groups Projects
Commit acb67ac2 authored by Philipp Arras's avatar Philipp Arras
Browse files

Restructuring models

parent ab20d935
No related branches found
No related tags found
1 merge request!271Restructuring models
Pipeline #
from .constant import Constant from .constant import Constant
from .linear import LinearModel
from .local_nonlinearity import (LocalModel, PointwiseExponential, from .local_nonlinearity import (LocalModel, PointwiseExponential,
PointwisePositiveTanh, PointwiseTanh) PointwisePositiveTanh, PointwiseTanh)
from .model import LinearModel, Model from .model import Model
from .variable import Variable from .variable import Variable
__all__ = ['Model', 'Constant', 'LocalModel', 'Variable', __all__ = ['Model', 'Constant', 'LocalModel', 'Variable',
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..multi import MultiField
from ..sugar import makeOp
from .model import Model
def _joint_position(op1, op2):
a = op1.position._val
b = op2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
return MultiField(ab)
class ScalarMul(Model):
def __init__(self, factor, op):
super(ScalarMul, self).__init__(op.position)
if not isinstance(factor, (float, int)):
raise TypeError
self._op = op
self._factor = factor
self._value = self._factor * self._op.value
self._gradient = self._factor * self._op.gradient
def at(self, position):
return self.__class__(self._factor, self._op.at(position))
class Add(Model):
def __init__(self, position, op1, op2):
super(Add, self).__init__(position)
self._op1 = op1.at(position)
self._op2 = op2.at(position)
self._value = self._op1.value + self._op2.value
self._gradient = self._op1.gradient + self._op2.gradient
@staticmethod
def make(op1, op2):
position = _joint_position(op1, op2)
return Add(position, op1, op2)
def at(self, position):
return self.__class__(position, self._op1, self._op2)
class Mul(Model):
def __init__(self, position, op1, op2):
super(Mul, self).__init__(position)
self._op1 = op1.at(position)
self._op2 = op2.at(position)
self._value = self._op1.value * self._op2.value
self._gradient = (makeOp(self._op1.value) * self._op2.gradient +
makeOp(self._op2.value) * self._op1.gradient)
@staticmethod
def make(op1, op2):
position = _joint_position(op1, op2)
return Mul(position, op1, op2)
def at(self, position):
return self.__class__(position, self._op1, self._op2)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..operators.selection_operator import SelectionOperator
from .model import Model
class LinearModel(Model):
def __init__(self, inp, lin_op):
"""
Computes lin_op(inp) where lin_op is a Linear Operator
"""
from ..operators import LinearOperator
super(LinearModel, self).__init__(inp.position)
if not isinstance(lin_op, LinearOperator):
raise TypeError("needs a LinearOperator as input")
self._lin_op = lin_op
self._inp = inp
if isinstance(self._lin_op, SelectionOperator):
self._lin_op = SelectionOperator(self._inp.value.domain,
self._lin_op._key)
self._value = self._lin_op(self._inp.value)
self._gradient = self._lin_op*self._inp.gradient
def at(self, position):
return self.__class__(self._inp.at(position), self._lin_op)
...@@ -16,10 +16,8 @@ ...@@ -16,10 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..multi import MultiField from ..operators.selection_operator import SelectionOperator
from ..sugar import makeOp
from ..utilities import NiftyMetaBase from ..utilities import NiftyMetaBase
from .selection_operator import SelectionOperator
class Model(NiftyMetaBase()): class Model(NiftyMetaBase()):
...@@ -48,17 +46,21 @@ class Model(NiftyMetaBase()): ...@@ -48,17 +46,21 @@ class Model(NiftyMetaBase()):
def __add__(self, other): def __add__(self, other):
if not isinstance(other, Model): if not isinstance(other, Model):
raise TypeError raise TypeError
from .binary_helpers import Add
return Add.make(self, other) return Add.make(self, other)
def __sub__(self, other): def __sub__(self, other):
if not isinstance(other, Model): if not isinstance(other, Model):
raise TypeError raise TypeError
from .binary_helpers import Add
return Add.make(self, (-1) * other) return Add.make(self, (-1) * other)
def __mul__(self, other): def __mul__(self, other):
if isinstance(other, (float, int)): if isinstance(other, (float, int)):
from .binary_helpers import ScalarMul
return ScalarMul(other, self) return ScalarMul(other, self)
if isinstance(other, Model): if isinstance(other, Model):
from .binary_helpers import Mul
return Mul.make(self, other) return Mul.make(self, other)
raise NotImplementedError raise NotImplementedError
...@@ -66,101 +68,3 @@ class Model(NiftyMetaBase()): ...@@ -66,101 +68,3 @@ class Model(NiftyMetaBase()):
if isinstance(other, (float, int)): if isinstance(other, (float, int)):
return self.__mul__(other) return self.__mul__(other)
raise NotImplementedError raise NotImplementedError
def _joint_position(op1, op2):
a = op1.position._val
b = op2.position._val
# Note: In python >3.5 one could do {**a, **b}
ab = a.copy()
ab.update(b)
return MultiField(ab)
class Mul(Model):
"""
Please note: If you multiply two operators which share some keys in the
position but have different values there, it is not guaranteed which value
will be used for the product.
"""
def __init__(self, position, op1, op2):
super(Mul, self).__init__(position)
self._op1 = op1.at(position)
self._op2 = op2.at(position)
self._value = self._op1.value * self._op2.value
self._gradient = (makeOp(self._op1.value) * self._op2.gradient +
makeOp(self._op2.value) * self._op1.gradient)
@staticmethod
def make(op1, op2):
position = _joint_position(op1, op2)
return Mul(position, op1, op2)
def at(self, position):
return self.__class__(position, self._op1, self._op2)
class Add(Model):
"""
Please note: If you add two operators which share some keys in the position
but have different values there, it is not guaranteed which value will be
used for the sum.
"""
def __init__(self, position, op1, op2):
super(Add, self).__init__(position)
self._op1 = op1.at(position)
self._op2 = op2.at(position)
self._value = self._op1.value + self._op2.value
self._gradient = self._op1.gradient + self._op2.gradient
@staticmethod
def make(op1, op2):
position = _joint_position(op1, op2)
return Add(position, op1, op2)
def at(self, position):
return self.__class__(position, self._op1, self._op2)
class ScalarMul(Model):
def __init__(self, factor, op):
super(ScalarMul, self).__init__(op.position)
if not isinstance(factor, (float, int)):
raise TypeError
self._op = op
self._factor = factor
self._value = self._factor * self._op.value
self._gradient = self._factor * self._op.gradient
def at(self, position):
return self.__class__(self._factor, self._op.at(position))
class LinearModel(Model):
def __init__(self, inp, lin_op):
"""
Computes lin_op(inp) where lin_op is a Linear Operator
"""
from ..operators import LinearOperator
super(LinearModel, self).__init__(inp.position)
if not isinstance(lin_op, LinearOperator):
raise TypeError("needs a LinearOperator as input")
self._lin_op = lin_op
self._inp = inp
if isinstance(self._lin_op, SelectionOperator):
self._lin_op = SelectionOperator(self._inp.value.domain,
self._lin_op._key)
self._value = self._lin_op(self._inp.value)
self._gradient = self._lin_op*self._inp.gradient
def at(self, position):
return self.__class__(self._inp.at(position), self._lin_op)
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .model import Model
from ..operators.scaling_operator import ScalingOperator from ..operators.scaling_operator import ScalingOperator
from .model import Model
class Variable(Model): class Variable(Model):
......
...@@ -8,11 +8,12 @@ from .harmonic_transform_operator import HarmonicTransformOperator ...@@ -8,11 +8,12 @@ from .harmonic_transform_operator import HarmonicTransformOperator
from .inversion_enabler import InversionEnabler from .inversion_enabler import InversionEnabler
from .laplace_operator import LaplaceOperator from .laplace_operator import LaplaceOperator
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from .model_gradient_operator import ModelGradientOperator
from .power_distributor import PowerDistributor from .power_distributor import PowerDistributor
from .sampling_enabler import SamplingEnabler from .sampling_enabler import SamplingEnabler
from .sandwich_operator import SandwichOperator from .sandwich_operator import SandwichOperator
from .scaling_operator import ScalingOperator from .scaling_operator import ScalingOperator
from .model_gradient_operator import ModelGradientOperator from .selection_operator import SelectionOperator
from .smoothness_operator import SmoothnessOperator from .smoothness_operator import SmoothnessOperator
__all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator", __all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment