Commit 0f4a5c18 authored by Theo Steininger's avatar Theo Steininger
Browse files

Added invertible_operator_mixin.

parent a4b43a07
Pipeline #9760 failed with stages
in 29 minutes and 54 seconds
......@@ -31,6 +31,8 @@ from smoothing_operator import SmoothingOperator
from fft_operator import *
from invertible_operator_mixin import InvertibleOperatorMixin
from propagator_operator import PropagatorOperator
from composed_operator import ComposedOperator
# -*- coding: utf-8 -*-
from invertible_operator_mixin import InvertibleOperatorMixin
\ No newline at end of file
# -*- coding: utf-8 -*-
from nifty.minimization import ConjugateGradient
from nifty.field import Field
class InvertibleOperatorMixin(object):
def __init__(self, inverter=None, preconditioner=None):
self.__preconditioner = preconditioner
if inverter is not None:
self.__inverter = inverter
else:
self.__inverter = ConjugateGradient(
preconditioner=self.preconditioner)
def _times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.inverse_times,
b=x,
x0=x0)
return result
def _adjoint_times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.adjoint_inverse_times,
b=x,
x0=x0)
return result
def _inverse_times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.times,
b=x,
x0=x0)
return result
def _adjoint_inverse_times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.adjoint_times,
b=x,
x0=x0)
return result
def _inverse_adjoint_times(self, x, spaces, types):
raise NotImplementedError(
"no generic instance method 'inverse_adjoint_times'.")
......@@ -109,7 +109,7 @@ class PropagatorOperator(EndomorphicOperator):
def _times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.domain, val=0., dtype=x.dtype)
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.inverse_times,
b=x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment