Commit a141d2bf authored by Lukas Platz's avatar Lukas Platz
Browse files

resolve degenerate use of 'mode' in OperatorAdapter and LinearOperator

'mode' was used for field operator modes and operator transformation types
concurrently, which makes unclear what type of input the function operates
on. To resolve this, operator transformation type variables were renamed
to 'op_transform' and _trafo, making the distinction easy and obvious.
parent 8bf61ef1
Pipeline #27468 passed with stage
in 1 minute and 28 seconds
......@@ -48,6 +48,17 @@ class LinearOperator(NiftyMetaBase()):
by means of a single integer number.
"""
# Field Operator Modes
TIMES = 1
ADJOINT_TIMES = 2
INVERSE_TIMES = 4
ADJOINT_INVERSE_TIMES = 8
INVERSE_ADJOINT_TIMES = 8
# Operator Transform Flags
ADJOINT_BIT = 1
INVERSE_BIT = 2
_ilog = (-1, 0, 1, -1, 2, -1, -1, -1, 3)
_validMode = (False, True, True, False, True, False, False, False, True)
_modeTable = ((1, 2, 4, 8),
......@@ -61,13 +72,6 @@ class LinearOperator(NiftyMetaBase()):
_addInverse = (0, 5, 10, 15, 5, 5, 15, 15, 10, 15, 10, 15, 15, 15, 15, 15)
_backwards = 6
_all_ops = 15
TIMES = 1
ADJOINT_TIMES = 2
INVERSE_TIMES = 4
ADJOINT_INVERSE_TIMES = 8
INVERSE_ADJOINT_TIMES = 8
ADJOINT_BIT = 1
INVERSE_BIT = 2
def _dom(self, mode):
return self.domain if (mode & 9) else self.target
......@@ -92,9 +96,9 @@ class LinearOperator(NiftyMetaBase()):
The domain on which the Operator's output Field lives."""
raise NotImplementedError
def _flip_modes(self, mode):
def _flip_modes(self, op_transform):
from .operator_adapter import OperatorAdapter
return self if mode == 0 else OperatorAdapter(self, mode)
return self if op_transform == 0 else OperatorAdapter(self, op_transform)
@property
def inverse(self):
......
......@@ -23,33 +23,33 @@ import numpy as np
class OperatorAdapter(LinearOperator):
"""Class representing the inverse and/or adjoint of another operator."""
def __init__(self, op, mode):
def __init__(self, op, op_transform):
super(OperatorAdapter, self).__init__()
self._op = op
self._mode = int(mode)
if self._mode < 1 or self._mode > 3:
self._trafo = int(op_transform)
if self._trafo < 1 or self._trafo > 3:
raise ValueError("invalid mode")
@property
def domain(self):
return self._op._dom(1 << self._mode)
return self._op._dom(1 << self._trafo)
@property
def target(self):
return self._op._tgt(1 << self._mode)
return self._op._tgt(1 << self._trafo)
@property
def capability(self):
return self._capTable[self._mode][self._op.capability]
return self._capTable[self._trafo][self._op.capability]
def _flip_modes(self, mode):
newmode = mode ^ self._mode
def _flip_modes(self, op_transform):
newmode = op_transform ^ self._trafo
return self._op if newmode == 0 else OperatorAdapter(self._op, newmode)
def apply(self, x, mode):
return self._op.apply(x, self._modeTable[self._mode][self._ilog[mode]])
return self._op.apply(x, self._modeTable[self._trafo][self._ilog[mode]])
def draw_sample(self, from_inverse=False, dtype=np.float64):
if self._mode & self.INVERSE_BIT:
if self._trafo & self.INVERSE_BIT:
return self._op.draw_sample(not from_inverse, dtype)
return self._op.draw_sample(from_inverse, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment