local_nonlinearity.py 893 Bytes
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
from ..library.nonlinearities import Exponential, PositiveTanh, Tanh
from ..sugar import makeOp
3

Philipp Arras's avatar
Philipp Arras committed
4
from .model import Model
5
6


Philipp Arras's avatar
Philipp Arras committed
7
class LocalModel(Model):
Philipp Arras's avatar
Philipp Arras committed
8
    def __init__(self, inp, nonlinearity):
9
10
11
        """
        Computes nonlinearity(inp)
        """
Philipp Arras's avatar
Philipp Arras committed
12
13
        super(LocalModel, self).__init__(inp.position)
        self._inp = inp
14
15
        self._nonlinearity = nonlinearity
        self._value = nonlinearity(self._inp.value)
Philipp Arras's avatar
Philipp Arras committed
16
17
18
        d_inner = self._inp.gradient
        d_outer = makeOp(self._nonlinearity.derivative(self._inp.value))
        self._gradient = d_outer * d_inner
19
20

    def at(self, position):
Philipp Arras's avatar
Philipp Arras committed
21
        return self.__class__(self._inp.at(position), self._nonlinearity)
Philipp Arras's avatar
Philipp Arras committed
22
23
24
25
26
27
28
29
30
31
32
33


def PointwiseExponential(inp):
    return LocalModel(inp, Exponential())


def PointwiseTanh(inp):
    return LocalModel(inp, Tanh())


def PointwisePositiveTanh(inp):
    return LocalModel(inp, PositiveTanh())