Commit 9dbf2577 authored by Theo Steininger's avatar Theo Steininger

Adding forward_x0 and backward_x0 to InvertibleOperatorMixin

parent c034e8f1
......@@ -21,7 +21,7 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, theta, T, inverter=None, preconditioner=None):
def __init__(self, theta, T, inverter=None, preconditioner=None, **kwargs):
self.theta = DiagonalOperator(theta.domain, diagonal=theta)
self.T = T
......@@ -30,7 +30,8 @@ class CriticalPowerCurvature(InvertibleOperatorMixin, EndomorphicOperator):
self._domain = self.theta.domain
super(CriticalPowerCurvature, self).__init__(
inverter=inverter,
preconditioner=preconditioner)
preconditioner=preconditioner,
**kwargs)
def _times(self, x, spaces):
return self.T(x) + self.theta(x)
......
......@@ -22,7 +22,7 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
"""
def __init__(self, R, N, S, inverter=None, preconditioner=None):
def __init__(self, R, N, S, inverter=None, preconditioner=None, **kwargs):
self.R = R
self.N = N
......@@ -32,7 +32,8 @@ class WienerFilterCurvature(InvertibleOperatorMixin, EndomorphicOperator):
self._domain = self.S.domain
super(WienerFilterCurvature, self).__init__(
inverter=inverter,
preconditioner=preconditioner)
preconditioner=preconditioner,
**kwargs)
@property
def domain(self):
......
......@@ -60,17 +60,23 @@ class InvertibleOperatorMixin(object):
"""
def __init__(self, inverter=None, preconditioner=None, *args, **kwargs):
def __init__(self, inverter=None, preconditioner=None,
forward_x0=None, backward_x0=None, *args, **kwargs):
self.__preconditioner = preconditioner
if inverter is not None:
self.__inverter = inverter
else:
self.__inverter = ConjugateGradient(
preconditioner=self.__preconditioner)
self.__forward_x0 = forward_x0
self.__backward_x0 = backward_x0
super(InvertibleOperatorMixin, self).__init__(*args, **kwargs)
def _times(self, x, spaces, x0=None):
if x0 is None:
def _times(self, x, spaces):
if self.__forward_x0 is not None:
x0 = self.__forward_x0
else:
x0 = Field(self.target, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
......@@ -79,8 +85,10 @@ class InvertibleOperatorMixin(object):
x0=x0)
return result
def _adjoint_times(self, x, spaces, x0=None):
if x0 is None:
def _adjoint_times(self, x, spaces):
if self.__backward_x0 is not None:
x0 = self.__backward_x0
else:
x0 = Field(self.domain, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
......@@ -89,8 +97,10 @@ class InvertibleOperatorMixin(object):
x0=x0)
return result
def _inverse_times(self, x, spaces, x0=None):
if x0 is None:
def _inverse_times(self, x, spaces):
if self.__backward_x0 is not None:
x0 = self.__backward_x0
else:
x0 = Field(self.domain, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
......@@ -99,8 +109,10 @@ class InvertibleOperatorMixin(object):
x0=x0)
return result
def _adjoint_inverse_times(self, x, spaces, x0=None):
if x0 is None:
def _adjoint_inverse_times(self, x, spaces):
if self.__forward_x0 is not None:
x0 = self.__forward_x0
else:
x0 = Field(self.target, val=0., dtype=x.dtype,
distribution_strategy=x.distribution_strategy)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment