local_nonlinearity.py 637 Bytes
Newer Older
1
2
from nifty4.sugar import makeOp

Philipp Arras's avatar
Philipp Arras committed
3
from .model import Model
4
5


Philipp Arras's avatar
Philipp Arras committed
6
class LocalModel(Model):
7
8
9
10
11
12
13
14
    def __init__(self, position, inp, nonlinearity):
        """
        Computes nonlinearity(inp)
        """
        super(LocalModel, self).__init__(position)
        self._inp = inp.at(self.position)
        self._nonlinearity = nonlinearity
        self._value = nonlinearity(self._inp.value)
Philipp Arras's avatar
Philipp Arras committed
15
16
17
        d_inner = self._inp.gradient
        d_outer = makeOp(self._nonlinearity.derivative(self._inp.value))
        self._gradient = d_outer * d_inner
18
19
20

    def at(self, position):
        return self.__class__(position, self._inp, self._nonlinearity)