binary_helpers.py 3.46 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Martin Reinecke's avatar
Martin Reinecke committed
19
20
from __future__ import absolute_import, division, print_function
from ..compat import *
Martin Reinecke's avatar
Martin Reinecke committed
21
from ..multi.multi_field import MultiField
Philipp Arras's avatar
Philipp Arras committed
22
23
24
25
from ..sugar import makeOp
from .model import Model


26
def _joint_position(model1, model2):
Martin Reinecke's avatar
Martin Reinecke committed
27
28
    a = model1.position.to_dict()
    b = model2.position.to_dict()
Philipp Arras's avatar
Philipp Arras committed
29
    # Note: In python >3.5 one could do {**a, **b}
Martin Reinecke's avatar
Martin Reinecke committed
30
    ab = a
Philipp Arras's avatar
Philipp Arras committed
31
    ab.update(b)
Martin Reinecke's avatar
Martin Reinecke committed
32
    return MultiField.from_dict(ab)
Philipp Arras's avatar
Philipp Arras committed
33
34
35


class ScalarMul(Model):
36
    """Class representing a model multiplied by a scalar factor."""
37
38
    def __init__(self, factor, model):
        super(ScalarMul, self).__init__(model.position)
Philipp Arras's avatar
Philipp Arras committed
39
        # TODO -> floating
Philipp Arras's avatar
Philipp Arras committed
40
41
42
        if not isinstance(factor, (float, int)):
            raise TypeError

43
        self._model = model
Philipp Arras's avatar
Philipp Arras committed
44
45
        self._factor = factor

46
        self._value = self._factor * self._model.value
47
        self._jacobian = self._factor * self._model.jacobian
Philipp Arras's avatar
Philipp Arras committed
48
49

    def at(self, position):
50
        return self.__class__(self._factor, self._model.at(position))
Philipp Arras's avatar
Philipp Arras committed
51
52
53


class Add(Model):
54
    """Class representing the sum of two models."""
55
    def __init__(self, position, model1, model2):
Philipp Arras's avatar
Philipp Arras committed
56
57
        super(Add, self).__init__(position)

58
59
        self._model1 = model1.at(position)
        self._model2 = model2.at(position)
Philipp Arras's avatar
Philipp Arras committed
60

61
        self._value = self._model1.value + self._model2.value
62
        self._jacobian = self._model1.jacobian + self._model2.jacobian
Philipp Arras's avatar
Philipp Arras committed
63
64

    @staticmethod
65
66
67
68
69
70
71
72
73
74
75
76
77
    def make(model1, model2):
        """Build the sum of two models.

        Parameters
        ----------
        model1: Model
            First model.
        model2: Model
            Second model
        """

        position = _joint_position(model1, model2)
        return Add(position, model1, model2)
Philipp Arras's avatar
Philipp Arras committed
78
79

    def at(self, position):
80
        return self.__class__(position, self._model1, self._model2)
Philipp Arras's avatar
Philipp Arras committed
81
82
83


class Mul(Model):
84
    """Class representing the pointwise product of two models."""
85
    def __init__(self, position, model1, model2):
Philipp Arras's avatar
Philipp Arras committed
86
87
        super(Mul, self).__init__(position)

88
89
        self._model1 = model1.at(position)
        self._model2 = model2.at(position)
Philipp Arras's avatar
Philipp Arras committed
90

91
        self._value = self._model1.value * self._model2.value
92
93
        self._jacobian = (makeOp(self._model1.value) * self._model2.jacobian +
                          makeOp(self._model2.value) * self._model1.jacobian)
Philipp Arras's avatar
Philipp Arras committed
94
95

    @staticmethod
96
97
98
99
100
101
102
103
104
105
106
107
    def make(model1, model2):
        """Build the pointwise product of two models.

        Parameters
        ----------
        model1: Model
            First model.
        model2: Model
            Second model
        """
        position = _joint_position(model1, model2)
        return Mul(position, model1, model2)
Philipp Arras's avatar
Philipp Arras committed
108
109

    def at(self, position):
110
        return self.__class__(position, self._model1, self._model2)