Commit 5a460b08 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'no_implicit_op_conversion' into 'NIFTy_4'

No implicit op conversion

See merge request ift/NIFTy!264
parents 5288f3bd 127e4a3e
Pipeline #30884 passed with stages
in 9 minutes and 36 seconds
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..sugar import exp
from ..sugar import exp, makeOp
from ..minimization.energy import Energy
from ..operators.smoothness_operator import SmoothnessOperator
from ..operators.inversion_enabler import InversionEnabler
......@@ -27,8 +27,9 @@ def _LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau,
xi):
power = exp(0.5*tau)
position = ht(Distributor(power)*xi)
linearization = nonlinearity.derivative(position)
return 0.5*Instrument*linearization*ht*xi*Distributor*power
linearization = makeOp(nonlinearity.derivative(position))
return (0.5 * Instrument * linearization * ht * makeOp(xi) * Distributor *
makeOp(power))
class NonlinearPowerEnergy(Energy):
......
......@@ -19,6 +19,7 @@
from .wiener_filter_curvature import WienerFilterCurvature
from ..utilities import memo
from ..minimization.energy import Energy
from ..sugar import makeOp
class NonlinearWienerFilterEnergy(Energy):
......@@ -39,7 +40,8 @@ class NonlinearWienerFilterEnergy(Energy):
t1 = S.inverse_times(position)
t2 = N.inverse_times(residual)
self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real
self.R = Instrument * nonlinearity.derivative(m) * ht * power
self.R = (Instrument * makeOp(nonlinearity.derivative(m)) *
ht * makeOp(power))
self._gradient = (t1 - self.R.adjoint_times(t2)).lock()
def at(self, position):
......
......@@ -118,12 +118,9 @@ class LinearOperator(NiftyMetaBase()):
@staticmethod
def _toOperator(thing, dom):
from .diagonal_operator import DiagonalOperator
from .scaling_operator import ScalingOperator
if isinstance(thing, LinearOperator):
return thing
if isinstance(thing, Field):
return DiagonalOperator(thing)
if np.isscalar(thing):
return ScalingOperator(thing, dom)
return NotImplemented
......
......@@ -21,6 +21,7 @@ import numpy as np
from .domains.power_space import PowerSpace
from .field import Field
from .multi.multi_field import MultiField
from .multi.block_diagonal_operator import BlockDiagonalOperator
from .multi.multi_domain import MultiDomain
from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor
......@@ -32,7 +33,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'empty', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate',
'get_signal_variance']
'get_signal_variance', 'makeOp']
def PS_field(pspace, func):
......@@ -232,6 +233,14 @@ def makeDomain(domain):
return DomainTuple.make(domain)
def makeOp(input):
if isinstance(input, Field):
return DiagonalOperator(input)
if isinstance(input, MultiField):
return BlockDiagonalOperator({key: makeOp(val)
for key, val in input.items()})
raise NotImplementedError
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment