From e1157422fb99e8a6a7b0e14fa5b2da705791a2f3 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Fri, 15 Jun 2018 00:37:58 +0200 Subject: [PATCH] Fixups --- nifty4/__init__.py | 2 +- nifty4/nonlinear_operators/nonlinear_operator.py | 13 +++++++------ nifty4/nonlinear_operators/selection_operator.py | 3 ++- nifty4/operators/linear_operator.py | 7 ++++++- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/nifty4/__init__.py b/nifty4/__init__.py index fcd050114..b58472d9d 100644 --- a/nifty4/__init__.py +++ b/nifty4/__init__.py @@ -4,8 +4,8 @@ from . import dobj from .domains import * from .domain_tuple import DomainTuple from .field import Field -from .operators import * from .nonlinear_operators import * +from .operators import * from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ StatCalculator from .minimization import * diff --git a/nifty4/nonlinear_operators/nonlinear_operator.py b/nifty4/nonlinear_operators/nonlinear_operator.py index d3554b15c..ee040119f 100644 --- a/nifty4/nonlinear_operators/nonlinear_operator.py +++ b/nifty4/nonlinear_operators/nonlinear_operator.py @@ -1,6 +1,5 @@ import nifty4 as ift -from ..operators import LinearOperator from .selection_operator import SelectionOperator @@ -25,7 +24,7 @@ class NonlinearOperator(object): def __getitem__(self, key): sel = SelectionOperator(self.value.domain, key) - return LinearModel(self.position, self, sel) + return sel(self) # TODO Support addition and multiplication with fields def __add__(self, other): @@ -49,6 +48,7 @@ class NonlinearOperator(object): raise NotImplementedError + def _joint_position(op1, op2): a = op1.position._val b = op2.position._val @@ -123,17 +123,18 @@ class ScalarMul(NonlinearOperator): class LinearModel(NonlinearOperator): - def __init__(self, position, inp, lin_op): + def __init__(self, inp, lin_op): """ Computes lin_op(inp) where lin_op is a Linear Operator """ - super(LinearModel, self).__init__(position) + from ..operators import LinearOperator + super(LinearModel, self).__init__(inp.position) if not isinstance(lin_op, LinearOperator): raise TypeError("needs a LinearOperator as input") - self._inp = inp.at(position) self._lin_op = lin_op + self._inp = inp # FIXME This is a dirty hack! if isinstance(self._lin_op, SelectionOperator): self._lin_op = SelectionOperator(self._inp.value.domain, @@ -143,4 +144,4 @@ class LinearModel(NonlinearOperator): self._gradient = self._lin_op*self._inp.gradient def at(self, position): - return self.__class__(position, self._inp, self._lin_op) + return self.__class__(self._inp.at(position), self._lin_op) diff --git a/nifty4/nonlinear_operators/selection_operator.py b/nifty4/nonlinear_operators/selection_operator.py index a11a302d3..af62db0df 100644 --- a/nifty4/nonlinear_operators/selection_operator.py +++ b/nifty4/nonlinear_operators/selection_operator.py @@ -1,10 +1,10 @@ -from ..multi import MultiDomain, MultiField from ..operators import LinearOperator from ..sugar import full class SelectionOperator(LinearOperator): def __init__(self, domain, key): + from ..multi import MultiDomain if not isinstance(domain, MultiDomain): raise TypeError("Domain must be a MultiDomain") self._target = domain[key] @@ -34,4 +34,5 @@ class SelectionOperator(LinearOperator): result[key] = full(val, 0.) else: result[key] = x.copy() + from ..multi import MultiField return MultiField(result) diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py index 8819d271f..9d202e4f8 100644 --- a/nifty4/operators/linear_operator.py +++ b/nifty4/operators/linear_operator.py @@ -17,9 +17,11 @@ # and financially supported by the Studienstiftung des deutschen Volkes. import abc -from ..utilities import NiftyMetaBase + import numpy as np +from ..utilities import NiftyMetaBase + class LinearOperator(NiftyMetaBase()): """NIFTY base class for linear operators. @@ -196,7 +198,10 @@ class LinearOperator(NiftyMetaBase()): raise NotImplementedError def __call__(self, x): + from ..nonlinear_operators import LinearModel, NonlinearOperator """Same as :meth:`times`""" + if isinstance(x, NonlinearOperator): + return LinearModel(x, self) return self.apply(x, self.TIMES) def times(self, x): -- GitLab