Commit 68b71aca authored by Martin Reinecke's avatar Martin Reinecke
Browse files

locate all places where adjustments are needed for distributed fields

parent 5bbff04b
Pipeline #20800 passed with stage
in 4 minutes and 8 seconds
...@@ -56,8 +56,8 @@ class data_object(object): ...@@ -56,8 +56,8 @@ class data_object(object):
a = self._data a = self._data
if isinstance(other, data_object): if isinstance(other, data_object):
b = other._data b = other._data
# if a.shape != b.shape: if a.shape != b.shape:
# print("shapes are incompatible.") raise ValueError("shapes are incompatible.")
else: else:
b = other b = other
......
from .data_objects.my_own_do import * from .data_objects.my_own_do import *
#from .data_objects.numpy_do import *
...@@ -22,6 +22,7 @@ from ..field import Field ...@@ -22,6 +22,7 @@ from ..field import Field
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator from .endomorphic_operator import EndomorphicOperator
from ..nifty_utilities import cast_iseq_to_tuple from ..nifty_utilities import cast_iseq_to_tuple
from ..dobj import to_ndarray as to_np
class DiagonalOperator(EndomorphicOperator): class DiagonalOperator(EndomorphicOperator):
...@@ -97,16 +98,16 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -97,16 +98,16 @@ class DiagonalOperator(EndomorphicOperator):
self._unitary = None self._unitary = None
def _times(self, x): def _times(self, x):
return self._times_helper(x, lambda z: z.__mul__) return self._times_helper(x, self._diagonal)
def _adjoint_times(self, x): def _adjoint_times(self, x):
return self._times_helper(x, lambda z: z.conjugate().__mul__) return self._times_helper(x, self._diagonal.conj())
def _inverse_times(self, x): def _inverse_times(self, x):
return self._times_helper(x, lambda z: z.__rtruediv__) return self._times_helper(x, 1./self._diagonal)
def _adjoint_inverse_times(self, x): def _adjoint_inverse_times(self, x):
return self._times_helper(x, lambda z: z.conjugate().__rtruediv__) return self._times_helper(x, 1./self._diagonal.conj())
def diagonal(self): def diagonal(self):
""" Returns the diagonal of the Operator. """ Returns the diagonal of the Operator.
...@@ -137,9 +138,9 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -137,9 +138,9 @@ class DiagonalOperator(EndomorphicOperator):
self._unitary = (abs(self._diagonal.val) == 1.).all() self._unitary = (abs(self._diagonal.val) == 1.).all()
return self._unitary return self._unitary
def _times_helper(self, x, operation): def _times_helper(self, x, diag):
if self._spaces is None: if self._spaces is None:
return operation(self._diagonal)(x) return diag*x
active_axes = [] active_axes = []
for space_index in self._spaces: for space_index in self._spaces:
...@@ -147,7 +148,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -147,7 +148,7 @@ class DiagonalOperator(EndomorphicOperator):
reshaper = [shp if i in active_axes else 1 reshaper = [shp if i in active_axes else 1
for i, shp in enumerate(x.shape)] for i, shp in enumerate(x.shape)]
reshaped_local_diagonal = self._diagonal.val.reshape(reshaper) reshaped_local_diagonal = to_np(diag.val).reshape(reshaper)
# here the actual multiplication takes place # here the actual multiplication takes place
return Field(x.domain, val=operation(reshaped_local_diagonal)(x.val)) return Field(x.domain, val=x.val*reshaped_local_diagonal)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment