diff --git a/nifty4/library/nonlinear_power_energy.py b/nifty4/library/nonlinear_power_energy.py index dff636a7934983cb3c494b38bcec5e5351223770..f7b90d8fe492ba2a3da4edaebddc4d4bc88b1738 100644 --- a/nifty4/library/nonlinear_power_energy.py +++ b/nifty4/library/nonlinear_power_energy.py @@ -16,7 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ..sugar import exp +from ..sugar import exp, makeOp from ..minimization.energy import Energy from ..operators.smoothness_operator import SmoothnessOperator from ..operators.inversion_enabler import InversionEnabler @@ -27,8 +27,9 @@ def _LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau, xi): power = exp(0.5*tau) position = ht(Distributor(power)*xi) - linearization = nonlinearity.derivative(position) - return 0.5*Instrument*linearization*ht*xi*Distributor*power + linearization = makeOp(nonlinearity.derivative(position)) + return (makeOp(0.5, Instrument.target) * Instrument * linearization * ht * + makeOp(xi) * Distributor * makeOp(power)) class NonlinearPowerEnergy(Energy): @@ -137,5 +138,6 @@ class NonlinearPowerEnergy(Energy): self.position, xi_sample) op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse result = op if result is None else result + op - result = result*(1./len(self.xi_sample_list)) + self.T + result = (result*makeOp(1./len(self.xi_sample_list), result.domain) + + self.T) return InversionEnabler(result, self.inverter) diff --git a/nifty4/library/nonlinear_wiener_filter_energy.py b/nifty4/library/nonlinear_wiener_filter_energy.py index a74235837c9dbc721538de781ee1967c043eedb0..d54a52cc83081f3afe1fbbcbbd15c9e14abcf89d 100644 --- a/nifty4/library/nonlinear_wiener_filter_energy.py +++ b/nifty4/library/nonlinear_wiener_filter_energy.py @@ -19,6 +19,7 @@ from .wiener_filter_curvature import WienerFilterCurvature from ..utilities import memo from ..minimization.energy import Energy +from ..sugar import makeOp class NonlinearWienerFilterEnergy(Energy): @@ -39,7 +40,8 @@ class NonlinearWienerFilterEnergy(Energy): t1 = S.inverse_times(position) t2 = N.inverse_times(residual) self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real - self.R = Instrument * nonlinearity.derivative(m) * ht * power + self.R = (Instrument * makeOp(nonlinearity.derivative(m)) * + ht * makeOp(power)) self._gradient = (t1 - self.R.adjoint_times(t2)).lock() def at(self, position): diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py index 7799c98a8f25fd92aa0d3567d2c208649afa59a2..167f9f7021cbc9292aa0e7ea8a9a8360e19c55b7 100644 --- a/nifty4/operators/linear_operator.py +++ b/nifty4/operators/linear_operator.py @@ -116,54 +116,18 @@ class LinearOperator(NiftyMetaBase()): the adjoint of this operator.""" return self._flip_modes(self.ADJOINT_BIT) - @staticmethod - def _toOperator(thing, dom): - from .diagonal_operator import DiagonalOperator - from .scaling_operator import ScalingOperator - if isinstance(thing, LinearOperator): - return thing - if isinstance(thing, Field): - return DiagonalOperator(thing) - if np.isscalar(thing): - return ScalingOperator(thing, dom) - return NotImplemented - def __mul__(self, other): from .chain_operator import ChainOperator - if np.isscalar(other) and other == 1.: - return self - other = self._toOperator(other, self.domain) return ChainOperator.make([self, other]) - def __rmul__(self, other): - from .chain_operator import ChainOperator - if np.isscalar(other) and other == 1.: - return self - other = self._toOperator(other, self.target) - return ChainOperator.make([other, self]) - def __add__(self, other): from .sum_operator import SumOperator - if np.isscalar(other) and other == 0.: - return self - other = self._toOperator(other, self.domain) return SumOperator.make([self, other], [False, False]) - def __radd__(self, other): - return self.__add__(other) - def __sub__(self, other): from .sum_operator import SumOperator - if np.isscalar(other) and other == 0.: - return self - other = self._toOperator(other, self.domain) return SumOperator.make([self, other], [False, True]) - def __rsub__(self, other): - from .sum_operator import SumOperator - other = self._toOperator(other, self.domain) - return SumOperator.make([other, self], [False, True]) - @abc.abstractproperty def capability(self): """int : the supported operation modes diff --git a/nifty4/operators/smoothness_operator.py b/nifty4/operators/smoothness_operator.py index ba6aa76a3a18e4abe3b764c7f033e3171b03ee22..58b60b41439b1f78d5895c930ac6cd53485c26ae 100644 --- a/nifty4/operators/smoothness_operator.py +++ b/nifty4/operators/smoothness_operator.py @@ -18,6 +18,7 @@ from .scaling_operator import ScalingOperator from .laplace_operator import LaplaceOperator +from ..sugar import makeOp def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None): @@ -51,4 +52,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None): if strength == 0.: return ScalingOperator(0., domain) laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space) - return (strength**2)*laplace.adjoint*laplace + return makeOp(strength**2, laplace.domain)*laplace.adjoint*laplace diff --git a/nifty4/sugar.py b/nifty4/sugar.py index 8f84ddefed6b2999ab87c69aaeba1c54b37abf66..0dd0be5d8e3af5a25e3519fde6606d62ae082ded 100644 --- a/nifty4/sugar.py +++ b/nifty4/sugar.py @@ -23,6 +23,7 @@ from .field import Field from .multi.multi_field import MultiField from .multi.multi_domain import MultiDomain from .operators.diagonal_operator import DiagonalOperator +from .operators.scaling_operator import ScalingOperator from .operators.power_distributor import PowerDistributor from .domain_tuple import DomainTuple from . import dobj, utilities @@ -32,7 +33,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator', 'create_harmonic_smoothing_operator', 'from_random', 'full', 'empty', 'from_global_data', 'from_local_data', 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate', - 'get_signal_variance'] + 'get_signal_variance', 'makeOp'] def PS_field(pspace, func): @@ -232,6 +233,18 @@ def makeDomain(domain): return DomainTuple.make(domain) +def makeOp(input, domain=None): + if isinstance(input, Field): + return DiagonalOperator(input) + if isinstance(input, MultiField): + return BlockDiagonalOperator({key: makeOp(val) + for key, val in input.items()}) + if np.isscalar(input): + if domain is None: + raise ValueError("domain needs to be set") + return ScalingOperator(input, domain) + raise NotImplementedError + # Arithmetic functions working on Fields _current_module = sys.modules[__name__]