binary_helpers.py 3.36 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
from ..multi.multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
20
21
22
23
from ..sugar import makeOp
from .model import Model


24
25
26
def _joint_position(model1, model2):
    a = model1.position._val
    b = model2.position._val
Philipp Arras's avatar
Philipp Arras committed
27
28
29
30
31
32
33
    # Note: In python >3.5 one could do {**a, **b}
    ab = a.copy()
    ab.update(b)
    return MultiField(ab)


class ScalarMul(Model):
34
    """Class representing a model multiplied by a scalar factor."""
35
36
    def __init__(self, factor, model):
        super(ScalarMul, self).__init__(model.position)
Philipp Arras's avatar
Philipp Arras committed
37
        # TODO -> floating
Philipp Arras's avatar
Philipp Arras committed
38
39
40
        if not isinstance(factor, (float, int)):
            raise TypeError

41
        self._model = model
Philipp Arras's avatar
Philipp Arras committed
42
43
        self._factor = factor

44
45
        self._value = self._factor * self._model.value
        self._gradient = self._factor * self._model.gradient
Philipp Arras's avatar
Philipp Arras committed
46
47

    def at(self, position):
48
        return self.__class__(self._factor, self._model.at(position))
Philipp Arras's avatar
Philipp Arras committed
49
50
51


class Add(Model):
52
    """Class representing the sum of two models."""
53
    def __init__(self, position, model1, model2):
Philipp Arras's avatar
Philipp Arras committed
54
55
        super(Add, self).__init__(position)

56
57
        self._model1 = model1.at(position)
        self._model2 = model2.at(position)
Philipp Arras's avatar
Philipp Arras committed
58

59
60
        self._value = self._model1.value + self._model2.value
        self._gradient = self._model1.gradient + self._model2.gradient
Philipp Arras's avatar
Philipp Arras committed
61
62

    @staticmethod
63
64
65
66
67
68
69
70
71
72
73
74
75
    def make(model1, model2):
        """Build the sum of two models.

        Parameters
        ----------
        model1: Model
            First model.
        model2: Model
            Second model
        """

        position = _joint_position(model1, model2)
        return Add(position, model1, model2)
Philipp Arras's avatar
Philipp Arras committed
76
77

    def at(self, position):
78
        return self.__class__(position, self._model1, self._model2)
Philipp Arras's avatar
Philipp Arras committed
79
80
81


class Mul(Model):
82
    """Class representing the pointwise product of two models."""
83
    def __init__(self, position, model1, model2):
Philipp Arras's avatar
Philipp Arras committed
84
85
        super(Mul, self).__init__(position)

86
87
        self._model1 = model1.at(position)
        self._model2 = model2.at(position)
Philipp Arras's avatar
Philipp Arras committed
88

89
90
91
        self._value = self._model1.value * self._model2.value
        self._gradient = (makeOp(self._model1.value) * self._model2.gradient +
                          makeOp(self._model2.value) * self._model1.gradient)
Philipp Arras's avatar
Philipp Arras committed
92
93

    @staticmethod
94
95
96
97
98
99
100
101
102
103
104
105
    def make(model1, model2):
        """Build the pointwise product of two models.

        Parameters
        ----------
        model1: Model
            First model.
        model2: Model
            Second model
        """
        position = _joint_position(model1, model2)
        return Mul(position, model1, model2)
Philipp Arras's avatar
Philipp Arras committed
106
107

    def at(self, position):
108
        return self.__class__(position, self._model1, self._model2)