From 9eed1df9f6ebbe18604fe6c7c591e7c23015f2df Mon Sep 17 00:00:00 2001
From: Lukas Platz <lplatz@mpa-garching.mpg.de>
Date: Fri, 20 Apr 2018 17:51:57 +0200
Subject: [PATCH] consistency: make the argument name of _flip_modes() the same
 for all ops
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Commit a141d2bf "resolve degenerate use of 'mode' in OperatorAdapter
and LinearOperator" was pushed without ensuring consistency with
other operator definitons and correct exception strings. This commit
aims to resolve this.

The argument variable of _flip_modes(…) is now always called 'trafo'
and the Exception thrown for invalid values of it is
ValueError("invalid operator transformation"). To aid readability,
numerical constants were replaced with correspondingly named
constants (ADJ, INV).
---
 nifty4/operators/chain_operator.py    | 17 ++++++++++-------
 nifty4/operators/diagonal_operator.py | 17 ++++++++++-------
 nifty4/operators/linear_operator.py   |  5 ++---
 nifty4/operators/operator_adapter.py  |  9 +++++----
 nifty4/operators/scaling_operator.py  | 17 ++++++++++-------
 5 files changed, 37 insertions(+), 28 deletions(-)

diff --git a/nifty4/operators/chain_operator.py b/nifty4/operators/chain_operator.py
index 9bc205bf8..965dbb156 100644
--- a/nifty4/operators/chain_operator.py
+++ b/nifty4/operators/chain_operator.py
@@ -96,15 +96,18 @@ class ChainOperator(LinearOperator):
     def target(self):
         return self._ops[0].target
 
-    def _flip_modes(self, mode):
-        if mode == 0:
+    def _flip_modes(self, trafo):
+        ADJ = self.ADJOINT_BIT
+        INV = self.INVERSE_BIT
+
+        if trafo == 0:
             return self
-        if mode == 1 or mode == 2:
-            return self.make([op._flip_modes(mode)
+        if trafo == ADJ or trafo == INV:
+            return self.make([op._flip_modes(trafo)
                               for op in reversed(self._ops)])
-        if mode == 3:
-            return self.make([op._flip_modes(mode) for op in self._ops])
-        raise ValueError("bad operator flipping mode")
+        if trafo == ADJ | INV:
+            return self.make([op._flip_modes(trafo) for op in self._ops])
+        raise ValueError("invalid operator transformation")
 
     @property
     def capability(self):
diff --git a/nifty4/operators/diagonal_operator.py b/nifty4/operators/diagonal_operator.py
index 9c3652226..5d68227a4 100644
--- a/nifty4/operators/diagonal_operator.py
+++ b/nifty4/operators/diagonal_operator.py
@@ -158,20 +158,23 @@ class DiagonalOperator(EndomorphicOperator):
     def capability(self):
         return self._all_ops
 
-    def _flip_modes(self, mode):
-        if mode == 0:
+    def _flip_modes(self, trafo):
+        ADJ = self.ADJOINT_BIT
+        INV = self.INVERSE_BIT
+
+        if trafo == 0:
             return self
-        if mode == 1 and np.issubdtype(self._ldiag.dtype, np.floating):
+        if trafo == ADJ and np.issubdtype(self._ldiag.dtype, np.floating):
             return self
         res = self._skeleton(())
-        if mode == 1:
+        if trafo == ADJ:
             res._ldiag = self._ldiag.conjugate()
-        elif mode == 2:
+        elif trafo == INV:
             res._ldiag = 1./self._ldiag
-        elif mode == 3:
+        elif trafo == ADJ | INV:
             res._ldiag = 1./self._ldiag.conjugate()
         else:
-            raise ValueError("bad operator flipping mode")
+            raise ValueError("invalid operator transformation")
         return res
 
     def draw_sample(self, from_inverse=False, dtype=np.float64):
diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py
index 97dd6ee5b..2d7b333ff 100644
--- a/nifty4/operators/linear_operator.py
+++ b/nifty4/operators/linear_operator.py
@@ -96,10 +96,9 @@ class LinearOperator(NiftyMetaBase()):
             The domain on which the Operator's output Field lives."""
         raise NotImplementedError
 
-    def _flip_modes(self, op_transform):
+    def _flip_modes(self, trafo):
         from .operator_adapter import OperatorAdapter
-        return self if op_transform == 0 \
-            else OperatorAdapter(self, op_transform)
+        return self if trafo == 0 else OperatorAdapter(self, trafo)
 
     @property
     def inverse(self):
diff --git a/nifty4/operators/operator_adapter.py b/nifty4/operators/operator_adapter.py
index 01bfb79e3..c722623f1 100644
--- a/nifty4/operators/operator_adapter.py
+++ b/nifty4/operators/operator_adapter.py
@@ -28,7 +28,7 @@ class OperatorAdapter(LinearOperator):
         self._op = op
         self._trafo = int(op_transform)
         if self._trafo < 1 or self._trafo > 3:
-            raise ValueError("invalid mode")
+            raise ValueError("invalid operator transformation")
 
     @property
     def domain(self):
@@ -42,9 +42,10 @@ class OperatorAdapter(LinearOperator):
     def capability(self):
         return self._capTable[self._trafo][self._op.capability]
 
-    def _flip_modes(self, op_transform):
-        newmode = op_transform ^ self._trafo
-        return self._op if newmode == 0 else OperatorAdapter(self._op, newmode)
+    def _flip_modes(self, trafo):
+        newtrafo = trafo ^ self._trafo
+        return self._op if newtrafo == 0 \
+            else OperatorAdapter(self._op, newtrafo)
 
     def apply(self, x, mode):
         return self._op.apply(x,
diff --git a/nifty4/operators/scaling_operator.py b/nifty4/operators/scaling_operator.py
index cce8fee4b..c802b050f 100644
--- a/nifty4/operators/scaling_operator.py
+++ b/nifty4/operators/scaling_operator.py
@@ -72,18 +72,21 @@ class ScalingOperator(EndomorphicOperator):
         else:
             return x*(1./np.conj(self._factor))
 
-    def _flip_modes(self, mode):
-        if mode == 0:
+    def _flip_modes(self, trafo):
+        ADJ = self.ADJOINT_BIT
+        INV = self.INVERSE_BIT
+
+        if trafo == 0:
             return self
-        if mode == 1 and np.issubdtype(type(self._factor), np.floating):
+        if trafo == ADJ and np.issubdtype(type(self._factor), np.floating):
             return self
-        if mode == 1:
+        if trafo == ADJ:
             return ScalingOperator(np.conj(self._factor), self._domain)
-        elif mode == 2:
+        elif trafo == INV:
             return ScalingOperator(1./self._factor, self._domain)
-        elif mode == 3:
+        elif trafo == ADJ | INV:
             return ScalingOperator(1./np.conj(self._factor), self._domain)
-        raise ValueError("bad operator flipping mode")
+        raise ValueError("invalid operator transformation")
 
     @property
     def domain(self):
-- 
GitLab