Commit 9eed1df9 authored by Lukas Platz's avatar Lukas Platz

consistency: make the argument name of _flip_modes() the same for all ops

Commit a141d2bf "resolve degenerate use of 'mode' in OperatorAdapter
and LinearOperator" was pushed without ensuring consistency with
other operator definitons and correct exception strings. This commit
aims to resolve this.

The argument variable of _flip_modes(…) is now always called 'trafo'
and the Exception thrown for invalid values of it is
ValueError("invalid operator transformation"). To aid readability,
numerical constants were replaced with correspondingly named
constants (ADJ, INV).
parent d8c7cd5a
Pipeline #28018 passed with stage
in 22 minutes
......@@ -96,15 +96,18 @@ class ChainOperator(LinearOperator):
def target(self):
return self._ops[0].target
def _flip_modes(self, mode):
if mode == 0:
def _flip_modes(self, trafo):
ADJ = self.ADJOINT_BIT
INV = self.INVERSE_BIT
if trafo == 0:
return self
if mode == 1 or mode == 2:
return self.make([op._flip_modes(mode)
if trafo == ADJ or trafo == INV:
return self.make([op._flip_modes(trafo)
for op in reversed(self._ops)])
if mode == 3:
return self.make([op._flip_modes(mode) for op in self._ops])
raise ValueError("bad operator flipping mode")
if trafo == ADJ | INV:
return self.make([op._flip_modes(trafo) for op in self._ops])
raise ValueError("invalid operator transformation")
@property
def capability(self):
......
......@@ -158,20 +158,23 @@ class DiagonalOperator(EndomorphicOperator):
def capability(self):
return self._all_ops
def _flip_modes(self, mode):
if mode == 0:
def _flip_modes(self, trafo):
ADJ = self.ADJOINT_BIT
INV = self.INVERSE_BIT
if trafo == 0:
return self
if mode == 1 and np.issubdtype(self._ldiag.dtype, np.floating):
if trafo == ADJ and np.issubdtype(self._ldiag.dtype, np.floating):
return self
res = self._skeleton(())
if mode == 1:
if trafo == ADJ:
res._ldiag = self._ldiag.conjugate()
elif mode == 2:
elif trafo == INV:
res._ldiag = 1./self._ldiag
elif mode == 3:
elif trafo == ADJ | INV:
res._ldiag = 1./self._ldiag.conjugate()
else:
raise ValueError("bad operator flipping mode")
raise ValueError("invalid operator transformation")
return res
def draw_sample(self, from_inverse=False, dtype=np.float64):
......
......@@ -96,10 +96,9 @@ class LinearOperator(NiftyMetaBase()):
The domain on which the Operator's output Field lives."""
raise NotImplementedError
def _flip_modes(self, op_transform):
def _flip_modes(self, trafo):
from .operator_adapter import OperatorAdapter
return self if op_transform == 0 \
else OperatorAdapter(self, op_transform)
return self if trafo == 0 else OperatorAdapter(self, trafo)
@property
def inverse(self):
......
......@@ -28,7 +28,7 @@ class OperatorAdapter(LinearOperator):
self._op = op
self._trafo = int(op_transform)
if self._trafo < 1 or self._trafo > 3:
raise ValueError("invalid mode")
raise ValueError("invalid operator transformation")
@property
def domain(self):
......@@ -42,9 +42,10 @@ class OperatorAdapter(LinearOperator):
def capability(self):
return self._capTable[self._trafo][self._op.capability]
def _flip_modes(self, op_transform):
newmode = op_transform ^ self._trafo
return self._op if newmode == 0 else OperatorAdapter(self._op, newmode)
def _flip_modes(self, trafo):
newtrafo = trafo ^ self._trafo
return self._op if newtrafo == 0 \
else OperatorAdapter(self._op, newtrafo)
def apply(self, x, mode):
return self._op.apply(x,
......
......@@ -72,18 +72,21 @@ class ScalingOperator(EndomorphicOperator):
else:
return x*(1./np.conj(self._factor))
def _flip_modes(self, mode):
if mode == 0:
def _flip_modes(self, trafo):
ADJ = self.ADJOINT_BIT
INV = self.INVERSE_BIT
if trafo == 0:
return self
if mode == 1 and np.issubdtype(type(self._factor), np.floating):
if trafo == ADJ and np.issubdtype(type(self._factor), np.floating):
return self
if mode == 1:
if trafo == ADJ:
return ScalingOperator(np.conj(self._factor), self._domain)
elif mode == 2:
elif trafo == INV:
return ScalingOperator(1./self._factor, self._domain)
elif mode == 3:
elif trafo == ADJ | INV:
return ScalingOperator(1./np.conj(self._factor), self._domain)
raise ValueError("bad operator flipping mode")
raise ValueError("invalid operator transformation")
@property
def domain(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment