Commit 0360693e authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix value checks in LinearInterpolator

parent 5689ba2d
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce from functools import reduce
from operator import mul from operator import add
import numpy as np import numpy as np
from scipy.sparse import coo_matrix from scipy.sparse import coo_matrix
...@@ -53,7 +53,7 @@ class LinearInterpolator(LinearOperator): ...@@ -53,7 +53,7 @@ class LinearInterpolator(LinearOperator):
dims = [len(dom.shape) for dom in self.domain] dims = [len(dom.shape) for dom in self.domain]
# FIXME This needs to be removed as soon as the bug below is fixed. # FIXME This needs to be removed as soon as the bug below is fixed.
if not dims.count(dims[0]) == len(dims): if dims.count(dims[0]) != len(dims):
raise TypeError( raise TypeError(
'This is a bug. Please extend LinearInterpolators functionality!' 'This is a bug. Please extend LinearInterpolators functionality!'
) )
...@@ -62,7 +62,7 @@ class LinearInterpolator(LinearOperator): ...@@ -62,7 +62,7 @@ class LinearInterpolator(LinearOperator):
if not (isinstance(sampling_points, np.ndarray) and len(shp) == 2): if not (isinstance(sampling_points, np.ndarray) and len(shp) == 2):
raise TypeError raise TypeError
n_dim, n_points = shp n_dim, n_points = shp
if not n_dim == reduce(mul, dims): if n_dim != reduce(add, dims):
raise TypeError raise TypeError
self._target = makeDomain(UnstructuredDomain(n_points)) self._target = makeDomain(UnstructuredDomain(n_points))
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment