There is a maintenance of MPCDF Gitlab on Thursday, April 22st 2020, 9:00 am CEST - Expect some service interruptions during this time

Commit 5a460b08 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'no_implicit_op_conversion' into 'NIFTy_4'

No implicit op conversion

See merge request ift/NIFTy!264
parents 5288f3bd 127e4a3e
Pipeline #30884 passed with stages
in 9 minutes and 36 seconds
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..sugar import exp from ..sugar import exp, makeOp
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..operators.smoothness_operator import SmoothnessOperator from ..operators.smoothness_operator import SmoothnessOperator
from ..operators.inversion_enabler import InversionEnabler from ..operators.inversion_enabler import InversionEnabler
...@@ -27,8 +27,9 @@ def _LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau, ...@@ -27,8 +27,9 @@ def _LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau,
xi): xi):
power = exp(0.5*tau) power = exp(0.5*tau)
position = ht(Distributor(power)*xi) position = ht(Distributor(power)*xi)
linearization = nonlinearity.derivative(position) linearization = makeOp(nonlinearity.derivative(position))
return 0.5*Instrument*linearization*ht*xi*Distributor*power return (0.5 * Instrument * linearization * ht * makeOp(xi) * Distributor *
makeOp(power))
class NonlinearPowerEnergy(Energy): class NonlinearPowerEnergy(Energy):
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from .wiener_filter_curvature import WienerFilterCurvature from .wiener_filter_curvature import WienerFilterCurvature
from ..utilities import memo from ..utilities import memo
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..sugar import makeOp
class NonlinearWienerFilterEnergy(Energy): class NonlinearWienerFilterEnergy(Energy):
...@@ -39,7 +40,8 @@ class NonlinearWienerFilterEnergy(Energy): ...@@ -39,7 +40,8 @@ class NonlinearWienerFilterEnergy(Energy):
t1 = S.inverse_times(position) t1 = S.inverse_times(position)
t2 = N.inverse_times(residual) t2 = N.inverse_times(residual)
self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real
self.R = Instrument * nonlinearity.derivative(m) * ht * power self.R = (Instrument * makeOp(nonlinearity.derivative(m)) *
ht * makeOp(power))
self._gradient = (t1 - self.R.adjoint_times(t2)).lock() self._gradient = (t1 - self.R.adjoint_times(t2)).lock()
def at(self, position): def at(self, position):
......
...@@ -118,12 +118,9 @@ class LinearOperator(NiftyMetaBase()): ...@@ -118,12 +118,9 @@ class LinearOperator(NiftyMetaBase()):
@staticmethod @staticmethod
def _toOperator(thing, dom): def _toOperator(thing, dom):
from .diagonal_operator import DiagonalOperator
from .scaling_operator import ScalingOperator from .scaling_operator import ScalingOperator
if isinstance(thing, LinearOperator): if isinstance(thing, LinearOperator):
return thing return thing
if isinstance(thing, Field):
return DiagonalOperator(thing)
if np.isscalar(thing): if np.isscalar(thing):
return ScalingOperator(thing, dom) return ScalingOperator(thing, dom)
return NotImplemented return NotImplemented
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
from .domains.power_space import PowerSpace from .domains.power_space import PowerSpace
from .field import Field from .field import Field
from .multi.multi_field import MultiField from .multi.multi_field import MultiField
from .multi.block_diagonal_operator import BlockDiagonalOperator
from .multi.multi_domain import MultiDomain from .multi.multi_domain import MultiDomain
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor from .operators.power_distributor import PowerDistributor
...@@ -32,7 +33,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator', ...@@ -32,7 +33,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random', 'create_harmonic_smoothing_operator', 'from_random',
'full', 'empty', 'from_global_data', 'from_local_data', 'full', 'empty', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate', 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate',
'get_signal_variance'] 'get_signal_variance', 'makeOp']
def PS_field(pspace, func): def PS_field(pspace, func):
...@@ -232,6 +233,14 @@ def makeDomain(domain): ...@@ -232,6 +233,14 @@ def makeDomain(domain):
return DomainTuple.make(domain) return DomainTuple.make(domain)
def makeOp(input):
if isinstance(input, Field):
return DiagonalOperator(input)
if isinstance(input, MultiField):
return BlockDiagonalOperator({key: makeOp(val)
for key, val in input.items()})
raise NotImplementedError
# Arithmetic functions working on Fields # Arithmetic functions working on Fields
_current_module = sys.modules[__name__] _current_module = sys.modules[__name__]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment