diff --git a/nifty4/library/nonlinear_power_energy.py b/nifty4/library/nonlinear_power_energy.py index f7b90d8fe492ba2a3da4edaebddc4d4bc88b1738..5d28126d963253b1fa7e4b60d24b6a8cce68ec6c 100644 --- a/nifty4/library/nonlinear_power_energy.py +++ b/nifty4/library/nonlinear_power_energy.py @@ -28,8 +28,8 @@ def _LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau, power = exp(0.5*tau) position = ht(Distributor(power)*xi) linearization = makeOp(nonlinearity.derivative(position)) - return (makeOp(0.5, Instrument.target) * Instrument * linearization * ht * - makeOp(xi) * Distributor * makeOp(power)) + return (0.5 * Instrument * linearization * ht * makeOp(xi) * Distributor * + makeOp(power)) class NonlinearPowerEnergy(Energy): @@ -138,6 +138,5 @@ class NonlinearPowerEnergy(Energy): self.position, xi_sample) op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse result = op if result is None else result + op - result = (result*makeOp(1./len(self.xi_sample_list), result.domain) + - self.T) + result = result*(1./len(self.xi_sample_list)) + self.T return InversionEnabler(result, self.inverter) diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py index 167f9f7021cbc9292aa0e7ea8a9a8360e19c55b7..47661b6fb6f6c0388bd34f76f571d844461adf40 100644 --- a/nifty4/operators/linear_operator.py +++ b/nifty4/operators/linear_operator.py @@ -116,18 +116,51 @@ class LinearOperator(NiftyMetaBase()): the adjoint of this operator.""" return self._flip_modes(self.ADJOINT_BIT) + @staticmethod + def _toOperator(thing, dom): + from .scaling_operator import ScalingOperator + if isinstance(thing, LinearOperator): + return thing + if np.isscalar(thing): + return ScalingOperator(thing, dom) + return NotImplemented + def __mul__(self, other): from .chain_operator import ChainOperator + if np.isscalar(other) and other == 1.: + return self + other = self._toOperator(other, self.domain) return ChainOperator.make([self, other]) + def __rmul__(self, other): + from .chain_operator import ChainOperator + if np.isscalar(other) and other == 1.: + return self + other = self._toOperator(other, self.target) + return ChainOperator.make([other, self]) + def __add__(self, other): from .sum_operator import SumOperator + if np.isscalar(other) and other == 0.: + return self + other = self._toOperator(other, self.domain) return SumOperator.make([self, other], [False, False]) + def __radd__(self, other): + return self.__add__(other) + def __sub__(self, other): from .sum_operator import SumOperator + if np.isscalar(other) and other == 0.: + return self + other = self._toOperator(other, self.domain) return SumOperator.make([self, other], [False, True]) + def __rsub__(self, other): + from .sum_operator import SumOperator + other = self._toOperator(other, self.domain) + return SumOperator.make([other, self], [False, True]) + @abc.abstractproperty def capability(self): """int : the supported operation modes diff --git a/nifty4/operators/smoothness_operator.py b/nifty4/operators/smoothness_operator.py index 58b60b41439b1f78d5895c930ac6cd53485c26ae..ba6aa76a3a18e4abe3b764c7f033e3171b03ee22 100644 --- a/nifty4/operators/smoothness_operator.py +++ b/nifty4/operators/smoothness_operator.py @@ -18,7 +18,6 @@ from .scaling_operator import ScalingOperator from .laplace_operator import LaplaceOperator -from ..sugar import makeOp def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None): @@ -52,4 +51,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None): if strength == 0.: return ScalingOperator(0., domain) laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space) - return makeOp(strength**2, laplace.domain)*laplace.adjoint*laplace + return (strength**2)*laplace.adjoint*laplace diff --git a/nifty4/sugar.py b/nifty4/sugar.py index 33db00de2fd61a3498e43654ef4eb2a0f222429b..3775baa922819786455667c09e05cafd253f4293 100644 --- a/nifty4/sugar.py +++ b/nifty4/sugar.py @@ -24,7 +24,6 @@ from .multi.multi_field import MultiField from .multi.block_diagonal_operator import BlockDiagonalOperator from .multi.multi_domain import MultiDomain from .operators.diagonal_operator import DiagonalOperator -from .operators.scaling_operator import ScalingOperator from .operators.power_distributor import PowerDistributor from .domain_tuple import DomainTuple from . import dobj, utilities @@ -234,16 +233,12 @@ def makeDomain(domain): return DomainTuple.make(domain) -def makeOp(input, domain=None): +def makeOp(input): if isinstance(input, Field): return DiagonalOperator(input) if isinstance(input, MultiField): return BlockDiagonalOperator({key: makeOp(val) for key, val in input.items()}) - if np.isscalar(input): - if domain is None: - raise ValueError("domain needs to be set") - return ScalingOperator(input, domain) raise NotImplementedError # Arithmetic functions working on Fields