Commit 93314cc2 authored by Philipp Arras's avatar Philipp Arras

Simplify LocalModel

parent ece7b47f
Pipeline #31076 passed with stages
in 1 minute and 29 seconds
......@@ -4,12 +4,12 @@ from .model import Model
class LocalModel(Model):
def __init__(self, position, inp, nonlinearity):
def __init__(self, inp, nonlinearity):
"""
Computes nonlinearity(inp)
"""
super(LocalModel, self).__init__(position)
self._inp = inp.at(self.position)
super(LocalModel, self).__init__(inp.position)
self._inp = inp
self._nonlinearity = nonlinearity
self._value = nonlinearity(self._inp.value)
d_inner = self._inp.gradient
......@@ -17,4 +17,4 @@ class LocalModel(Model):
self._gradient = d_outer * d_inner
def at(self, position):
return self.__class__(position, self._inp, self._nonlinearity)
return self.__class__(self._inp.at(position), self._nonlinearity)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment