From e1157422fb99e8a6a7b0e14fa5b2da705791a2f3 Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Fri, 15 Jun 2018 00:37:58 +0200
Subject: [PATCH] Fixups

---
 nifty4/__init__.py                               |  2 +-
 nifty4/nonlinear_operators/nonlinear_operator.py | 13 +++++++------
 nifty4/nonlinear_operators/selection_operator.py |  3 ++-
 nifty4/operators/linear_operator.py              |  7 ++++++-
 4 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/nifty4/__init__.py b/nifty4/__init__.py
index fcd050114..b58472d9d 100644
--- a/nifty4/__init__.py
+++ b/nifty4/__init__.py
@@ -4,8 +4,8 @@ from . import dobj
 from .domains import *
 from .domain_tuple import DomainTuple
 from .field import Field
-from .operators import *
 from .nonlinear_operators import *
+from .operators import *
 from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
     StatCalculator
 from .minimization import *
diff --git a/nifty4/nonlinear_operators/nonlinear_operator.py b/nifty4/nonlinear_operators/nonlinear_operator.py
index d3554b15c..ee040119f 100644
--- a/nifty4/nonlinear_operators/nonlinear_operator.py
+++ b/nifty4/nonlinear_operators/nonlinear_operator.py
@@ -1,6 +1,5 @@
 import nifty4 as ift
 
-from ..operators import LinearOperator
 from .selection_operator import SelectionOperator
 
 
@@ -25,7 +24,7 @@ class NonlinearOperator(object):
 
     def __getitem__(self, key):
         sel = SelectionOperator(self.value.domain, key)
-        return LinearModel(self.position, self, sel)
+        return sel(self)
 
     # TODO Support addition and multiplication with fields
     def __add__(self, other):
@@ -49,6 +48,7 @@ class NonlinearOperator(object):
         raise NotImplementedError
 
 
+
 def _joint_position(op1, op2):
     a = op1.position._val
     b = op2.position._val
@@ -123,17 +123,18 @@ class ScalarMul(NonlinearOperator):
 
 
 class LinearModel(NonlinearOperator):
-    def __init__(self, position, inp, lin_op):
+    def __init__(self, inp, lin_op):
         """
         Computes lin_op(inp) where lin_op is a Linear Operator
         """
-        super(LinearModel, self).__init__(position)
+        from ..operators import LinearOperator
+        super(LinearModel, self).__init__(inp.position)
 
         if not isinstance(lin_op, LinearOperator):
             raise TypeError("needs a LinearOperator as input")
 
-        self._inp = inp.at(position)
         self._lin_op = lin_op
+        self._inp = inp
         # FIXME This is a dirty hack!
         if isinstance(self._lin_op, SelectionOperator):
             self._lin_op = SelectionOperator(self._inp.value.domain,
@@ -143,4 +144,4 @@ class LinearModel(NonlinearOperator):
         self._gradient = self._lin_op*self._inp.gradient
 
     def at(self, position):
-        return self.__class__(position, self._inp, self._lin_op)
+        return self.__class__(self._inp.at(position), self._lin_op)
diff --git a/nifty4/nonlinear_operators/selection_operator.py b/nifty4/nonlinear_operators/selection_operator.py
index a11a302d3..af62db0df 100644
--- a/nifty4/nonlinear_operators/selection_operator.py
+++ b/nifty4/nonlinear_operators/selection_operator.py
@@ -1,10 +1,10 @@
-from ..multi import MultiDomain, MultiField
 from ..operators import LinearOperator
 from ..sugar import full
 
 
 class SelectionOperator(LinearOperator):
     def __init__(self, domain, key):
+        from ..multi import MultiDomain
         if not isinstance(domain, MultiDomain):
             raise TypeError("Domain must be a MultiDomain")
         self._target = domain[key]
@@ -34,4 +34,5 @@ class SelectionOperator(LinearOperator):
                     result[key] = full(val, 0.)
                 else:
                     result[key] = x.copy()
+            from ..multi import MultiField
             return MultiField(result)
diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py
index 8819d271f..9d202e4f8 100644
--- a/nifty4/operators/linear_operator.py
+++ b/nifty4/operators/linear_operator.py
@@ -17,9 +17,11 @@
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
 import abc
-from ..utilities import NiftyMetaBase
+
 import numpy as np
 
+from ..utilities import NiftyMetaBase
+
 
 class LinearOperator(NiftyMetaBase()):
     """NIFTY base class for linear operators.
@@ -196,7 +198,10 @@ class LinearOperator(NiftyMetaBase()):
         raise NotImplementedError
 
     def __call__(self, x):
+        from ..nonlinear_operators import LinearModel, NonlinearOperator
         """Same as :meth:`times`"""
+        if isinstance(x, NonlinearOperator):
+            return LinearModel(x, self)
         return self.apply(x, self.TIMES)
 
     def times(self, x):
-- 
GitLab