diff --git a/nifty4/operators/chain_operator.py b/nifty4/operators/chain_operator.py index 9bc205bf8a9fc3cf4f7209085848921d1373005f..965dbb156fff49aababa205d2dd286645cf014a2 100644 --- a/nifty4/operators/chain_operator.py +++ b/nifty4/operators/chain_operator.py @@ -96,15 +96,18 @@ class ChainOperator(LinearOperator): def target(self): return self._ops[0].target - def _flip_modes(self, mode): - if mode == 0: + def _flip_modes(self, trafo): + ADJ = self.ADJOINT_BIT + INV = self.INVERSE_BIT + + if trafo == 0: return self - if mode == 1 or mode == 2: - return self.make([op._flip_modes(mode) + if trafo == ADJ or trafo == INV: + return self.make([op._flip_modes(trafo) for op in reversed(self._ops)]) - if mode == 3: - return self.make([op._flip_modes(mode) for op in self._ops]) - raise ValueError("bad operator flipping mode") + if trafo == ADJ | INV: + return self.make([op._flip_modes(trafo) for op in self._ops]) + raise ValueError("invalid operator transformation") @property def capability(self): diff --git a/nifty4/operators/diagonal_operator.py b/nifty4/operators/diagonal_operator.py index 9c36522262c44a753a0ea7501a0a8b62bbb0347e..5d68227a43c1b1b826eb1e6f77b14b478ec4ef5b 100644 --- a/nifty4/operators/diagonal_operator.py +++ b/nifty4/operators/diagonal_operator.py @@ -158,20 +158,23 @@ class DiagonalOperator(EndomorphicOperator): def capability(self): return self._all_ops - def _flip_modes(self, mode): - if mode == 0: + def _flip_modes(self, trafo): + ADJ = self.ADJOINT_BIT + INV = self.INVERSE_BIT + + if trafo == 0: return self - if mode == 1 and np.issubdtype(self._ldiag.dtype, np.floating): + if trafo == ADJ and np.issubdtype(self._ldiag.dtype, np.floating): return self res = self._skeleton(()) - if mode == 1: + if trafo == ADJ: res._ldiag = self._ldiag.conjugate() - elif mode == 2: + elif trafo == INV: res._ldiag = 1./self._ldiag - elif mode == 3: + elif trafo == ADJ | INV: res._ldiag = 1./self._ldiag.conjugate() else: - raise ValueError("bad operator flipping mode") + raise ValueError("invalid operator transformation") return res def draw_sample(self, from_inverse=False, dtype=np.float64): diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py index 97dd6ee5bbbb2988bfbcfd7a0177ee6f36d9a87d..2d7b333ff7a0c639e17f4c9cc7d402ecbe1f7e46 100644 --- a/nifty4/operators/linear_operator.py +++ b/nifty4/operators/linear_operator.py @@ -96,10 +96,9 @@ class LinearOperator(NiftyMetaBase()): The domain on which the Operator's output Field lives.""" raise NotImplementedError - def _flip_modes(self, op_transform): + def _flip_modes(self, trafo): from .operator_adapter import OperatorAdapter - return self if op_transform == 0 \ - else OperatorAdapter(self, op_transform) + return self if trafo == 0 else OperatorAdapter(self, trafo) @property def inverse(self): diff --git a/nifty4/operators/operator_adapter.py b/nifty4/operators/operator_adapter.py index 01bfb79e336737740ba4d829975049cb4196b3ef..c722623f1159bbb8b21b24a6c4dc153286e05899 100644 --- a/nifty4/operators/operator_adapter.py +++ b/nifty4/operators/operator_adapter.py @@ -28,7 +28,7 @@ class OperatorAdapter(LinearOperator): self._op = op self._trafo = int(op_transform) if self._trafo < 1 or self._trafo > 3: - raise ValueError("invalid mode") + raise ValueError("invalid operator transformation") @property def domain(self): @@ -42,9 +42,10 @@ class OperatorAdapter(LinearOperator): def capability(self): return self._capTable[self._trafo][self._op.capability] - def _flip_modes(self, op_transform): - newmode = op_transform ^ self._trafo - return self._op if newmode == 0 else OperatorAdapter(self._op, newmode) + def _flip_modes(self, trafo): + newtrafo = trafo ^ self._trafo + return self._op if newtrafo == 0 \ + else OperatorAdapter(self._op, newtrafo) def apply(self, x, mode): return self._op.apply(x, diff --git a/nifty4/operators/scaling_operator.py b/nifty4/operators/scaling_operator.py index cce8fee4bf63b78f3d543c7b45149159cb9aeedd..c802b050f6322c90646340f495f440d93e9b8e67 100644 --- a/nifty4/operators/scaling_operator.py +++ b/nifty4/operators/scaling_operator.py @@ -72,18 +72,21 @@ class ScalingOperator(EndomorphicOperator): else: return x*(1./np.conj(self._factor)) - def _flip_modes(self, mode): - if mode == 0: + def _flip_modes(self, trafo): + ADJ = self.ADJOINT_BIT + INV = self.INVERSE_BIT + + if trafo == 0: return self - if mode == 1 and np.issubdtype(type(self._factor), np.floating): + if trafo == ADJ and np.issubdtype(type(self._factor), np.floating): return self - if mode == 1: + if trafo == ADJ: return ScalingOperator(np.conj(self._factor), self._domain) - elif mode == 2: + elif trafo == INV: return ScalingOperator(1./self._factor, self._domain) - elif mode == 3: + elif trafo == ADJ | INV: return ScalingOperator(1./np.conj(self._factor), self._domain) - raise ValueError("bad operator flipping mode") + raise ValueError("invalid operator transformation") @property def domain(self):