Commit 55c9fea0 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

first steps

parent 9740e2f3
......@@ -40,4 +40,4 @@ def WienerFilterCurvature(R, N, S, inverter):
The minimizer to use during numerical inversion
"""
op = SandwichOperator(R, N.inverse) + S.inverse
return InversionEnabler(op, inverter, S)
return InversionEnabler(op, inverter, S.inverse)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .linear_operator import LinearOperator
class AdjointOperator(LinearOperator):
"""Adapter class representing the adjoint of a given operator."""
def __init__(self, op):
super(AdjointOperator, self).__init__()
self._op = op
@property
def domain(self):
return self._op.target
@property
def target(self):
return self._op.domain
@property
def capability(self):
return self._adjointCapability[self._op.capability]
@property
def adjoint(self):
return self._op
def apply(self, x, mode):
return self._op.apply(x, self._adjointMode[mode])
......@@ -96,13 +96,15 @@ class ChainOperator(LinearOperator):
def target(self):
return self._ops[0].target
@property
def inverse(self):
return self.make([op.inverse for op in reversed(self._ops)])
@property
def adjoint(self):
return self.make([op.adjoint for op in reversed(self._ops)])
def _flip_modes(self, mode):
if mode == 0:
return self
if mode == 1 or mode == 2:
return self.make([op._flip_modes(mode)
for op in reversed(self._ops)])
if mode == 3:
return self.make([op._flip_modes(mode) for op in self._ops])
raise ValueError("bad operator flipping mode")
@property
def capability(self):
......
......@@ -158,18 +158,20 @@ class DiagonalOperator(EndomorphicOperator):
def capability(self):
return self._all_ops
@property
def inverse(self):
res = self._skeleton(())
res._ldiag = 1./self._ldiag
return res
@property
def adjoint(self):
if np.issubdtype(self._ldiag.dtype, np.floating):
def _flip_modes(self, mode):
if mode == 0:
return self
if mode == 1 and np.issubdtype(self._ldiag.dtype, np.floating):
return self
res = self._skeleton(())
res._ldiag = self._ldiag.conjugate()
if mode == 1:
res._ldiag = self._ldiag.conjugate()
elif mode == 2:
res._ldiag = 1./self._ldiag
elif mode == 3:
res._ldiag = 1./self._ldiag.conjugate()
else:
raise ValueError("bad operator flipping mode")
return res
def draw_sample(self, dtype=np.float64):
......
......@@ -66,12 +66,14 @@ class InversionEnabler(LinearOperator):
if self._op.capability & mode:
return self._op.apply(x, mode)
def func(x):
return self._op.apply(x, self._inverseMode[mode])
x0 = Field.zeros(self._tgt(mode), dtype=x.dtype)
energy = QuadraticEnergy(A=func, b=x, position=x0)
r, stat = self._inverter(energy, preconditioner=self._preconditioner)
invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]]
invop = self._op._flip_modes(self._ilog[invmode])
prec = self._preconditioner
if prec is not None:
prec = prec._flip_modes(self._ilog[mode])
energy = QuadraticEnergy(A=invop, b=x, position=x0)
r, stat = self._inverter(energy, preconditioner=prec)
if stat != IterationController.CONVERGED:
logger.warning("Error detected during operator inversion")
return r.position
......
......@@ -48,11 +48,16 @@ class LinearOperator(NiftyMetaBase()):
by means of a single integer number.
"""
_ilog = (-1, 0, 1, -1, 2, -1, -1, -1, 3)
_validMode = (False, True, True, False, True, False, False, False, True)
_inverseMode = (0, 4, 8, 0, 1, 0, 0, 0, 2)
_inverseCapability = (0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)
_adjointMode = (0, 2, 1, 0, 8, 0, 0, 0, 4)
_adjointCapability = (0, 2, 1, 3, 8, 10, 9, 11, 4, 6, 5, 7, 12, 14, 13, 15)
_modeTable = ((1, 2, 4, 8),
(2, 1, 8, 4),
(4, 8, 1, 2),
(8, 4, 2, 1))
_capTable = ((0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
(0, 2, 1, 3, 8, 10, 9, 11, 4, 6, 5, 7, 12, 14, 13, 15),
(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15),
(0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15))
_addInverse = (0, 5, 10, 15, 5, 5, 15, 15, 10, 15, 10, 15, 15, 15, 15, 15)
_backwards = 6
_all_ops = 15
......@@ -61,6 +66,8 @@ class LinearOperator(NiftyMetaBase()):
INVERSE_TIMES = 4
ADJOINT_INVERSE_TIMES = 8
INVERSE_ADJOINT_TIMES = 8
ADJOINT_BIT = 1
INVERSE_BIT = 2
def _dom(self, mode):
return self.domain if (mode & 9) else self.target
......@@ -85,14 +92,17 @@ class LinearOperator(NiftyMetaBase()):
The domain on which the Operator's output Field lives."""
raise NotImplementedError
def _flip_modes(self, mode):
from .operator_adapter import OperatorAdapter
return self if mode == 0 else OperatorAdapter(self, mode)
@property
def inverse(self):
"""LinearOperator : the inverse of `self`
Returns a LinearOperator object which behaves as if it were
the inverse of this operator."""
from .inverse_operator import InverseOperator
return InverseOperator(self)
return self._flip_modes(self.INVERSE_BIT)
@property
def adjoint(self):
......@@ -100,8 +110,7 @@ class LinearOperator(NiftyMetaBase()):
Returns a LinearOperator object which behaves as if it were
the adjoint of this operator."""
from .adjoint_operator import AdjointOperator
return AdjointOperator(self)
return self._flip_modes(self.ADJOINT_BIT)
@staticmethod
def _toOperator(thing, dom):
......
......@@ -20,34 +20,41 @@ from .linear_operator import LinearOperator
import numpy as np
class InverseOperator(LinearOperator):
"""Adapter class representing the inverse of a given operator."""
class OperatorAdapter(LinearOperator):
"""Class representing the inverse and/or adjoint of another operator."""
def __init__(self, op):
super(InverseOperator, self).__init__()
def __init__(self, op, mode):
super(OperatorAdapter, self).__init__()
self._op = op
self._mode = int(mode)
if self._mode < 1 or self._mode > 3:
raise ValueError("invalid mode")
@property
def domain(self):
return self._op.target
return self._op._dom(1 << self._mode)
@property
def target(self):
return self._op.domain
return self._op._tgt(1 << self._mode)
@property
def capability(self):
return self._inverseCapability[self._op.capability]
return self._capTable[self._mode][self._op.capability]
@property
def inverse(self):
return self._op
def _flip_modes(self, mode):
newmode = mode ^ self._mode
return self._op if newmode == 0 else OperatorAdapter(self._op, newmode)
def apply(self, x, mode):
return self._op.apply(x, self._inverseMode[mode])
return self._op.apply(x, self._modeTable[self._mode][self._ilog[mode]])
def draw_sample(self, dtype=np.float64):
return self._op.inverse_draw_sample(dtype)
if self._mode & self.INVERSE_BIT:
return self._op.inverse_draw_sample(dtype)
return self._op.draw_sample(dtype)
def inverse_draw_sample(self, dtype=np.float64):
return self._op.draw_sample(dtype)
if self._mode & self.INVERSE_BIT:
return self._op.draw_sample(dtype)
return self._op.inverse_draw_sample(dtype)
......@@ -72,18 +72,18 @@ class ScalingOperator(EndomorphicOperator):
else:
return x*(1./np.conj(self._factor))
@property
def inverse(self):
if self._factor != 0.:
return ScalingOperator(1./self._factor, self._domain)
from .inverse_operator import InverseOperator
return InverseOperator(self)
@property
def adjoint(self):
if np.issubdtype(type(self._factor), np.floating):
def _flip_modes(self, mode):
if mode == 0:
return self
return ScalingOperator(np.conj(self._factor), self._domain)
if mode == 1 and np.issubdtype(type(self._factor), np.floating):
return self
if mode == 1:
return ScalingOperator(np.conj(self._factor), self._domain)
elif mode == 2:
return ScalingOperator(1./self._factor, self._domain)
elif mode == 3:
return ScalingOperator(1./np.conj(self._factor), self._domain)
raise ValueError("bad operator flipping mode")
@property
def domain(self):
......@@ -91,8 +91,6 @@ class ScalingOperator(EndomorphicOperator):
@property
def capability(self):
if self._factor == 0.:
return self.TIMES | self.ADJOINT_TIMES
return self._all_ops
def _sample_helper(self, fct, dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment