diff --git a/nifty4/__init__.py b/nifty4/__init__.py index fcd050114264daa005f43894543332b9c9dca9ea..b58472d9decd32360473a7817bcc3448aba148b6 100644 --- a/nifty4/__init__.py +++ b/nifty4/__init__.py @@ -4,8 +4,8 @@ from . import dobj from .domains import * from .domain_tuple import DomainTuple from .field import Field -from .operators import * from .nonlinear_operators import * +from .operators import * from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ StatCalculator from .minimization import * diff --git a/nifty4/nonlinear_operators/nonlinear_operator.py b/nifty4/nonlinear_operators/nonlinear_operator.py index d3554b15c711ea26952cb04bcd51d5d5b40150bc..ee040119f7cdae2735252dfa4504f18920179019 100644 --- a/nifty4/nonlinear_operators/nonlinear_operator.py +++ b/nifty4/nonlinear_operators/nonlinear_operator.py @@ -1,6 +1,5 @@ import nifty4 as ift -from ..operators import LinearOperator from .selection_operator import SelectionOperator @@ -25,7 +24,7 @@ class NonlinearOperator(object): def __getitem__(self, key): sel = SelectionOperator(self.value.domain, key) - return LinearModel(self.position, self, sel) + return sel(self) # TODO Support addition and multiplication with fields def __add__(self, other): @@ -49,6 +48,7 @@ class NonlinearOperator(object): raise NotImplementedError + def _joint_position(op1, op2): a = op1.position._val b = op2.position._val @@ -123,17 +123,18 @@ class ScalarMul(NonlinearOperator): class LinearModel(NonlinearOperator): - def __init__(self, position, inp, lin_op): + def __init__(self, inp, lin_op): """ Computes lin_op(inp) where lin_op is a Linear Operator """ - super(LinearModel, self).__init__(position) + from ..operators import LinearOperator + super(LinearModel, self).__init__(inp.position) if not isinstance(lin_op, LinearOperator): raise TypeError("needs a LinearOperator as input") - self._inp = inp.at(position) self._lin_op = lin_op + self._inp = inp # FIXME This is a dirty hack! if isinstance(self._lin_op, SelectionOperator): self._lin_op = SelectionOperator(self._inp.value.domain, @@ -143,4 +144,4 @@ class LinearModel(NonlinearOperator): self._gradient = self._lin_op*self._inp.gradient def at(self, position): - return self.__class__(position, self._inp, self._lin_op) + return self.__class__(self._inp.at(position), self._lin_op) diff --git a/nifty4/nonlinear_operators/selection_operator.py b/nifty4/nonlinear_operators/selection_operator.py index a11a302d3119ce06afbab9b70035b4f48cf560dd..af62db0df6ca76f2c2fb532c95712a9d20dec422 100644 --- a/nifty4/nonlinear_operators/selection_operator.py +++ b/nifty4/nonlinear_operators/selection_operator.py @@ -1,10 +1,10 @@ -from ..multi import MultiDomain, MultiField from ..operators import LinearOperator from ..sugar import full class SelectionOperator(LinearOperator): def __init__(self, domain, key): + from ..multi import MultiDomain if not isinstance(domain, MultiDomain): raise TypeError("Domain must be a MultiDomain") self._target = domain[key] @@ -34,4 +34,5 @@ class SelectionOperator(LinearOperator): result[key] = full(val, 0.) else: result[key] = x.copy() + from ..multi import MultiField return MultiField(result) diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py index 8819d271fef4bab579598158a7f20ba0948ee768..9d202e4f8e503420e71d0a123344d062ac23f25f 100644 --- a/nifty4/operators/linear_operator.py +++ b/nifty4/operators/linear_operator.py @@ -17,9 +17,11 @@ # and financially supported by the Studienstiftung des deutschen Volkes. import abc -from ..utilities import NiftyMetaBase + import numpy as np +from ..utilities import NiftyMetaBase + class LinearOperator(NiftyMetaBase()): """NIFTY base class for linear operators. @@ -196,7 +198,10 @@ class LinearOperator(NiftyMetaBase()): raise NotImplementedError def __call__(self, x): + from ..nonlinear_operators import LinearModel, NonlinearOperator """Same as :meth:`times`""" + if isinstance(x, NonlinearOperator): + return LinearModel(x, self) return self.apply(x, self.TIMES) def times(self, x):