diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 712cd3d96801b4a427a043beafd0497d524d49bd..a674ccd5f03ac66d268cdd38d2ca46e1519b9908 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -219,10 +219,16 @@ class _ConstantOperator(Operator): def apply(self, x): from ..linearization import Linearization from .simple_linear_operators import NullOperator + from ..domain_tuple import DomainTuple self._check_input(x) if not isinstance(x, Linearization): return self._output - return x.new(self._output, NullOperator(self._domain, self._target)) + if x.want_metric and self._target is DomainTuple.scalar_domain(): + met = NullOperator(self._domain, self._domain) + else: + met = None + return x.new(self._output, NullOperator(self._domain, self._target), + met) def __repr__(self): return 'ConstantOperator <- {}'.format(self.domain.keys()) diff --git a/test/test_operators/test_simplification.py b/test/test_operators/test_simplification.py index 9306a9ebe8a9b144ea5f8f800a31617fdabe87a1..bce790f27c226383479802e13239395b31f17573 100644 --- a/test/test_operators/test_simplification.py +++ b/test/test_operators/test_simplification.py @@ -48,3 +48,8 @@ def test_simplification(): assert_equal(isinstance(op2._op1, _ConstantOperator), True) assert_allclose(op(f1)["a"].local_data, op2(f1)["a"].local_data) assert_allclose(op(f1)["b"].local_data, op2(f1)["b"].local_data) + lin = ift.Linearization.make_var(ift.MultiField.full(op2.domain, 2.), True) + assert_allclose(op(lin).val["a"].local_data, + op2(lin).val["a"].local_data) + assert_allclose(op(lin).val["b"].local_data, + op2(lin).val["b"].local_data)