diff --git a/nifty/data_objects/my_own_do.py b/nifty/data_objects/my_own_do.py index 594ae2838c57fcbab8ad5b3afe6f1eac0048f0ce..56c2d31ddddcb78b4348bc4a484d4d72a3041f69 100644 --- a/nifty/data_objects/my_own_do.py +++ b/nifty/data_objects/my_own_do.py @@ -56,8 +56,8 @@ class data_object(object): a = self._data if isinstance(other, data_object): b = other._data - # if a.shape != b.shape: - # print("shapes are incompatible.") + if a.shape != b.shape: + raise ValueError("shapes are incompatible.") else: b = other diff --git a/nifty/dobj.py b/nifty/dobj.py index 2264d49154ef4731987ef183ac06742a1bf3839c..cb3af9f93c9eda671cad98b9d2e074f04fcb3da2 100644 --- a/nifty/dobj.py +++ b/nifty/dobj.py @@ -1 +1,2 @@ from .data_objects.my_own_do import * +#from .data_objects.numpy_do import * diff --git a/nifty/operators/diagonal_operator.py b/nifty/operators/diagonal_operator.py index 1757c127159aed80b77fc3ffa2bfd1919c3f3588..5006540e1f2940292eb1ec83c39d49f69a6e90dc 100644 --- a/nifty/operators/diagonal_operator.py +++ b/nifty/operators/diagonal_operator.py @@ -22,6 +22,7 @@ from ..field import Field from ..domain_tuple import DomainTuple from .endomorphic_operator import EndomorphicOperator from ..nifty_utilities import cast_iseq_to_tuple +from ..dobj import to_ndarray as to_np class DiagonalOperator(EndomorphicOperator): @@ -97,16 +98,16 @@ class DiagonalOperator(EndomorphicOperator): self._unitary = None def _times(self, x): - return self._times_helper(x, lambda z: z.__mul__) + return self._times_helper(x, self._diagonal) def _adjoint_times(self, x): - return self._times_helper(x, lambda z: z.conjugate().__mul__) + return self._times_helper(x, self._diagonal.conj()) def _inverse_times(self, x): - return self._times_helper(x, lambda z: z.__rtruediv__) + return self._times_helper(x, 1./self._diagonal) def _adjoint_inverse_times(self, x): - return self._times_helper(x, lambda z: z.conjugate().__rtruediv__) + return self._times_helper(x, 1./self._diagonal.conj()) def diagonal(self): """ Returns the diagonal of the Operator. @@ -137,9 +138,9 @@ class DiagonalOperator(EndomorphicOperator): self._unitary = (abs(self._diagonal.val) == 1.).all() return self._unitary - def _times_helper(self, x, operation): + def _times_helper(self, x, diag): if self._spaces is None: - return operation(self._diagonal)(x) + return diag*x active_axes = [] for space_index in self._spaces: @@ -147,7 +148,7 @@ class DiagonalOperator(EndomorphicOperator): reshaper = [shp if i in active_axes else 1 for i, shp in enumerate(x.shape)] - reshaped_local_diagonal = self._diagonal.val.reshape(reshaper) + reshaped_local_diagonal = to_np(diag.val).reshape(reshaper) # here the actual multiplication takes place - return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val)) + return Field(x.domain, val=x.val*reshaped_local_diagonal)