model.py 4.99 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.

Philipp Arras's avatar
Philipp Arras committed
19
20
21
from ..multi import MultiField
from ..sugar import makeOp
from ..utilities import NiftyMetaBase
22
23
24
from .selection_operator import SelectionOperator


Philipp Arras's avatar
Philipp Arras committed
25
class Model(NiftyMetaBase()):
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    def __init__(self, position):
        self._position = position

    def at(self, position):
        raise NotImplementedError

    @property
    def position(self):
        return self._position

    @property
    def value(self):
        return self._value

    @property
    def gradient(self):
        return self._gradient

    def __getitem__(self, key):
        sel = SelectionOperator(self.value.domain, key)
Philipp Arras's avatar
Fixups    
Philipp Arras committed
46
        return sel(self)
47
48

    def __add__(self, other):
Philipp Arras's avatar
Philipp Arras committed
49
50
        if not isinstance(other, Model):
            raise TypeError
51
52
53
        return Add.make(self, other)

    def __sub__(self, other):
Philipp Arras's avatar
Philipp Arras committed
54
55
        if not isinstance(other, Model):
            raise TypeError
56
57
58
59
        return Add.make(self, (-1) * other)

    def __mul__(self, other):
        if isinstance(other, (float, int)):
Philipp Arras's avatar
Philipp Arras committed
60
            return ScalarMul(other, self)
Philipp Arras's avatar
Philipp Arras committed
61
        if isinstance(other, Model):
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            return Mul.make(self, other)
        raise NotImplementedError

    def __rmul__(self, other):
        if isinstance(other, (float, int)):
            return self.__mul__(other)
        raise NotImplementedError


def _joint_position(op1, op2):
    a = op1.position._val
    b = op2.position._val
    # Note: In python >3.5 one could do {**a, **b}
    ab = a.copy()
    ab.update(b)
Philipp Arras's avatar
Philipp Arras committed
77
    return MultiField(ab)
78
79


Philipp Arras's avatar
Philipp Arras committed
80
class Mul(Model):
81
    """
82
83
    Please note: If you multiply two operators which share some keys in the
    position but have different values there, it is not guaranteed which value
Philipp Arras's avatar
Docu    
Philipp Arras committed
84
    will be used for the product.
85
86
87
88
89
90
91
92
    """
    def __init__(self, position, op1, op2):
        super(Mul, self).__init__(position)

        self._op1 = op1.at(position)
        self._op2 = op2.at(position)

        self._value = self._op1.value * self._op2.value
Philipp Arras's avatar
Philipp Arras committed
93
94
        self._gradient = (makeOp(self._op1.value) * self._op2.gradient +
                          makeOp(self._op2.value) * self._op1.gradient)
95
96
97
98
99
100
101
102
103
104

    @staticmethod
    def make(op1, op2):
        position = _joint_position(op1, op2)
        return Mul(position, op1, op2)

    def at(self, position):
        return self.__class__(position, self._op1, self._op2)


Philipp Arras's avatar
Philipp Arras committed
105
class Add(Model):
106
107
108
    """
    Please note: If you add two operators which share some keys in the position
    but have different values there, it is not guaranteed which value will be
Philipp Arras's avatar
Docu    
Philipp Arras committed
109
    used for the sum.
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    """
    def __init__(self, position, op1, op2):
        super(Add, self).__init__(position)

        self._op1 = op1.at(position)
        self._op2 = op2.at(position)

        self._value = self._op1.value + self._op2.value
        self._gradient = self._op1.gradient + self._op2.gradient

    @staticmethod
    def make(op1, op2):
        position = _joint_position(op1, op2)
        return Add(position, op1, op2)

    def at(self, position):
        return self.__class__(position, self._op1, self._op2)


Philipp Arras's avatar
Philipp Arras committed
129
class ScalarMul(Model):
Philipp Arras's avatar
Philipp Arras committed
130
131
    def __init__(self, factor, op):
        super(ScalarMul, self).__init__(op.position)
Philipp Arras's avatar
Philipp Arras committed
132
133
        if not isinstance(factor, (float, int)):
            raise TypeError
134

Philipp Arras's avatar
Philipp Arras committed
135
        self._op = op
136
137
138
139
140
141
        self._factor = factor

        self._value = self._factor * self._op.value
        self._gradient = self._factor * self._op.gradient

    def at(self, position):
Philipp Arras's avatar
Philipp Arras committed
142
        return self.__class__(self._factor, self._op.at(position))
143
144


Philipp Arras's avatar
Philipp Arras committed
145
class LinearModel(Model):
Philipp Arras's avatar
Fixups    
Philipp Arras committed
146
    def __init__(self, inp, lin_op):
147
148
149
        """
        Computes lin_op(inp) where lin_op is a Linear Operator
        """
Philipp Arras's avatar
Fixups    
Philipp Arras committed
150
151
        from ..operators import LinearOperator
        super(LinearModel, self).__init__(inp.position)
152
153
154
155
156

        if not isinstance(lin_op, LinearOperator):
            raise TypeError("needs a LinearOperator as input")

        self._lin_op = lin_op
Philipp Arras's avatar
Fixups    
Philipp Arras committed
157
        self._inp = inp
158
159
160
161
162
163
164
165
        if isinstance(self._lin_op, SelectionOperator):
            self._lin_op = SelectionOperator(self._inp.value.domain,
                                             self._lin_op._key)

        self._value = self._lin_op(self._inp.value)
        self._gradient = self._lin_op*self._inp.gradient

    def at(self, position):
Philipp Arras's avatar
Fixups    
Philipp Arras committed
166
        return self.__class__(self._inp.at(position), self._lin_op)