Commit 51db693e authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'readability_consistent_renaming_mode-op_transform' into 'NIFTy_4'

Consistency: make the argument name of _flip_modes() the same for all operators

See merge request ift/NIFTy!247
parents d8c7cd5a 9eed1df9
Pipeline #28200 passed with stage
in 2 minutes and 23 seconds
......@@ -96,15 +96,18 @@ class ChainOperator(LinearOperator):
def target(self):
return self._ops[0].target
def _flip_modes(self, mode):
if mode == 0:
def _flip_modes(self, trafo):
ADJ = self.ADJOINT_BIT
INV = self.INVERSE_BIT
if trafo == 0:
return self
if mode == 1 or mode == 2:
return self.make([op._flip_modes(mode)
if trafo == ADJ or trafo == INV:
return self.make([op._flip_modes(trafo)
for op in reversed(self._ops)])
if mode == 3:
return self.make([op._flip_modes(mode) for op in self._ops])
raise ValueError("bad operator flipping mode")
if trafo == ADJ | INV:
return self.make([op._flip_modes(trafo) for op in self._ops])
raise ValueError("invalid operator transformation")
@property
def capability(self):
......
......@@ -158,20 +158,23 @@ class DiagonalOperator(EndomorphicOperator):
def capability(self):
return self._all_ops
def _flip_modes(self, mode):
if mode == 0:
def _flip_modes(self, trafo):
ADJ = self.ADJOINT_BIT
INV = self.INVERSE_BIT
if trafo == 0:
return self
if mode == 1 and np.issubdtype(self._ldiag.dtype, np.floating):
if trafo == ADJ and np.issubdtype(self._ldiag.dtype, np.floating):
return self
res = self._skeleton(())
if mode == 1:
if trafo == ADJ:
res._ldiag = self._ldiag.conjugate()
elif mode == 2:
elif trafo == INV:
res._ldiag = 1./self._ldiag
elif mode == 3:
elif trafo == ADJ | INV:
res._ldiag = 1./self._ldiag.conjugate()
else:
raise ValueError("bad operator flipping mode")
raise ValueError("invalid operator transformation")
return res
def draw_sample(self, from_inverse=False, dtype=np.float64):
......
......@@ -96,10 +96,9 @@ class LinearOperator(NiftyMetaBase()):
The domain on which the Operator's output Field lives."""
raise NotImplementedError
def _flip_modes(self, op_transform):
def _flip_modes(self, trafo):
from .operator_adapter import OperatorAdapter
return self if op_transform == 0 \
else OperatorAdapter(self, op_transform)
return self if trafo == 0 else OperatorAdapter(self, trafo)
@property
def inverse(self):
......
......@@ -28,7 +28,7 @@ class OperatorAdapter(LinearOperator):
self._op = op
self._trafo = int(op_transform)
if self._trafo < 1 or self._trafo > 3:
raise ValueError("invalid mode")
raise ValueError("invalid operator transformation")
@property
def domain(self):
......@@ -42,9 +42,10 @@ class OperatorAdapter(LinearOperator):
def capability(self):
return self._capTable[self._trafo][self._op.capability]
def _flip_modes(self, op_transform):
newmode = op_transform ^ self._trafo
return self._op if newmode == 0 else OperatorAdapter(self._op, newmode)
def _flip_modes(self, trafo):
newtrafo = trafo ^ self._trafo
return self._op if newtrafo == 0 \
else OperatorAdapter(self._op, newtrafo)
def apply(self, x, mode):
return self._op.apply(x,
......
......@@ -72,18 +72,21 @@ class ScalingOperator(EndomorphicOperator):
else:
return x*(1./np.conj(self._factor))
def _flip_modes(self, mode):
if mode == 0:
def _flip_modes(self, trafo):
ADJ = self.ADJOINT_BIT
INV = self.INVERSE_BIT
if trafo == 0:
return self
if mode == 1 and np.issubdtype(type(self._factor), np.floating):
if trafo == ADJ and np.issubdtype(type(self._factor), np.floating):
return self
if mode == 1:
if trafo == ADJ:
return ScalingOperator(np.conj(self._factor), self._domain)
elif mode == 2:
elif trafo == INV:
return ScalingOperator(1./self._factor, self._domain)
elif mode == 3:
elif trafo == ADJ | INV:
return ScalingOperator(1./np.conj(self._factor), self._domain)
raise ValueError("bad operator flipping mode")
raise ValueError("invalid operator transformation")
@property
def domain(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment