From c3ed466f547cd9ee25fbe1ab84ecaf86e8a998a6 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sun, 5 Aug 2018 14:00:28 +0200
Subject: [PATCH] no more chains

---
 demos/Wiener_Filter.ipynb                     |  4 +--
 demos/bernoulli_demo.py                       |  4 +--
 demos/getting_started_1.py                    |  6 ++--
 demos/getting_started_2.py                    |  6 ++--
 demos/getting_started_3.py                    |  8 ++---
 demos/polynomial_fit.py                       |  2 +-
 nifty5/energies/hamiltonian.py                |  2 +-
 nifty5/energies/kl.py                         |  2 +-
 nifty5/library/amplitude_model.py             |  4 +--
 nifty5/library/bernoulli_energy.py            |  2 +-
 nifty5/library/correlated_fields.py           |  2 +-
 nifty5/library/gaussian_energy.py             |  2 +-
 nifty5/library/poissonian_energy.py           |  2 +-
 nifty5/linearization.py                       | 32 +++++++++---------
 nifty5/multi/block_diagonal_operator.py       |  2 +-
 .../operators/harmonic_smoothing_operator.py  |  2 +-
 nifty5/operators/linear_operator.py           | 11 ++++---
 nifty5/operators/operator.py                  | 33 ++++++-------------
 nifty5/operators/sandwich_operator.py         |  4 +--
 nifty5/operators/smoothness_operator.py       |  2 +-
 nifty5/utilities.py                           | 13 ++++----
 test/test_energies/test_map.py                |  8 ++---
 test/test_field.py                            |  4 +--
 test/test_models/test_model_gradients.py      |  6 ++--
 test/test_multi_field.py                      |  2 +-
 test/test_operators/test_adjoint.py           |  2 +-
 test/test_operators/test_composed_operator.py |  9 +++--
 27 files changed, 82 insertions(+), 94 deletions(-)

diff --git a/demos/Wiener_Filter.ipynb b/demos/Wiener_Filter.ipynb
index dd8c7f374..8b6969725 100644
--- a/demos/Wiener_Filter.ipynb
+++ b/demos/Wiener_Filter.ipynb
@@ -429,7 +429,7 @@
     "mask[l:h] = 0\n",
     "mask = ift.Field.from_global_data(s_space, mask)\n",
     "\n",
-    "R = ift.DiagonalOperator(mask).chain(HT)\n",
+    "R = ift.DiagonalOperator(mask)(HT)\n",
     "n = n.to_global_data_rw()\n",
     "n[l:h] = 0\n",
     "n = ift.Field.from_global_data(s_space, n)\n",
@@ -585,7 +585,7 @@
     "mask[l:h,l:h] = 0.\n",
     "mask = ift.Field.from_global_data(s_space, mask)\n",
     "\n",
-    "R = ift.DiagonalOperator(mask).chain(HT)\n",
+    "R = ift.DiagonalOperator(mask)(HT)\n",
     "n = n.to_global_data_rw()\n",
     "n[l:h, l:h] = 0\n",
     "n = ift.Field.from_global_data(s_space, n)\n",
diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py
index 7efd68a31..d517362c9 100644
--- a/demos/bernoulli_demo.py
+++ b/demos/bernoulli_demo.py
@@ -53,7 +53,7 @@ if __name__ == '__main__':
     A = pd(a)
 
     # Set up a sky model
-    sky = HT.chain(ift.makeOp(A)).positive_tanh()
+    sky = HT(ift.makeOp(A)).positive_tanh()
 
     GR = ift.GeometryRemover(position_space)
     # Set up instrumental response
@@ -61,7 +61,7 @@ if __name__ == '__main__':
 
     # Generate mock data
     d_space = R.target[0]
-    p = R.chain(sky)
+    p = R(sky)
     mock_position = ift.from_random('normal', harmonic_space)
     pp = p(mock_position)
     data = np.random.binomial(1, pp.to_global_data().astype(np.float64))
diff --git a/demos/getting_started_1.py b/demos/getting_started_1.py
index 1e36c2c18..76425bea5 100644
--- a/demos/getting_started_1.py
+++ b/demos/getting_started_1.py
@@ -78,7 +78,7 @@ if __name__ == '__main__':
     GR = ift.GeometryRemover(position_space)
     mask = ift.Field.from_global_data(position_space, mask)
     Mask = ift.DiagonalOperator(mask)
-    R = GR.chain(Mask).chain(HT)
+    R = GR(Mask(HT))
 
     data_space = GR.target
 
@@ -93,7 +93,7 @@ if __name__ == '__main__':
 
     # Build propagator D and information source j
     j = R.adjoint_times(N.inverse_times(data))
-    D_inv = R.adjoint.chain(N.inverse).chain(R) + S.inverse
+    D_inv = R.adjoint(N.inverse(R)) + S.inverse
     # Make it invertible
     IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3)
     D = ift.InversionEnabler(D_inv, IC, approximation=S.inverse).inverse
@@ -112,7 +112,7 @@ if __name__ == '__main__':
                         title="getting_started_1")
     else:
         ift.plot(HT(MOCK_SIGNAL), title='Mock Signal')
-        ift.plot(mask_to_nan(mask, (GR.chain(Mask)).adjoint(data)),
+        ift.plot(mask_to_nan(mask, (GR(Mask)).adjoint(data)),
                  title='Data')
         ift.plot(HT(m), title='Reconstruction')
         ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), title='Residuals')
diff --git a/demos/getting_started_2.py b/demos/getting_started_2.py
index d878ec0df..c6622395e 100644
--- a/demos/getting_started_2.py
+++ b/demos/getting_started_2.py
@@ -70,16 +70,16 @@ if __name__ == '__main__':
     A = pd(a)
 
     # Set up a sky model
-    sky = ift.exp(HT.chain(ift.makeOp(A)))
+    sky = ift.exp(HT(ift.makeOp(A)))
 
     M = ift.DiagonalOperator(exposure)
     GR = ift.GeometryRemover(position_space)
     # Set up instrumental response
-    R = GR.chain(M)
+    R = GR(M)
 
     # Generate mock data
     d_space = R.target[0]
-    lamb = R.chain(sky)
+    lamb = R(sky)
     mock_position = ift.from_random('normal', domain)
     data = lamb(mock_position)
     data = np.random.poisson(data.to_global_data().astype(np.float64))
diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py
index dc6174a86..3aa3a0a5a 100644
--- a/demos/getting_started_3.py
+++ b/demos/getting_started_3.py
@@ -44,8 +44,8 @@ if __name__ == '__main__':
     domain = ift.MultiDomain.union(
         (A.domain, ift.MultiDomain.make({'xi': harmonic_space})))
 
-    correlated_field = ht.chain(
-        power_distributor.chain(A)*ift.FieldAdapter(domain, "xi"))
+    correlated_field = ht(
+        power_distributor(A)*ift.FieldAdapter(domain, "xi"))
     # alternatively to the block above one can do:
     # correlated_field = ift.CorrelatedField(position_space, A)
 
@@ -57,7 +57,7 @@ if __name__ == '__main__':
     R = ift.LOSResponse(position_space, starts=LOS_starts,
                         ends=LOS_ends)
     # build signal response model and model likelihood
-    signal_response = R.chain(signal)
+    signal_response = R(signal)
     # specify noise
     data_space = R.target
     noise = .001
@@ -69,7 +69,7 @@ if __name__ == '__main__':
 
     # set up model likelihood
     likelihood = ift.GaussianEnergy(
-        mean=data, covariance=N).chain(signal_response)
+        mean=data, covariance=N)(signal_response)
 
     # set up minimization and inversion schemes
     ic_cg = ift.GradientNormController(iteration_limit=10)
diff --git a/demos/polynomial_fit.py b/demos/polynomial_fit.py
index 23ea6e726..c8eba5cec 100644
--- a/demos/polynomial_fit.py
+++ b/demos/polynomial_fit.py
@@ -97,7 +97,7 @@ d = ift.from_global_data(d_space, y)
 N = ift.DiagonalOperator(ift.from_global_data(d_space, var))
 
 IC = ift.GradientNormController(tol_abs_gradnorm=1e-8)
-likelihood = ift.GaussianEnergy(d, N).chain(R)
+likelihood = ift.GaussianEnergy(d, N)(R)
 H = ift.Hamiltonian(likelihood, IC)
 H = ift.EnergyAdapter(params, H)
 H = H.make_invertible(IC)
diff --git a/nifty5/energies/hamiltonian.py b/nifty5/energies/hamiltonian.py
index 48f4729ab..fb5a45176 100644
--- a/nifty5/energies/hamiltonian.py
+++ b/nifty5/energies/hamiltonian.py
@@ -40,7 +40,7 @@ class Hamiltonian(Operator):
     def target(self):
         return DomainTuple.scalar_domain()
 
-    def __call__(self, x):
+    def apply(self, x):
         if self._ic_samp is None or not isinstance(x, Linearization):
             return self._lh(x) + self._prior(x)
         else:
diff --git a/nifty5/energies/kl.py b/nifty5/energies/kl.py
index 6c9bb07fd..d513d1101 100644
--- a/nifty5/energies/kl.py
+++ b/nifty5/energies/kl.py
@@ -42,6 +42,6 @@ class SampledKullbachLeiblerDivergence(Operator):
     def target(self):
         return DomainTuple.scalar_domain()
 
-    def __call__(self, x):
+    def apply(self, x):
         return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) *
                 (1./len(self._res_samples)))
diff --git a/nifty5/library/amplitude_model.py b/nifty5/library/amplitude_model.py
index 7716a42ba..c2e27ef1f 100644
--- a/nifty5/library/amplitude_model.py
+++ b/nifty5/library/amplitude_model.py
@@ -130,7 +130,7 @@ class AmplitudeModel(Operator):
         cepstrum = create_cepstrum_amplitude_field(dof_space, kern)
 
         ceps = makeOp(sqrt(cepstrum))
-        self._smooth_op = sym.chain(qht).chain(ceps)
+        self._smooth_op = sym(qht(ceps))
         self._keys = tuple(keys)
 
     @property
@@ -141,7 +141,7 @@ class AmplitudeModel(Operator):
     def target(self):
         return self._target
 
-    def __call__(self, x):
+    def apply(self, x):
         smooth_spec = self._smooth_op(x[self._keys[0]])
         phi = x[self._keys[1]] + self._norm_phi_mean
         linear_spec = self._slope(phi)
diff --git a/nifty5/library/bernoulli_energy.py b/nifty5/library/bernoulli_energy.py
index 2ee3a0b53..5a6c33dee 100644
--- a/nifty5/library/bernoulli_energy.py
+++ b/nifty5/library/bernoulli_energy.py
@@ -39,7 +39,7 @@ class BernoulliEnergy(Operator):
     def target(self):
         return DomainTuple.scalar_domain()
 
-    def __call__(self, x):
+    def apply(self, x):
         x = self._p(x)
         v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
         if not isinstance(x, Linearization):
diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py
index 53c2cacb8..d139259b8 100644
--- a/nifty5/library/correlated_fields.py
+++ b/nifty5/library/correlated_fields.py
@@ -58,7 +58,7 @@ class CorrelatedField(Operator):
     def target(self):
         return self._ht.target
 
-    def __call__(self, x):
+    def apply(self, x):
         A = self._power_distributor(self._amplitude_model(x))
         correlated_field_h = A * x["xi"]
         correlated_field = self._ht(correlated_field_h)
diff --git a/nifty5/library/gaussian_energy.py b/nifty5/library/gaussian_energy.py
index 3c462f589..85a3511d5 100644
--- a/nifty5/library/gaussian_energy.py
+++ b/nifty5/library/gaussian_energy.py
@@ -55,7 +55,7 @@ class GaussianEnergy(Operator):
     def target(self):
         return DomainTuple.scalar_domain()
 
-    def __call__(self, x):
+    def apply(self, x):
         residual = x if self._mean is None else x-self._mean
         icovres = residual if self._icov is None else self._icov(residual)
         res = .5*residual.vdot(icovres)
diff --git a/nifty5/library/poissonian_energy.py b/nifty5/library/poissonian_energy.py
index 4260f9d6b..0a42ba93c 100644
--- a/nifty5/library/poissonian_energy.py
+++ b/nifty5/library/poissonian_energy.py
@@ -41,7 +41,7 @@ class PoissonianEnergy(Operator):
     def target(self):
         return DomainTuple.scalar_domain()
 
-    def __call__(self, x):
+    def apply(self, x):
         x = self._op(x)
         res = x.sum() - x.log().vdot(self._d)
         if not isinstance(x, Linearization):
diff --git a/nifty5/linearization.py b/nifty5/linearization.py
index 03985f05e..c9020c0cb 100644
--- a/nifty5/linearization.py
+++ b/nifty5/linearization.py
@@ -46,8 +46,8 @@ class Linearization(object):
 
     def __neg__(self):
         return Linearization(
-            -self._val, self._jac.chain(-1),
-            None if self._metric is None else self._metric.chain(-1))
+            -self._val, self._jac*(-1),
+            None if self._metric is None else self._metric*(-1))
 
     def __add__(self, other):
         if isinstance(other, Linearization):
@@ -77,24 +77,24 @@ class Linearization(object):
             d2 = makeOp(other._val)
             return Linearization(
                 self._val*other._val,
-                d2.chain(self._jac) + d1.chain(other._jac))
+                d2(self._jac) + d1(other._jac))
         if isinstance(other, (int, float, complex)):
             # if other == 0:
             #     return ...
-            met = None if self._metric is None else self._metric.chain(other)
-            return Linearization(self._val*other, self._jac.chain(other), met)
+            met = None if self._metric is None else self._metric(other)
+            return Linearization(self._val*other, self._jac(other), met)
         if isinstance(other, (Field, MultiField)):
             d2 = makeOp(other)
-            return Linearization(self._val*other, d2.chain(self._jac))
+            return Linearization(self._val*other, d2(self._jac))
         raise TypeError
 
     def __rmul__(self, other):
         from .sugar import makeOp
         if isinstance(other, (int, float, complex)):
-            return Linearization(self._val*other, self._jac.chain(other))
+            return Linearization(self._val*other, self._jac(other))
         if isinstance(other, (Field, MultiField)):
             d1 = makeOp(other)
-            return Linearization(self._val*other, d1.chain(self._jac))
+            return Linearization(self._val*other, d1(self._jac))
 
     def vdot(self, other):
         from .domain_tuple import DomainTuple
@@ -102,11 +102,11 @@ class Linearization(object):
         if isinstance(other, (Field, MultiField)):
             return Linearization(
                 Field(DomainTuple.scalar_domain(),self._val.vdot(other)),
-                VdotOperator(other).chain(self._jac))
+                VdotOperator(other)(self._jac))
         return Linearization(
             Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)),
-            VdotOperator(self._val).chain(other._jac) +
-            VdotOperator(other._val).chain(self._jac))
+            VdotOperator(self._val)(other._jac) +
+            VdotOperator(other._val)(self._jac))
 
     def sum(self):
         from .domain_tuple import DomainTuple
@@ -114,24 +114,24 @@ class Linearization(object):
         from .sugar import full
         return Linearization(
             Field(DomainTuple.scalar_domain(), self._val.sum()),
-            SumReductionOperator(self._jac.target).chain(self._jac))
+            SumReductionOperator(self._jac.target)(self._jac))
 
     def exp(self):
         tmp = self._val.exp()
-        return Linearization(tmp, makeOp(tmp).chain(self._jac))
+        return Linearization(tmp, makeOp(tmp)(self._jac))
 
     def log(self):
         tmp = self._val.log()
-        return Linearization(tmp, makeOp(1./self._val).chain(self._jac))
+        return Linearization(tmp, makeOp(1./self._val)(self._jac))
 
     def tanh(self):
         tmp = self._val.tanh()
-        return Linearization(tmp, makeOp(1.-tmp**2).chain(self._jac))
+        return Linearization(tmp, makeOp(1.-tmp**2)(self._jac))
 
     def positive_tanh(self):
         tmp = self._val.tanh()
         tmp2 = 0.5*(1.+tmp)
-        return Linearization(tmp2, makeOp(0.5*(1.-tmp**2)).chain(self._jac))
+        return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
 
     def add_metric(self, metric):
         return Linearization(self._val, self._jac, metric)
diff --git a/nifty5/multi/block_diagonal_operator.py b/nifty5/multi/block_diagonal_operator.py
index 41c03043b..9d176a353 100644
--- a/nifty5/multi/block_diagonal_operator.py
+++ b/nifty5/multi/block_diagonal_operator.py
@@ -68,7 +68,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
     def _combine_chain(self, op):
         if self._domain is not op._domain:
             raise ValueError("domain mismatch")
-        res = tuple(v1.chain(v2) for v1, v2 in zip(self._ops, op._ops))
+        res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops))
         return BlockDiagonalOperator(self._domain, res)
 
     def _combine_sum(self, op, selfneg, opneg):
diff --git a/nifty5/operators/harmonic_smoothing_operator.py b/nifty5/operators/harmonic_smoothing_operator.py
index ca40ce806..a1bfbac0b 100644
--- a/nifty5/operators/harmonic_smoothing_operator.py
+++ b/nifty5/operators/harmonic_smoothing_operator.py
@@ -67,4 +67,4 @@ def HarmonicSmoothingOperator(domain, sigma, space=None):
     ddom = list(domain)
     ddom[space] = codomain
     diag = DiagonalOperator(kernel, ddom, space)
-    return Hartley.inverse.chain(diag).chain(Hartley)
+    return Hartley.inverse(diag(Hartley))
diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py
index 783065ffc..5b4bbf799 100644
--- a/nifty5/operators/linear_operator.py
+++ b/nifty5/operators/linear_operator.py
@@ -142,9 +142,6 @@ class LinearOperator(Operator):
         from .chain_operator import ChainOperator
         return ChainOperator.make([self, other2])
 
-    def chain(self, other):
-        return self.__matmul__(other)
-
     def __rmatmul__(self, other):
         if np.isscalar(other) and other == 1.:
             return self
@@ -213,10 +210,14 @@ class LinearOperator(Operator):
 
     def __call__(self, x):
         """Same as :meth:`times`"""
+        from ..field import Field
+        from ..multi.multi_field import MultiField
+        if isinstance(x, (Field, MultiField)):
+            return self.apply(x, self.TIMES)
         from ..linearization import Linearization
         if isinstance(x, Linearization):
-            return Linearization(self(x._val), self.chain(x._jac))
-        return self.apply(x, self.TIMES)
+            return Linearization(self(x._val), self(x._jac))
+        return self.__matmul__(x)
 
     def times(self, x):
         """ Applies the Operator to a given Field.
diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py
index 61346cde3..55416578e 100644
--- a/nifty5/operators/operator.py
+++ b/nifty5/operators/operator.py
@@ -33,26 +33,13 @@ class Operator(NiftyMetaBase()):
             return NotImplemented
         return _OpProd.make((self, x))
 
-    def chain(self, x):
-        res = self.__matmul__(x)
-        if res == NotImplemented:
-            raise TypeError("operator expected")
-        return res
+    def apply(self, x):
+        raise NotImplementedError
 
     def __call__(self, x):
-        """Returns transformed x
-
-        Parameters
-        ----------
-        x : Linearization
-            input
-
-        Returns
-        -------
-        Linearization
-            output
-        """
-        raise NotImplementedError
+       if isinstance(x, Operator):
+           return _OpChain.make((self, x))
+       return self.apply(x)
 
 
 for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]:
@@ -78,7 +65,7 @@ class _FunctionApplier(Operator):
     def target(self):
         return self._domain
 
-    def __call__(self, x):
+    def apply(self, x):
         return getattr(x, self._funcname)()
 
 
@@ -117,7 +104,7 @@ class _OpChain(_CombinedOperator):
     def target(self):
         return self._ops[0].target
 
-    def __call__(self, x):
+    def apply(self, x):
         for op in reversed(self._ops):
             x = op(x)
         return x
@@ -135,7 +122,7 @@ class _OpProd(_CombinedOperator):
     def target(self):
         return self._ops[0].target
 
-    def __call__(self, x):
+    def apply(self, x):
         from ..utilities import my_product
         return my_product(map(lambda op: op(x), self._ops))
 
@@ -154,7 +141,7 @@ class _OpSum(_CombinedOperator):
     def target(self):
         return self._target
 
-    def __call__(self, x):
+    def apply(self, x):
         raise NotImplementedError
 
 
@@ -193,7 +180,7 @@ class QuadraticFormOperator(Operator):
     def target(self):
         return self._target
 
-    def __call__(self, x):
+    def apply(self, x):
         if isinstance(x, Linearization):
             jac = self._op(x)
             val = Field(self._target, 0.5 * x.vdot(jac))
diff --git a/nifty5/operators/sandwich_operator.py b/nifty5/operators/sandwich_operator.py
index 20236a580..480e098db 100644
--- a/nifty5/operators/sandwich_operator.py
+++ b/nifty5/operators/sandwich_operator.py
@@ -56,9 +56,9 @@ class SandwichOperator(EndomorphicOperator):
             raise TypeError("cheese must be a linear operator")
         if cheese is None:
             cheese = ScalingOperator(1., bun.target)
-            op = bun.adjoint.chain(bun)
+            op = bun.adjoint(bun)
         else:
-            op = bun.adjoint.chain(cheese).chain(bun)
+            op = bun.adjoint(cheese(bun))
 
         # if our sandwich is diagonal, we can return immediately
         if isinstance(op, (ScalingOperator, DiagonalOperator)):
diff --git a/nifty5/operators/smoothness_operator.py b/nifty5/operators/smoothness_operator.py
index 981863927..38e068730 100644
--- a/nifty5/operators/smoothness_operator.py
+++ b/nifty5/operators/smoothness_operator.py
@@ -54,4 +54,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
     if strength == 0.:
         return ScalingOperator(0., domain)
     laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space)
-    return (strength**2)*laplace.adjoint.chain(laplace)
+    return (strength**2)*laplace.adjoint(laplace)
diff --git a/nifty5/utilities.py b/nifty5/utilities.py
index 8bf84bc2c..d7436ccf8 100644
--- a/nifty5/utilities.py
+++ b/nifty5/utilities.py
@@ -23,6 +23,8 @@ from itertools import product
 
 import numpy as np
 from future.utils import with_metaclass
+import pyfftw
+from pyfftw.interfaces.numpy_fft import rfftn, fftn
 
 from .compat import *
 
@@ -201,9 +203,11 @@ _fft_extra_args = dict(planner_effort='FFTW_ESTIMATE')
 
 
 def fft_prep():
-    import pyfftw
-    pyfftw.interfaces.cache.enable()
-    pyfftw.interfaces.cache.set_keepalive_time(1000.)
+    if not fft_prep._initialized:
+        pyfftw.interfaces.cache.enable()
+        pyfftw.interfaces.cache.set_keepalive_time(1000.)
+        fft_prep._initialized = True
+fft_prep._initialized = False
 
 
 def hartley(a, axes=None):
@@ -214,7 +218,6 @@ def hartley(a, axes=None):
     if iscomplextype(a.dtype):
         raise TypeError("Hartley transform requires real-valued arrays.")
 
-    from pyfftw.interfaces.numpy_fft import rfftn
     tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
 
     def _fill_array(tmp, res, axes):
@@ -258,7 +261,6 @@ def my_fftn_r2c(a, axes=None):
     if iscomplextype(a.dtype):
         raise TypeError("Transform requires real-valued input arrays.")
 
-    from pyfftw.interfaces.numpy_fft import rfftn
     tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args)
 
     def _fill_complex_array(tmp, res, axes):
@@ -293,7 +295,6 @@ def my_fftn_r2c(a, axes=None):
 
 
 def my_fftn(a, axes=None):
-    from pyfftw.interfaces.numpy_fft import fftn
     return fftn(a, axes=axes, **_fft_extra_args)
 
 
diff --git a/test/test_energies/test_map.py b/test/test_energies/test_map.py
index a88a4b272..ae53e8259 100644
--- a/test/test_energies/test_map.py
+++ b/test/test_energies/test_map.py
@@ -56,18 +56,18 @@ class Energy_Tests(unittest.TestCase):
 
         def d_model():
             if nonlinearity == "":
-                return R.chain(ht.chain(ift.makeOp(A)))
+                return R(ht(ift.makeOp(A)))
             else:
-                tmp = ht.chain(ift.makeOp(A))
+                tmp = ht(ift.makeOp(A))
                 nonlin = getattr(tmp, nonlinearity)()
-                return R.chain(nonlin)
+                return R(nonlin)
 
         d = d_model()(xi0) + n
 
         if noise == 1:
             N = None
 
-        energy = ift.GaussianEnergy(d, N).chain(d_model())
+        energy = ift.GaussianEnergy(d, N)(d_model())
         if nonlinearity == "":
             ift.extra.check_value_gradient_metric_consistency(
                 energy, xi0, ntries=10)
diff --git a/test/test_field.py b/test/test_field.py
index e08463c90..cfc6f7d0e 100644
--- a/test/test_field.py
+++ b/test/test_field.py
@@ -66,7 +66,7 @@ class Test_Functionality(unittest.TestCase):
 
         op1 = ift.create_power_operator((space1, space2), _spec1, 0)
         op2 = ift.create_power_operator((space1, space2), _spec2, 1)
-        opfull = op2.chain(op1)
+        opfull = op2(op1)
 
         samples = 500
         sc1 = ift.StatCalculator()
@@ -94,7 +94,7 @@ class Test_Functionality(unittest.TestCase):
 
         S_1 = ift.create_power_operator((space1, space2), _spec1, 0)
         S_2 = ift.create_power_operator((space1, space2), _spec2, 1)
-        S_full = S_2.chain(S_1)
+        S_full = S_2(S_1)
 
         samples = 500
         sc1 = ift.StatCalculator()
diff --git a/test/test_models/test_model_gradients.py b/test/test_models/test_model_gradients.py
index 8f1bdf839..6a3abd2d5 100644
--- a/test/test_models/test_model_gradients.py
+++ b/test/test_models/test_model_gradients.py
@@ -71,16 +71,16 @@ class Model_Tests(unittest.TestCase):
         model = ift.FieldAdapter(dom, "s1")*3.
         pos = ift.from_random("normal", dom)
         ift.extra.check_value_gradient_consistency(model, pos)
-        model = ift.ScalingOperator(2.456, space).chain(
+        model = ift.ScalingOperator(2.456, space)(
             ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
         pos = ift.from_random("normal", dom)
         ift.extra.check_value_gradient_consistency(model, pos)
-        model = ift.positive_tanh(ift.ScalingOperator(2.456, space).chain(
+        model = ift.positive_tanh(ift.ScalingOperator(2.456, space)(
             ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")))
         pos = ift.from_random("normal", dom)
         ift.extra.check_value_gradient_consistency(model, pos)
         if isinstance(space, ift.RGSpace):
-            model = ift.FFTOperator(space).chain(
+            model = ift.FFTOperator(space)(
                 ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
             pos = ift.from_random("normal", dom)
             ift.extra.check_value_gradient_consistency(model, pos)
diff --git a/test/test_multi_field.py b/test/test_multi_field.py
index f2dbdf082..d9854bded 100644
--- a/test/test_multi_field.py
+++ b/test/test_multi_field.py
@@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase):
     def test_blockdiagonal(self):
         op = ift.BlockDiagonalOperator(
             dom, (ift.ScalingOperator(20., dom["d1"]),))
-        op2 = op.chain(op)
+        op2 = op(op)
         ift.extra.consistency_check(op2)
         assert_equal(type(op2), ift.BlockDiagonalOperator)
         f1 = op2(ift.full(dom, 1))
diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py
index e56ee11df..540ec34d8 100644
--- a/test/test_operators/test_adjoint.py
+++ b/test/test_operators/test_adjoint.py
@@ -53,7 +53,7 @@ class Consistency_Tests(unittest.TestCase):
                                                        dtype=dtype))
         op = ift.SandwichOperator.make(a, b)
         ift.extra.consistency_check(op, dtype, dtype)
-        op = a.chain(b)
+        op = a(b)
         ift.extra.consistency_check(op, dtype, dtype)
         op = a+b
         ift.extra.consistency_check(op, dtype, dtype)
diff --git a/test/test_operators/test_composed_operator.py b/test/test_operators/test_composed_operator.py
index 81bb1acd4..46a0cd8cf 100644
--- a/test/test_operators/test_composed_operator.py
+++ b/test/test_operators/test_composed_operator.py
@@ -37,7 +37,7 @@ class ComposedOperator_Tests(unittest.TestCase):
         op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
         op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
 
-        op = op2.chain(op1)
+        op = op2(op1)
 
         rand1 = ift.Field.from_random('normal', domain=(space1, space2))
         rand2 = ift.Field.from_random('normal', domain=(space1, space2))
@@ -54,7 +54,7 @@ class ComposedOperator_Tests(unittest.TestCase):
         op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,))
         op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,))
 
-        op = op2.chain(op1)
+        op = op2(op1)
 
         rand1 = ift.Field.from_random('normal', domain=(space1, space2))
         tt1 = op.inverse_times(op.times(rand1))
@@ -75,8 +75,7 @@ class ComposedOperator_Tests(unittest.TestCase):
     def test_chain(self, space):
         op1 = ift.makeOp(ift.Field.full(space, 2.))
         op2 = 3.
-        full_op = (op1.chain(op2).chain(op2).chain(op1).
-                   chain(op1).chain(op1).chain(op2))
+        full_op = op1(op2)(op2)(op1)(op1)(op1)(op2)
         x = ift.Field.full(space, 1.)
         res = full_op(x)
         assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
@@ -86,7 +85,7 @@ class ComposedOperator_Tests(unittest.TestCase):
     def test_mix(self, space):
         op1 = ift.makeOp(ift.Field.full(space, 2.))
         op2 = 3.
-        full_op = op1.chain(op2 + op2).chain(op1).chain(op1) - op1.chain(op2)
+        full_op = op1(op2+op2)(op1)(op1) - op1(op2)
         x = ift.Field.full(space, 1.)
         res = full_op(x)
         assert_equal(isinstance(full_op, ift.DiagonalOperator), True)
-- 
GitLab