Commit c5f4f81f authored by Theo Steininger's avatar Theo Steininger
Browse files

Added InvertibleOperatorMixin to PropagatorOperator

parent 0f4a5c18
Pipeline #9762 failed with stages
in 27 minutes and 47 seconds
......@@ -227,7 +227,7 @@ class FFTW(Transform):
p = info.plan
# Load the value into the plan
if p.has_input:
p.input_array[:] = val
p.input_array[None] = val
# Execute the plan
p()
......
......@@ -6,48 +6,48 @@ from nifty.field import Field
class InvertibleOperatorMixin(object):
def __init__(self, inverter=None, preconditioner=None):
def __init__(self, inverter=None, preconditioner=None, *args, **kwargs):
self.__preconditioner = preconditioner
if inverter is not None:
self.__inverter = inverter
else:
self.__inverter = ConjugateGradient(
preconditioner=self.preconditioner)
preconditioner=self.__preconditioner)
def _times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.inverse_times,
b=x,
x0=x0)
(result, convergence) = self.__inverter(A=self.inverse_times,
b=x,
x0=x0)
return result
def _adjoint_times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.adjoint_inverse_times,
b=x,
x0=x0)
(result, convergence) = self.__inverter(A=self.adjoint_inverse_times,
b=x,
x0=x0)
return result
def _inverse_times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.times,
b=x,
x0=x0)
(result, convergence) = self.__inverter(A=self.times,
b=x,
x0=x0)
return result
def _adjoint_inverse_times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.adjoint_times,
b=x,
x0=x0)
(result, convergence) = self.__inverter(A=self.adjoint_times,
b=x,
x0=x0)
return result
def _inverse_adjoint_times(self, x, spaces, types):
......
# -*- coding: utf-8 -*-
from nifty.minimization import ConjugateGradient
from nifty.field import Field
from nifty.operators import EndomorphicOperator,\
FFTOperator
FFTOperator,\
InvertibleOperatorMixin
class PropagatorOperator(EndomorphicOperator):
class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods---
......@@ -49,13 +51,8 @@ class PropagatorOperator(EndomorphicOperator):
if preconditioner is None:
preconditioner = self._S_times
self.preconditioner = preconditioner
if inverter is not None:
self.inverter = inverter
else:
self.inverter = ConjugateGradient(
preconditioner=self.preconditioner)
super(PropagatorOperator, self).__init__(inverter=inverter,
preconditioner=preconditioner)
# ---Mandatory properties and methods---
......@@ -107,15 +104,6 @@ class PropagatorOperator(EndomorphicOperator):
result.set_val(transformed_y, copy=False)
return result
def _times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.target, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.inverse_times,
b=x,
x0=x0)
return result
def _inverse_times(self, x, spaces, types):
pre_result = self._S_inverse_times(x, spaces, types)
pre_result += self._likelihood_times(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment