Commit 8eba48d3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent d752bd3b
Pipeline #18276 failed with stage
in 5 minutes and 31 seconds
......@@ -586,11 +586,8 @@ class Field(object):
# create a diagonal operator which is capable of taking care of the
# axes-matching
from .operators.diagonal_operator import DiagonalOperator
diagonal = y.val.conjugate()
diagonalOperator = DiagonalOperator(domain=y.domain,
diagonal=diagonal,
copy=False)
dotted = diagonalOperator(x, spaces=spaces)
diag = DiagonalOperator(y.domain, y.conjugate(), copy=False)
dotted = diag(x, spaces=spaces)
return fct*dotted.sum(spaces=spaces)
def norm(self):
......@@ -604,25 +601,16 @@ class Field(object):
"""
return np.sqrt(np.abs(self.vdot(x=self)))
def conjugate(self, inplace=False):
def conjugate(self):
""" Returns the complex conjugate of the field.
Parameters
----------
inplace : boolean
Decides whether the conjugation should be performed inplace.
Returns
-------
cc : field
The complex conjugated field.
"""
if inplace:
self.imag *= -1
return self
else:
return Field(self.domain, np.conj(self.val), self.dtype)
return Field(self.domain, self.val.conjugate(), self.dtype)
# ---General unary/contraction methods---
......
......@@ -188,7 +188,7 @@ class DiagonalOperator(EndomorphicOperator):
# do inverse weightening if the other way around
if bare:
# If `copy` is True, we won't change external data by weightening
# Otherwise, inplace weightening would change the external field
# Otherwise, inplace weighting would change the external field
f.weight(inplace=copy)
# Reset the self_adjoint property:
......
......@@ -122,37 +122,32 @@ class FFTOperator(LinearOperator):
self._backward_transformation = backward_class(
self.target[0], self.domain[0])
def _times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
def _times_helper(self, x, spaces, other, trafo):
if spaces is None:
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes = x.domain_axes[0]
result_domain = self.target
result_domain = other
else:
axes = x.domain_axes[spaces[0]]
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
result_domain = list(x.domain)
result_domain[spaces[0]] = self.target[0]
result_domain[spaces[0]] = other[0]
axes = x.domain_axes[spaces[0]]
new_val = self._forward_transformation.transform(x.val, axes=axes)
return Field(result_domain, new_val, copy=False)
new_val, fct = trafo.transform(x.val, axes=axes)
res = Field(result_domain, new_val, copy=False)
if fct != 1.:
res *= fct
return res
def _adjoint_times(self, x, spaces):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes = x.domain_axes[0]
result_domain = self.domain
else:
axes = x.domain_axes[spaces[0]]
result_domain = list(x.domain)
result_domain[spaces[0]] = self.domain[0]
def _times(self, x, spaces):
return self._times_helper(x, spaces, self.target,
self._forward_transformation)
new_val = self._backward_transformation.transform(x.val, axes=axes)
return Field(result_domain, new_val, copy=False)
def _adjoint_times(self, x, spaces):
return self._times_helper(x, spaces, self.domain,
self._backward_transformation)
# ---Mandatory properties and methods---
......
......@@ -113,8 +113,7 @@ class RGRGTransformation(Transformation):
else:
Tval = self._hartley(val, axes)
Tval *= fct
return Tval
return Tval, fct
class SlicingTransformation(Transformation):
......@@ -125,7 +124,7 @@ class SlicingTransformation(Transformation):
for slice in utilities.get_slice_list(val.shape, axes):
return_val[slice] = self._transformation_of_slice(val[slice])
return return_val
return return_val, 1.
def _transformation_of_slice(self, inp):
raise NotImplementedError
......
......@@ -86,6 +86,8 @@ class RGSpace(Space):
self._harmonic = bool(harmonic)
self._shape = self._parse_shape(shape)
self._distances = self._parse_distances(distances)
self._wgt = reduce(lambda x, y: x*y, self._distances)
self._dim = int(reduce(lambda x, y: x*y, self._shape))
def __repr__(self):
return ("RGSpace(shape=%r, distances=%r, harmonic=%r)"
......@@ -101,11 +103,11 @@ class RGSpace(Space):
@property
def dim(self):
return int(reduce(lambda x, y: x*y, self.shape))
return self._dim
@property
def total_volume(self):
return self.dim * reduce(lambda x, y: x*y, self.distances)
return self.dim * self._wgt
def copy(self):
return self.__class__(shape=self.shape,
......@@ -113,10 +115,10 @@ class RGSpace(Space):
harmonic=self.harmonic)
def scalar_weight(self):
return reduce(lambda x, y: x*y, self.distances)
return self._wgt
def weight(self):
return reduce(lambda x, y: x*y, self.distances)
return self._wgt
def get_distance_array(self):
""" Calculates an n-dimensional array with its entries being the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment