diff --git a/nifty5/linearization.py b/nifty5/linearization.py index 7602463f6ac0e2491c0037aba1cfe06a8bb234b2..baf500ccc6a594907e3e33fab3a4656ec260d803 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -47,8 +47,8 @@ class Linearization(object): def __neg__(self): return Linearization( - -self._val, self._jac*(-1), - None if self._metric is None else self._metric*(-1)) + -self._val, -self._jac, + None if self._metric is None else -self._metric) def __add__(self, other): if isinstance(other, Linearization): @@ -74,28 +74,19 @@ class Linearization(object): def __mul__(self, other): from .sugar import makeOp if isinstance(other, Linearization): - d1 = makeOp(self._val) - d2 = makeOp(other._val) return Linearization( self._val*other._val, - d2(self._jac) + d1(other._jac)) - if isinstance(other, (int, float, complex)): - # if other == 0: - # return ... - met = None if self._metric is None else self._metric(other) - return Linearization(self._val*other, self._jac(other), met) + makeOp(other._val)(self._jac) + makeOp(self._val)(other._jac)) + if np.isscalar(other): + if other == 1: + return self + met = None if self._metric is None else self._metric.scale(other) + return Linearization(self._val*other, self._jac.scale(other), met) if isinstance(other, (Field, MultiField)): - d2 = makeOp(other) - return Linearization(self._val*other, d2(self._jac)) - raise TypeError + return Linearization(self._val*other, makeOp(other)(self._jac)) def __rmul__(self, other): - from .sugar import makeOp - if isinstance(other, (int, float, complex)): - return Linearization(self._val*other, self._jac(other)) - if isinstance(other, (Field, MultiField)): - d1 = makeOp(other) - return Linearization(self._val*other, d1(self._jac)) + return self.__mul__(other) def vdot(self, other): from .domain_tuple import DomainTuple diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py index 6eb5a74440989712499733cbf786f82d92b5c459..a74b4292eba9ea0df18561de0d1f651adc17a1f0 100644 --- a/nifty5/operators/linear_operator.py +++ b/nifty5/operators/linear_operator.py @@ -104,75 +104,38 @@ class LinearOperator(Operator): the adjoint of this operator.""" return self._flip_modes(self.ADJOINT_BIT) - @staticmethod - def _toOperator(thing, dom): - from .scaling_operator import ScalingOperator - if isinstance(thing, LinearOperator): - return thing - if np.isscalar(thing): - return ScalingOperator(thing, dom) - return NotImplemented - - def __mul__(self, other): - from .chain_operator import ChainOperator - if not np.isscalar(other): - return Operator.__mul__(self, other) - if other == 1.: - return self - from .scaling_operator import ScalingOperator - other = ScalingOperator(other, self.domain) - return ChainOperator.make([self, other]) - - def __rmul__(self, other): - from .chain_operator import ChainOperator - if not np.isscalar(other): - return Operator.__rmul__(self, other) - if other == 1.: - return self - from .scaling_operator import ScalingOperator - other = ScalingOperator(other, self.target) - return ChainOperator.make([other, self]) - def __matmul__(self, other): - if np.isscalar(other) and other == 1.: - return self - other2 = self._toOperator(other, self.domain) - if other2 == NotImplemented: - return Operator.__matmul__(self, other) - from .chain_operator import ChainOperator - return ChainOperator.make([self, other2]) + if isinstance(other, LinearOperator): + from .chain_operator import ChainOperator + return ChainOperator.make([self, other]) + return Operator.__matmul__(self, other) def __rmatmul__(self, other): - if np.isscalar(other) and other == 1.: - return self - other2 = self._toOperator(other, self.target) - if other2 == NotImplemented: + if isinstance(other, LinearOperator): from .chain_operator import ChainOperator - return Operator.__rmatmul__(self, other) - from .chain_operator import ChainOperator - return ChainOperator.make([other2, self]) + return ChainOperator.make([other, self]) + return Operator.__rmatmul__(self, other) def __add__(self, other): - from .sum_operator import SumOperator - if np.isscalar(other) and other == 0.: - return self - other = self._toOperator(other, self.domain) - return SumOperator.make([self, other], [False, False]) + if isinstance(other, LinearOperator): + from .sum_operator import SumOperator + return SumOperator.make([self, other], [False, False]) + return Operator.__add__(self, other) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): - from .sum_operator import SumOperator - if np.isscalar(other) and other == 0.: - return self - other = self._toOperator(other, self.domain) - return SumOperator.make([self, other], [False, True]) + if isinstance(other, LinearOperator): + from .sum_operator import SumOperator + return SumOperator.make([self, other], [False, True]) + return Operator.__sub__(self, other) def __rsub__(self, other): - from .sum_operator import SumOperator - other = self._toOperator(other, self.domain) - return SumOperator.make([other, self], [False, True]) + if isinstance(other, LinearOperator): + from .sum_operator import SumOperator + return SumOperator.make([other, self], [False, True]) + return Operator.__rsub__(self, other) @property def capability(self): diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 1afffc79ff256c5270e1adf90a7d7a98d97c70a7..86a09f65f1e0567b5f9f3e4a7c7b44fe0b7bb8f3 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -24,6 +24,19 @@ class Operator(NiftyMetaBase()): The domain on which the Operator's output Field lives.""" raise NotImplementedError + def scale(self, factor): + if factor == 1: + return self + from .scaling_operator import ScalingOperator + return ScalingOperator(factor, self.target)(self) + + def conjugate(self): + from .simple_linear_operators import ConjugationOperator + return ConjugationOperator(self.target)(self) + + def __neg__(self): + return self.scale(-1) + def __matmul__(self, x): if not isinstance(x, Operator): return NotImplemented diff --git a/nifty5/operators/smoothness_operator.py b/nifty5/operators/smoothness_operator.py index 38e068730060f5633cc73ea36b2499e4bee5bdc3..113f6301f3237a510b8dad084edff4217c742b93 100644 --- a/nifty5/operators/smoothness_operator.py +++ b/nifty5/operators/smoothness_operator.py @@ -54,4 +54,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None): if strength == 0.: return ScalingOperator(0., domain) laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space) - return (strength**2)*laplace.adjoint(laplace) + return laplace.adjoint(laplace).scale(strength**2) diff --git a/test/test_models/test_model_gradients.py b/test/test_models/test_model_gradients.py index 6a3abd2d5a72dba98c95171a2c9a58758c657a23..4849857f939552a58098530f3e9fd56b952ad1fa 100644 --- a/test/test_models/test_model_gradients.py +++ b/test/test_models/test_model_gradients.py @@ -68,7 +68,7 @@ class Model_Tests(unittest.TestCase): model = ift.FieldAdapter(dom, "s1")+ift.FieldAdapter(dom, "s2") pos = ift.from_random("normal", dom) ift.extra.check_value_gradient_consistency(model, pos) - model = ift.FieldAdapter(dom, "s1")*3. + model = ift.FieldAdapter(dom, "s1").scale(3.) pos = ift.from_random("normal", dom) ift.extra.check_value_gradient_consistency(model, pos) model = ift.ScalingOperator(2.456, space)( diff --git a/test/test_operators/test_composed_operator.py b/test/test_operators/test_composed_operator.py index 46a0cd8cf91bce6956ce0941914e9031af29cbd5..0efde7a4ce7ddab5a7cf5d6768680c0579da1101 100644 --- a/test/test_operators/test_composed_operator.py +++ b/test/test_operators/test_composed_operator.py @@ -64,7 +64,7 @@ class ComposedOperator_Tests(unittest.TestCase): @expand(product(spaces)) def test_sum(self, space): op1 = ift.makeOp(ift.Field.full(space, 2.)) - op2 = 3. + op2 = ift.ScalingOperator(3., space) full_op = op1 + op2 - (op2 - op1) + op1 + op1 + op2 x = ift.Field.full(space, 1.) res = full_op(x) @@ -74,7 +74,7 @@ class ComposedOperator_Tests(unittest.TestCase): @expand(product(spaces)) def test_chain(self, space): op1 = ift.makeOp(ift.Field.full(space, 2.)) - op2 = 3. + op2 = ift.ScalingOperator(3., space) full_op = op1(op2)(op2)(op1)(op1)(op1)(op2) x = ift.Field.full(space, 1.) res = full_op(x) @@ -84,7 +84,7 @@ class ComposedOperator_Tests(unittest.TestCase): @expand(product(spaces)) def test_mix(self, space): op1 = ift.makeOp(ift.Field.full(space, 2.)) - op2 = 3. + op2 = ift.ScalingOperator(3., space) full_op = op1(op2+op2)(op1)(op1) - op1(op2) x = ift.Field.full(space, 1.) res = full_op(x)