diff --git a/nifty6/operators/linear_operator.py b/nifty6/operators/linear_operator.py index 4c5ec16009b5a7228750f47fc2f22dcec42f6d6a..9d10390b1d0fa94318a593c0389d98a001618779 100644 --- a/nifty6/operators/linear_operator.py +++ b/nifty6/operators/linear_operator.py @@ -174,10 +174,7 @@ class LinearOperator(Operator): return self.apply(x, self.TIMES) from ..linearization import Linearization if isinstance(x, Linearization): - res = x.new(self(x._val), self(x._jac)) - if x.metric is not None: - res = res.add_metric(self(x.metric)) - return res + return x.new(self(x._val), self(x._jac)) return self.__matmul__(x) def times(self, x): diff --git a/nifty6/operators/scaling_operator.py b/nifty6/operators/scaling_operator.py index 8c6ca4c5f24f08687b7a26b587745b54e5dc41b5..5a00babdd646706d9f29bd4565f6aef6ff563f16 100644 --- a/nifty6/operators/scaling_operator.py +++ b/nifty6/operators/scaling_operator.py @@ -98,14 +98,19 @@ class ScalingOperator(EndomorphicOperator): return from_random(random_type="normal", domain=self._domain, std=self._get_fct(from_inverse), dtype=dtype) - def __matmul__(self, other): - if np.isreal(self._factor) and self._factor > 0: - from .sandwich_operator import SandwichOperator - if isinstance(other, SandwichOperator): - sqrt_fac = np.sqrt(self._factor) - newop = ScalingOperator(other.domain, sqrt_fac) - return SandwichOperator.make(newop, other) - return EndomorphicOperator.__matmul__(self, other) + def __call__(self, other): + if np.isreal(self._factor) and self._factor >= 0: + from ..linearization import Linearization + if isinstance(other, Linearization): + res = EndomorphicOperator.__call__(self, other) + if other.metric is not None: + from .sandwich_operator import SandwichOperator + sqrt_fac = np.sqrt(self._factor) + newop = ScalingOperator(other.metric.domain, sqrt_fac) + met = SandwichOperator.make(newop, other.metric) + res = res.add_metric(met) + return res + return EndomorphicOperator.__call__(self, other) def __repr__(self): return "ScalingOperator ({})".format(self._factor) diff --git a/test/test_energy_gradients.py b/test/test_energy_gradients.py index 69375a8ded4012f9f1f6d78c812f55f82cde5d4d..5c82b61a0e52b1e905f39e086c57d5599fb7ca55 100644 --- a/test/test_energy_gradients.py +++ b/test/test_energy_gradients.py @@ -32,6 +32,7 @@ SPACES = [ ] SEEDS = [4, 78, 23] PARAMS = product(SEEDS, SPACES) +pmp = pytest.mark.parametrize @pytest.fixture(params=PARAMS) @@ -46,8 +47,11 @@ def test_gaussian(field): energy = ift.GaussianEnergy(domain=field.domain) ift.extra.check_jacobian_consistency(energy, field) -def test_ScaledEnergy(field): - energy = ift.GaussianEnergy(domain=field.domain) +@pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.), + lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))]) +def test_ScaledEnergy(field, icov): + icov = icov(field.domain) + energy = ift.GaussianEnergy(inverse_covariance=icov) ift.extra.check_jacobian_consistency(energy.scale(0.3), field) lin = ift.Linearization.make_var(field, want_metric=True) @@ -55,8 +59,8 @@ def test_ScaledEnergy(field): sE = energy.scale(0.3) linn = sE(lin) met2 = linn.metric - assert np.assert_allclose(met1(field), met2(field) / 0.3, rtol=1e-12) - assert isinstance(met2, ift.SandwichOperator) + np.testing.assert_allclose(met1(field).val, met2(field).val / 0.3, rtol=1e-12) + met2.draw_sample() def test_studentt(field): energy = ift.StudentTEnergy(domain=field.domain, theta=.5)