Commit 68b71aca authored by Martin Reinecke's avatar Martin Reinecke
Browse files

locate all places where adjustments are needed for distributed fields

parent 5bbff04b
Pipeline #20800 passed with stage
in 4 minutes and 8 seconds
......@@ -56,8 +56,8 @@ class data_object(object):
a = self._data
if isinstance(other, data_object):
b = other._data
# if a.shape != b.shape:
# print("shapes are incompatible.")
if a.shape != b.shape:
raise ValueError("shapes are incompatible.")
b = other
from .data_objects.my_own_do import *
#from .data_objects.numpy_do import *
......@@ -22,6 +22,7 @@ from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
from ..nifty_utilities import cast_iseq_to_tuple
from ..dobj import to_ndarray as to_np
class DiagonalOperator(EndomorphicOperator):
......@@ -97,16 +98,16 @@ class DiagonalOperator(EndomorphicOperator):
self._unitary = None
def _times(self, x):
return self._times_helper(x, lambda z: z.__mul__)
return self._times_helper(x, self._diagonal)
def _adjoint_times(self, x):
return self._times_helper(x, lambda z: z.conjugate().__mul__)
return self._times_helper(x, self._diagonal.conj())
def _inverse_times(self, x):
return self._times_helper(x, lambda z: z.__rtruediv__)
return self._times_helper(x, 1./self._diagonal)
def _adjoint_inverse_times(self, x):
return self._times_helper(x, lambda z: z.conjugate().__rtruediv__)
return self._times_helper(x, 1./self._diagonal.conj())
def diagonal(self):
""" Returns the diagonal of the Operator.
......@@ -137,9 +138,9 @@ class DiagonalOperator(EndomorphicOperator):
self._unitary = (abs(self._diagonal.val) == 1.).all()
return self._unitary
def _times_helper(self, x, operation):
def _times_helper(self, x, diag):
if self._spaces is None:
return operation(self._diagonal)(x)
return diag*x
active_axes = []
for space_index in self._spaces:
......@@ -147,7 +148,7 @@ class DiagonalOperator(EndomorphicOperator):
reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(x.shape)]
reshaped_local_diagonal = self._diagonal.val.reshape(reshaper)
reshaped_local_diagonal = to_np(diag.val).reshape(reshaper)
# here the actual multiplication takes place
return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val))
return Field(x.domain, val=x.val*reshaped_local_diagonal)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment