binary_helpers.py 3.46 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from __future__ import absolute_import, division, print_function
20

Martin Reinecke's avatar
Martin Reinecke committed
21
from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
22
from ..multi.multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
23 24 25 26
from ..sugar import makeOp
from .model import Model


27
def _joint_position(model1, model2):
Martin Reinecke's avatar
Martin Reinecke committed
28 29
    a = model1.position.to_dict()
    b = model2.position.to_dict()
Philipp Arras's avatar
Philipp Arras committed
30
    # Note: In python >3.5 one could do {**a, **b}
Martin Reinecke's avatar
Martin Reinecke committed
31
    ab = a
Philipp Arras's avatar
Philipp Arras committed
32
    ab.update(b)
Martin Reinecke's avatar
Martin Reinecke committed
33
    return MultiField.from_dict(ab)
Philipp Arras's avatar
Philipp Arras committed
34 35 36


class ScalarMul(Model):
37
    """Class representing a model multiplied by a scalar factor."""
Philipp Arras's avatar
Philipp Arras committed
38

39 40
    def __init__(self, factor, model):
        super(ScalarMul, self).__init__(model.position)
Philipp Arras's avatar
Philipp Arras committed
41
        # TODO -> floating
Philipp Arras's avatar
Philipp Arras committed
42 43 44
        if not isinstance(factor, (float, int)):
            raise TypeError

45
        self._model = model
Philipp Arras's avatar
Philipp Arras committed
46 47
        self._factor = factor

48
        self._value = self._factor * self._model.value
49
        self._jacobian = self._factor * self._model.jacobian
Philipp Arras's avatar
Philipp Arras committed
50 51

    def at(self, position):
52
        return self.__class__(self._factor, self._model.at(position))
Philipp Arras's avatar
Philipp Arras committed
53 54 55


class Add(Model):
56
    """Class representing the sum of two models."""
Philipp Arras's avatar
Philipp Arras committed
57

58
    def __init__(self, position, model1, model2):
Philipp Arras's avatar
Philipp Arras committed
59 60
        super(Add, self).__init__(position)

61 62
        self._model1 = model1.at(position)
        self._model2 = model2.at(position)
Philipp Arras's avatar
Philipp Arras committed
63

64
        self._value = self._model1.value + self._model2.value
65
        self._jacobian = self._model1.jacobian + self._model2.jacobian
Philipp Arras's avatar
Philipp Arras committed
66 67

    @staticmethod
68 69 70 71 72 73 74 75 76 77 78 79 80
    def make(model1, model2):
        """Build the sum of two models.

        Parameters
        ----------
        model1: Model
            First model.
        model2: Model
            Second model
        """

        position = _joint_position(model1, model2)
        return Add(position, model1, model2)
Philipp Arras's avatar
Philipp Arras committed
81 82

    def at(self, position):
83
        return self.__class__(position, self._model1, self._model2)
Philipp Arras's avatar
Philipp Arras committed
84 85 86


class Mul(Model):
87
    """Class representing the pointwise product of two models."""
Philipp Arras's avatar
Philipp Arras committed
88

89
    def __init__(self, position, model1, model2):
Philipp Arras's avatar
Philipp Arras committed
90 91
        super(Mul, self).__init__(position)

92 93
        self._model1 = model1.at(position)
        self._model2 = model2.at(position)
Philipp Arras's avatar
Philipp Arras committed
94

95
        self._value = self._model1.value * self._model2.value
96 97
        self._jacobian = (makeOp(self._model1.value) * self._model2.jacobian +
                          makeOp(self._model2.value) * self._model1.jacobian)
Philipp Arras's avatar
Philipp Arras committed
98 99

    @staticmethod
100 101 102 103 104 105 106 107 108 109 110 111
    def make(model1, model2):
        """Build the pointwise product of two models.

        Parameters
        ----------
        model1: Model
            First model.
        model2: Model
            Second model
        """
        position = _joint_position(model1, model2)
        return Mul(position, model1, model2)
Philipp Arras's avatar
Philipp Arras committed
112 113

    def at(self, position):
114
        return self.__class__(position, self._model1, self._model2)