From c3ed466f547cd9ee25fbe1ab84ecaf86e8a998a6 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Sun, 5 Aug 2018 14:00:28 +0200 Subject: [PATCH] no more chains --- demos/Wiener_Filter.ipynb | 4 +-- demos/bernoulli_demo.py | 4 +-- demos/getting_started_1.py | 6 ++-- demos/getting_started_2.py | 6 ++-- demos/getting_started_3.py | 8 ++--- demos/polynomial_fit.py | 2 +- nifty5/energies/hamiltonian.py | 2 +- nifty5/energies/kl.py | 2 +- nifty5/library/amplitude_model.py | 4 +-- nifty5/library/bernoulli_energy.py | 2 +- nifty5/library/correlated_fields.py | 2 +- nifty5/library/gaussian_energy.py | 2 +- nifty5/library/poissonian_energy.py | 2 +- nifty5/linearization.py | 32 +++++++++--------- nifty5/multi/block_diagonal_operator.py | 2 +- .../operators/harmonic_smoothing_operator.py | 2 +- nifty5/operators/linear_operator.py | 11 ++++--- nifty5/operators/operator.py | 33 ++++++------------- nifty5/operators/sandwich_operator.py | 4 +-- nifty5/operators/smoothness_operator.py | 2 +- nifty5/utilities.py | 13 ++++---- test/test_energies/test_map.py | 8 ++--- test/test_field.py | 4 +-- test/test_models/test_model_gradients.py | 6 ++-- test/test_multi_field.py | 2 +- test/test_operators/test_adjoint.py | 2 +- test/test_operators/test_composed_operator.py | 9 +++-- 27 files changed, 82 insertions(+), 94 deletions(-) diff --git a/demos/Wiener_Filter.ipynb b/demos/Wiener_Filter.ipynb index dd8c7f374..8b6969725 100644 --- a/demos/Wiener_Filter.ipynb +++ b/demos/Wiener_Filter.ipynb @@ -429,7 +429,7 @@ "mask[l:h] = 0\n", "mask = ift.Field.from_global_data(s_space, mask)\n", "\n", - "R = ift.DiagonalOperator(mask).chain(HT)\n", + "R = ift.DiagonalOperator(mask)(HT)\n", "n = n.to_global_data_rw()\n", "n[l:h] = 0\n", "n = ift.Field.from_global_data(s_space, n)\n", @@ -585,7 +585,7 @@ "mask[l:h,l:h] = 0.\n", "mask = ift.Field.from_global_data(s_space, mask)\n", "\n", - "R = ift.DiagonalOperator(mask).chain(HT)\n", + "R = ift.DiagonalOperator(mask)(HT)\n", "n = n.to_global_data_rw()\n", "n[l:h, l:h] = 0\n", "n = ift.Field.from_global_data(s_space, n)\n", diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py index 7efd68a31..d517362c9 100644 --- a/demos/bernoulli_demo.py +++ b/demos/bernoulli_demo.py @@ -53,7 +53,7 @@ if __name__ == '__main__': A = pd(a) # Set up a sky model - sky = HT.chain(ift.makeOp(A)).positive_tanh() + sky = HT(ift.makeOp(A)).positive_tanh() GR = ift.GeometryRemover(position_space) # Set up instrumental response @@ -61,7 +61,7 @@ if __name__ == '__main__': # Generate mock data d_space = R.target[0] - p = R.chain(sky) + p = R(sky) mock_position = ift.from_random('normal', harmonic_space) pp = p(mock_position) data = np.random.binomial(1, pp.to_global_data().astype(np.float64)) diff --git a/demos/getting_started_1.py b/demos/getting_started_1.py index 1e36c2c18..76425bea5 100644 --- a/demos/getting_started_1.py +++ b/demos/getting_started_1.py @@ -78,7 +78,7 @@ if __name__ == '__main__': GR = ift.GeometryRemover(position_space) mask = ift.Field.from_global_data(position_space, mask) Mask = ift.DiagonalOperator(mask) - R = GR.chain(Mask).chain(HT) + R = GR(Mask(HT)) data_space = GR.target @@ -93,7 +93,7 @@ if __name__ == '__main__': # Build propagator D and information source j j = R.adjoint_times(N.inverse_times(data)) - D_inv = R.adjoint.chain(N.inverse).chain(R) + S.inverse + D_inv = R.adjoint(N.inverse(R)) + S.inverse # Make it invertible IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3) D = ift.InversionEnabler(D_inv, IC, approximation=S.inverse).inverse @@ -112,7 +112,7 @@ if __name__ == '__main__': title="getting_started_1") else: ift.plot(HT(MOCK_SIGNAL), title='Mock Signal') - ift.plot(mask_to_nan(mask, (GR.chain(Mask)).adjoint(data)), + ift.plot(mask_to_nan(mask, (GR(Mask)).adjoint(data)), title='Data') ift.plot(HT(m), title='Reconstruction') ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), title='Residuals') diff --git a/demos/getting_started_2.py b/demos/getting_started_2.py index d878ec0df..c6622395e 100644 --- a/demos/getting_started_2.py +++ b/demos/getting_started_2.py @@ -70,16 +70,16 @@ if __name__ == '__main__': A = pd(a) # Set up a sky model - sky = ift.exp(HT.chain(ift.makeOp(A))) + sky = ift.exp(HT(ift.makeOp(A))) M = ift.DiagonalOperator(exposure) GR = ift.GeometryRemover(position_space) # Set up instrumental response - R = GR.chain(M) + R = GR(M) # Generate mock data d_space = R.target[0] - lamb = R.chain(sky) + lamb = R(sky) mock_position = ift.from_random('normal', domain) data = lamb(mock_position) data = np.random.poisson(data.to_global_data().astype(np.float64)) diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index dc6174a86..3aa3a0a5a 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -44,8 +44,8 @@ if __name__ == '__main__': domain = ift.MultiDomain.union( (A.domain, ift.MultiDomain.make({'xi': harmonic_space}))) - correlated_field = ht.chain( - power_distributor.chain(A)*ift.FieldAdapter(domain, "xi")) + correlated_field = ht( + power_distributor(A)*ift.FieldAdapter(domain, "xi")) # alternatively to the block above one can do: # correlated_field = ift.CorrelatedField(position_space, A) @@ -57,7 +57,7 @@ if __name__ == '__main__': R = ift.LOSResponse(position_space, starts=LOS_starts, ends=LOS_ends) # build signal response model and model likelihood - signal_response = R.chain(signal) + signal_response = R(signal) # specify noise data_space = R.target noise = .001 @@ -69,7 +69,7 @@ if __name__ == '__main__': # set up model likelihood likelihood = ift.GaussianEnergy( - mean=data, covariance=N).chain(signal_response) + mean=data, covariance=N)(signal_response) # set up minimization and inversion schemes ic_cg = ift.GradientNormController(iteration_limit=10) diff --git a/demos/polynomial_fit.py b/demos/polynomial_fit.py index 23ea6e726..c8eba5cec 100644 --- a/demos/polynomial_fit.py +++ b/demos/polynomial_fit.py @@ -97,7 +97,7 @@ d = ift.from_global_data(d_space, y) N = ift.DiagonalOperator(ift.from_global_data(d_space, var)) IC = ift.GradientNormController(tol_abs_gradnorm=1e-8) -likelihood = ift.GaussianEnergy(d, N).chain(R) +likelihood = ift.GaussianEnergy(d, N)(R) H = ift.Hamiltonian(likelihood, IC) H = ift.EnergyAdapter(params, H) H = H.make_invertible(IC) diff --git a/nifty5/energies/hamiltonian.py b/nifty5/energies/hamiltonian.py index 48f4729ab..fb5a45176 100644 --- a/nifty5/energies/hamiltonian.py +++ b/nifty5/energies/hamiltonian.py @@ -40,7 +40,7 @@ class Hamiltonian(Operator): def target(self): return DomainTuple.scalar_domain() - def __call__(self, x): + def apply(self, x): if self._ic_samp is None or not isinstance(x, Linearization): return self._lh(x) + self._prior(x) else: diff --git a/nifty5/energies/kl.py b/nifty5/energies/kl.py index 6c9bb07fd..d513d1101 100644 --- a/nifty5/energies/kl.py +++ b/nifty5/energies/kl.py @@ -42,6 +42,6 @@ class SampledKullbachLeiblerDivergence(Operator): def target(self): return DomainTuple.scalar_domain() - def __call__(self, x): + def apply(self, x): return (my_sum(map(lambda v: self._h(x+v), self._res_samples)) * (1./len(self._res_samples))) diff --git a/nifty5/library/amplitude_model.py b/nifty5/library/amplitude_model.py index 7716a42ba..c2e27ef1f 100644 --- a/nifty5/library/amplitude_model.py +++ b/nifty5/library/amplitude_model.py @@ -130,7 +130,7 @@ class AmplitudeModel(Operator): cepstrum = create_cepstrum_amplitude_field(dof_space, kern) ceps = makeOp(sqrt(cepstrum)) - self._smooth_op = sym.chain(qht).chain(ceps) + self._smooth_op = sym(qht(ceps)) self._keys = tuple(keys) @property @@ -141,7 +141,7 @@ class AmplitudeModel(Operator): def target(self): return self._target - def __call__(self, x): + def apply(self, x): smooth_spec = self._smooth_op(x[self._keys[0]]) phi = x[self._keys[1]] + self._norm_phi_mean linear_spec = self._slope(phi) diff --git a/nifty5/library/bernoulli_energy.py b/nifty5/library/bernoulli_energy.py index 2ee3a0b53..5a6c33dee 100644 --- a/nifty5/library/bernoulli_energy.py +++ b/nifty5/library/bernoulli_energy.py @@ -39,7 +39,7 @@ class BernoulliEnergy(Operator): def target(self): return DomainTuple.scalar_domain() - def __call__(self, x): + def apply(self, x): x = self._p(x) v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) if not isinstance(x, Linearization): diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index 53c2cacb8..d139259b8 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -58,7 +58,7 @@ class CorrelatedField(Operator): def target(self): return self._ht.target - def __call__(self, x): + def apply(self, x): A = self._power_distributor(self._amplitude_model(x)) correlated_field_h = A * x["xi"] correlated_field = self._ht(correlated_field_h) diff --git a/nifty5/library/gaussian_energy.py b/nifty5/library/gaussian_energy.py index 3c462f589..85a3511d5 100644 --- a/nifty5/library/gaussian_energy.py +++ b/nifty5/library/gaussian_energy.py @@ -55,7 +55,7 @@ class GaussianEnergy(Operator): def target(self): return DomainTuple.scalar_domain() - def __call__(self, x): + def apply(self, x): residual = x if self._mean is None else x-self._mean icovres = residual if self._icov is None else self._icov(residual) res = .5*residual.vdot(icovres) diff --git a/nifty5/library/poissonian_energy.py b/nifty5/library/poissonian_energy.py index 4260f9d6b..0a42ba93c 100644 --- a/nifty5/library/poissonian_energy.py +++ b/nifty5/library/poissonian_energy.py @@ -41,7 +41,7 @@ class PoissonianEnergy(Operator): def target(self): return DomainTuple.scalar_domain() - def __call__(self, x): + def apply(self, x): x = self._op(x) res = x.sum() - x.log().vdot(self._d) if not isinstance(x, Linearization): diff --git a/nifty5/linearization.py b/nifty5/linearization.py index 03985f05e..c9020c0cb 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -46,8 +46,8 @@ class Linearization(object): def __neg__(self): return Linearization( - -self._val, self._jac.chain(-1), - None if self._metric is None else self._metric.chain(-1)) + -self._val, self._jac*(-1), + None if self._metric is None else self._metric*(-1)) def __add__(self, other): if isinstance(other, Linearization): @@ -77,24 +77,24 @@ class Linearization(object): d2 = makeOp(other._val) return Linearization( self._val*other._val, - d2.chain(self._jac) + d1.chain(other._jac)) + d2(self._jac) + d1(other._jac)) if isinstance(other, (int, float, complex)): # if other == 0: # return ... - met = None if self._metric is None else self._metric.chain(other) - return Linearization(self._val*other, self._jac.chain(other), met) + met = None if self._metric is None else self._metric(other) + return Linearization(self._val*other, self._jac(other), met) if isinstance(other, (Field, MultiField)): d2 = makeOp(other) - return Linearization(self._val*other, d2.chain(self._jac)) + return Linearization(self._val*other, d2(self._jac)) raise TypeError def __rmul__(self, other): from .sugar import makeOp if isinstance(other, (int, float, complex)): - return Linearization(self._val*other, self._jac.chain(other)) + return Linearization(self._val*other, self._jac(other)) if isinstance(other, (Field, MultiField)): d1 = makeOp(other) - return Linearization(self._val*other, d1.chain(self._jac)) + return Linearization(self._val*other, d1(self._jac)) def vdot(self, other): from .domain_tuple import DomainTuple @@ -102,11 +102,11 @@ class Linearization(object): if isinstance(other, (Field, MultiField)): return Linearization( Field(DomainTuple.scalar_domain(),self._val.vdot(other)), - VdotOperator(other).chain(self._jac)) + VdotOperator(other)(self._jac)) return Linearization( Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)), - VdotOperator(self._val).chain(other._jac) + - VdotOperator(other._val).chain(self._jac)) + VdotOperator(self._val)(other._jac) + + VdotOperator(other._val)(self._jac)) def sum(self): from .domain_tuple import DomainTuple @@ -114,24 +114,24 @@ class Linearization(object): from .sugar import full return Linearization( Field(DomainTuple.scalar_domain(), self._val.sum()), - SumReductionOperator(self._jac.target).chain(self._jac)) + SumReductionOperator(self._jac.target)(self._jac)) def exp(self): tmp = self._val.exp() - return Linearization(tmp, makeOp(tmp).chain(self._jac)) + return Linearization(tmp, makeOp(tmp)(self._jac)) def log(self): tmp = self._val.log() - return Linearization(tmp, makeOp(1./self._val).chain(self._jac)) + return Linearization(tmp, makeOp(1./self._val)(self._jac)) def tanh(self): tmp = self._val.tanh() - return Linearization(tmp, makeOp(1.-tmp**2).chain(self._jac)) + return Linearization(tmp, makeOp(1.-tmp**2)(self._jac)) def positive_tanh(self): tmp = self._val.tanh() tmp2 = 0.5*(1.+tmp) - return Linearization(tmp2, makeOp(0.5*(1.-tmp**2)).chain(self._jac)) + return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) def add_metric(self, metric): return Linearization(self._val, self._jac, metric) diff --git a/nifty5/multi/block_diagonal_operator.py b/nifty5/multi/block_diagonal_operator.py index 41c03043b..9d176a353 100644 --- a/nifty5/multi/block_diagonal_operator.py +++ b/nifty5/multi/block_diagonal_operator.py @@ -68,7 +68,7 @@ class BlockDiagonalOperator(EndomorphicOperator): def _combine_chain(self, op): if self._domain is not op._domain: raise ValueError("domain mismatch") - res = tuple(v1.chain(v2) for v1, v2 in zip(self._ops, op._ops)) + res = tuple(v1(v2) for v1, v2 in zip(self._ops, op._ops)) return BlockDiagonalOperator(self._domain, res) def _combine_sum(self, op, selfneg, opneg): diff --git a/nifty5/operators/harmonic_smoothing_operator.py b/nifty5/operators/harmonic_smoothing_operator.py index ca40ce806..a1bfbac0b 100644 --- a/nifty5/operators/harmonic_smoothing_operator.py +++ b/nifty5/operators/harmonic_smoothing_operator.py @@ -67,4 +67,4 @@ def HarmonicSmoothingOperator(domain, sigma, space=None): ddom = list(domain) ddom[space] = codomain diag = DiagonalOperator(kernel, ddom, space) - return Hartley.inverse.chain(diag).chain(Hartley) + return Hartley.inverse(diag(Hartley)) diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py index 783065ffc..5b4bbf799 100644 --- a/nifty5/operators/linear_operator.py +++ b/nifty5/operators/linear_operator.py @@ -142,9 +142,6 @@ class LinearOperator(Operator): from .chain_operator import ChainOperator return ChainOperator.make([self, other2]) - def chain(self, other): - return self.__matmul__(other) - def __rmatmul__(self, other): if np.isscalar(other) and other == 1.: return self @@ -213,10 +210,14 @@ class LinearOperator(Operator): def __call__(self, x): """Same as :meth:`times`""" + from ..field import Field + from ..multi.multi_field import MultiField + if isinstance(x, (Field, MultiField)): + return self.apply(x, self.TIMES) from ..linearization import Linearization if isinstance(x, Linearization): - return Linearization(self(x._val), self.chain(x._jac)) - return self.apply(x, self.TIMES) + return Linearization(self(x._val), self(x._jac)) + return self.__matmul__(x) def times(self, x): """ Applies the Operator to a given Field. diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 61346cde3..55416578e 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -33,26 +33,13 @@ class Operator(NiftyMetaBase()): return NotImplemented return _OpProd.make((self, x)) - def chain(self, x): - res = self.__matmul__(x) - if res == NotImplemented: - raise TypeError("operator expected") - return res + def apply(self, x): + raise NotImplementedError def __call__(self, x): - """Returns transformed x - - Parameters - ---------- - x : Linearization - input - - Returns - ------- - Linearization - output - """ - raise NotImplementedError + if isinstance(x, Operator): + return _OpChain.make((self, x)) + return self.apply(x) for f in ["sqrt", "exp", "log", "tanh", "positive_tanh"]: @@ -78,7 +65,7 @@ class _FunctionApplier(Operator): def target(self): return self._domain - def __call__(self, x): + def apply(self, x): return getattr(x, self._funcname)() @@ -117,7 +104,7 @@ class _OpChain(_CombinedOperator): def target(self): return self._ops[0].target - def __call__(self, x): + def apply(self, x): for op in reversed(self._ops): x = op(x) return x @@ -135,7 +122,7 @@ class _OpProd(_CombinedOperator): def target(self): return self._ops[0].target - def __call__(self, x): + def apply(self, x): from ..utilities import my_product return my_product(map(lambda op: op(x), self._ops)) @@ -154,7 +141,7 @@ class _OpSum(_CombinedOperator): def target(self): return self._target - def __call__(self, x): + def apply(self, x): raise NotImplementedError @@ -193,7 +180,7 @@ class QuadraticFormOperator(Operator): def target(self): return self._target - def __call__(self, x): + def apply(self, x): if isinstance(x, Linearization): jac = self._op(x) val = Field(self._target, 0.5 * x.vdot(jac)) diff --git a/nifty5/operators/sandwich_operator.py b/nifty5/operators/sandwich_operator.py index 20236a580..480e098db 100644 --- a/nifty5/operators/sandwich_operator.py +++ b/nifty5/operators/sandwich_operator.py @@ -56,9 +56,9 @@ class SandwichOperator(EndomorphicOperator): raise TypeError("cheese must be a linear operator") if cheese is None: cheese = ScalingOperator(1., bun.target) - op = bun.adjoint.chain(bun) + op = bun.adjoint(bun) else: - op = bun.adjoint.chain(cheese).chain(bun) + op = bun.adjoint(cheese(bun)) # if our sandwich is diagonal, we can return immediately if isinstance(op, (ScalingOperator, DiagonalOperator)): diff --git a/nifty5/operators/smoothness_operator.py b/nifty5/operators/smoothness_operator.py index 981863927..38e068730 100644 --- a/nifty5/operators/smoothness_operator.py +++ b/nifty5/operators/smoothness_operator.py @@ -54,4 +54,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None): if strength == 0.: return ScalingOperator(0., domain) laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space) - return (strength**2)*laplace.adjoint.chain(laplace) + return (strength**2)*laplace.adjoint(laplace) diff --git a/nifty5/utilities.py b/nifty5/utilities.py index 8bf84bc2c..d7436ccf8 100644 --- a/nifty5/utilities.py +++ b/nifty5/utilities.py @@ -23,6 +23,8 @@ from itertools import product import numpy as np from future.utils import with_metaclass +import pyfftw +from pyfftw.interfaces.numpy_fft import rfftn, fftn from .compat import * @@ -201,9 +203,11 @@ _fft_extra_args = dict(planner_effort='FFTW_ESTIMATE') def fft_prep(): - import pyfftw - pyfftw.interfaces.cache.enable() - pyfftw.interfaces.cache.set_keepalive_time(1000.) + if not fft_prep._initialized: + pyfftw.interfaces.cache.enable() + pyfftw.interfaces.cache.set_keepalive_time(1000.) + fft_prep._initialized = True +fft_prep._initialized = False def hartley(a, axes=None): @@ -214,7 +218,6 @@ def hartley(a, axes=None): if iscomplextype(a.dtype): raise TypeError("Hartley transform requires real-valued arrays.") - from pyfftw.interfaces.numpy_fft import rfftn tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args) def _fill_array(tmp, res, axes): @@ -258,7 +261,6 @@ def my_fftn_r2c(a, axes=None): if iscomplextype(a.dtype): raise TypeError("Transform requires real-valued input arrays.") - from pyfftw.interfaces.numpy_fft import rfftn tmp = rfftn(a, axes=axes, threads=nthreads(), **_fft_extra_args) def _fill_complex_array(tmp, res, axes): @@ -293,7 +295,6 @@ def my_fftn_r2c(a, axes=None): def my_fftn(a, axes=None): - from pyfftw.interfaces.numpy_fft import fftn return fftn(a, axes=axes, **_fft_extra_args) diff --git a/test/test_energies/test_map.py b/test/test_energies/test_map.py index a88a4b272..ae53e8259 100644 --- a/test/test_energies/test_map.py +++ b/test/test_energies/test_map.py @@ -56,18 +56,18 @@ class Energy_Tests(unittest.TestCase): def d_model(): if nonlinearity == "": - return R.chain(ht.chain(ift.makeOp(A))) + return R(ht(ift.makeOp(A))) else: - tmp = ht.chain(ift.makeOp(A)) + tmp = ht(ift.makeOp(A)) nonlin = getattr(tmp, nonlinearity)() - return R.chain(nonlin) + return R(nonlin) d = d_model()(xi0) + n if noise == 1: N = None - energy = ift.GaussianEnergy(d, N).chain(d_model()) + energy = ift.GaussianEnergy(d, N)(d_model()) if nonlinearity == "": ift.extra.check_value_gradient_metric_consistency( energy, xi0, ntries=10) diff --git a/test/test_field.py b/test/test_field.py index e08463c90..cfc6f7d0e 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -66,7 +66,7 @@ class Test_Functionality(unittest.TestCase): op1 = ift.create_power_operator((space1, space2), _spec1, 0) op2 = ift.create_power_operator((space1, space2), _spec2, 1) - opfull = op2.chain(op1) + opfull = op2(op1) samples = 500 sc1 = ift.StatCalculator() @@ -94,7 +94,7 @@ class Test_Functionality(unittest.TestCase): S_1 = ift.create_power_operator((space1, space2), _spec1, 0) S_2 = ift.create_power_operator((space1, space2), _spec2, 1) - S_full = S_2.chain(S_1) + S_full = S_2(S_1) samples = 500 sc1 = ift.StatCalculator() diff --git a/test/test_models/test_model_gradients.py b/test/test_models/test_model_gradients.py index 8f1bdf839..6a3abd2d5 100644 --- a/test/test_models/test_model_gradients.py +++ b/test/test_models/test_model_gradients.py @@ -71,16 +71,16 @@ class Model_Tests(unittest.TestCase): model = ift.FieldAdapter(dom, "s1")*3. pos = ift.from_random("normal", dom) ift.extra.check_value_gradient_consistency(model, pos) - model = ift.ScalingOperator(2.456, space).chain( + model = ift.ScalingOperator(2.456, space)( ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")) pos = ift.from_random("normal", dom) ift.extra.check_value_gradient_consistency(model, pos) - model = ift.positive_tanh(ift.ScalingOperator(2.456, space).chain( + model = ift.positive_tanh(ift.ScalingOperator(2.456, space)( ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))) pos = ift.from_random("normal", dom) ift.extra.check_value_gradient_consistency(model, pos) if isinstance(space, ift.RGSpace): - model = ift.FFTOperator(space).chain( + model = ift.FFTOperator(space)( ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")) pos = ift.from_random("normal", dom) ift.extra.check_value_gradient_consistency(model, pos) diff --git a/test/test_multi_field.py b/test/test_multi_field.py index f2dbdf082..d9854bded 100644 --- a/test/test_multi_field.py +++ b/test/test_multi_field.py @@ -40,7 +40,7 @@ class Test_Functionality(unittest.TestCase): def test_blockdiagonal(self): op = ift.BlockDiagonalOperator( dom, (ift.ScalingOperator(20., dom["d1"]),)) - op2 = op.chain(op) + op2 = op(op) ift.extra.consistency_check(op2) assert_equal(type(op2), ift.BlockDiagonalOperator) f1 = op2(ift.full(dom, 1)) diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index e56ee11df..540ec34d8 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -53,7 +53,7 @@ class Consistency_Tests(unittest.TestCase): dtype=dtype)) op = ift.SandwichOperator.make(a, b) ift.extra.consistency_check(op, dtype, dtype) - op = a.chain(b) + op = a(b) ift.extra.consistency_check(op, dtype, dtype) op = a+b ift.extra.consistency_check(op, dtype, dtype) diff --git a/test/test_operators/test_composed_operator.py b/test/test_operators/test_composed_operator.py index 81bb1acd4..46a0cd8cf 100644 --- a/test/test_operators/test_composed_operator.py +++ b/test/test_operators/test_composed_operator.py @@ -37,7 +37,7 @@ class ComposedOperator_Tests(unittest.TestCase): op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,)) op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,)) - op = op2.chain(op1) + op = op2(op1) rand1 = ift.Field.from_random('normal', domain=(space1, space2)) rand2 = ift.Field.from_random('normal', domain=(space1, space2)) @@ -54,7 +54,7 @@ class ComposedOperator_Tests(unittest.TestCase): op1 = ift.DiagonalOperator(diag1, cspace, spaces=(0,)) op2 = ift.DiagonalOperator(diag2, cspace, spaces=(1,)) - op = op2.chain(op1) + op = op2(op1) rand1 = ift.Field.from_random('normal', domain=(space1, space2)) tt1 = op.inverse_times(op.times(rand1)) @@ -75,8 +75,7 @@ class ComposedOperator_Tests(unittest.TestCase): def test_chain(self, space): op1 = ift.makeOp(ift.Field.full(space, 2.)) op2 = 3. - full_op = (op1.chain(op2).chain(op2).chain(op1). - chain(op1).chain(op1).chain(op2)) + full_op = op1(op2)(op2)(op1)(op1)(op1)(op2) x = ift.Field.full(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) @@ -86,7 +85,7 @@ class ComposedOperator_Tests(unittest.TestCase): def test_mix(self, space): op1 = ift.makeOp(ift.Field.full(space, 2.)) op2 = 3. - full_op = op1.chain(op2 + op2).chain(op1).chain(op1) - op1.chain(op2) + full_op = op1(op2+op2)(op1)(op1) - op1(op2) x = ift.Field.full(space, 1.) res = full_op(x) assert_equal(isinstance(full_op, ift.DiagonalOperator), True) -- GitLab