local_nonlinearity.py 893 Bytes
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2
from ..library.nonlinearities import Exponential, PositiveTanh, Tanh
from ..sugar import makeOp
3

Philipp Arras's avatar
Philipp Arras committed
4
from .model import Model
5 6


Philipp Arras's avatar
Philipp Arras committed
7
class LocalModel(Model):
Philipp Arras's avatar
Philipp Arras committed
8
    def __init__(self, inp, nonlinearity):
9 10 11
        """
        Computes nonlinearity(inp)
        """
Philipp Arras's avatar
Philipp Arras committed
12 13
        super(LocalModel, self).__init__(inp.position)
        self._inp = inp
14 15
        self._nonlinearity = nonlinearity
        self._value = nonlinearity(self._inp.value)
Philipp Arras's avatar
Philipp Arras committed
16 17 18
        d_inner = self._inp.gradient
        d_outer = makeOp(self._nonlinearity.derivative(self._inp.value))
        self._gradient = d_outer * d_inner
19 20

    def at(self, position):
Philipp Arras's avatar
Philipp Arras committed
21
        return self.__class__(self._inp.at(position), self._nonlinearity)
Philipp Arras's avatar
Philipp Arras committed
22 23 24 25 26 27 28 29 30 31 32 33


def PointwiseExponential(inp):
    return LocalModel(inp, Exponential())


def PointwiseTanh(inp):
    return LocalModel(inp, Tanh())


def PointwisePositiveTanh(inp):
    return LocalModel(inp, PositiveTanh())