Commit ab20d935 authored by Philipp Arras's avatar Philipp Arras

Fix Scalar multiplication

parent 5b696445
Pipeline #31476 failed with stages
in 4 minutes and 39 seconds
......@@ -57,7 +57,7 @@ class Model(NiftyMetaBase()):
def __mul__(self, other):
if isinstance(other, (float, int)):
return ScalarMul(self._position, other, self)
return ScalarMul(other, self)
if isinstance(other, Model):
return Mul.make(self, other)
raise NotImplementedError
......@@ -127,21 +127,19 @@ class Add(Model):
class ScalarMul(Model):
def __init__(self, position, factor, op):
super(ScalarMul, self).__init__(position)
def __init__(self, factor, op):
super(ScalarMul, self).__init__(op.position)
if not isinstance(factor, (float, int)):
raise TypeError
if not isinstance(op, Model):
raise TypeError
self._op = op.at(position)
self._op = op
self._factor = factor
self._value = self._factor * self._op.value
self._gradient = self._factor * self._op.gradient
def at(self, position):
return self.__class__(position, self._factor, self._op)
return self.__class__(self._factor, self._op.at(position))
class LinearModel(Model):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment