Commit ab20d935 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix Scalar multiplication

parent 5b696445
Pipeline #31476 failed with stages
in 4 minutes and 39 seconds
...@@ -57,7 +57,7 @@ class Model(NiftyMetaBase()): ...@@ -57,7 +57,7 @@ class Model(NiftyMetaBase()):
def __mul__(self, other): def __mul__(self, other):
if isinstance(other, (float, int)): if isinstance(other, (float, int)):
return ScalarMul(self._position, other, self) return ScalarMul(other, self)
if isinstance(other, Model): if isinstance(other, Model):
return Mul.make(self, other) return Mul.make(self, other)
raise NotImplementedError raise NotImplementedError
...@@ -127,21 +127,19 @@ class Add(Model): ...@@ -127,21 +127,19 @@ class Add(Model):
class ScalarMul(Model): class ScalarMul(Model):
def __init__(self, position, factor, op): def __init__(self, factor, op):
super(ScalarMul, self).__init__(position) super(ScalarMul, self).__init__(op.position)
if not isinstance(factor, (float, int)): if not isinstance(factor, (float, int)):
raise TypeError raise TypeError
if not isinstance(op, Model):
raise TypeError
self._op = op.at(position) self._op = op
self._factor = factor self._factor = factor
self._value = self._factor * self._op.value self._value = self._factor * self._op.value
self._gradient = self._factor * self._op.gradient self._gradient = self._factor * self._op.gradient
def at(self, position): def at(self, position):
return self.__class__(position, self._factor, self._op) return self.__class__(self._factor, self._op.at(position))
class LinearModel(Model): class LinearModel(Model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment