Commit 2db45ddb authored by Martin Reinecke's avatar Martin Reinecke

demo implementation

parent 5288f3bd
Pipeline #30607 passed with stages
in 1 minute and 29 seconds
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..sugar import exp
from ..sugar import exp, makeOp
from ..minimization.energy import Energy
from ..operators.smoothness_operator import SmoothnessOperator
from ..operators.inversion_enabler import InversionEnabler
......@@ -27,8 +27,9 @@ def _LinearizedPowerResponse(Instrument, nonlinearity, ht, Distributor, tau,
xi):
power = exp(0.5*tau)
position = ht(Distributor(power)*xi)
linearization = nonlinearity.derivative(position)
return 0.5*Instrument*linearization*ht*xi*Distributor*power
linearization = makeOp(nonlinearity.derivative(position))
return (makeOp(0.5, Instrument.target) * Instrument * linearization * ht *
makeOp(xi) * Distributor * makeOp(power))
class NonlinearPowerEnergy(Energy):
......@@ -137,5 +138,6 @@ class NonlinearPowerEnergy(Energy):
self.position, xi_sample)
op = LinearizedResponse.adjoint*self.N.inverse*LinearizedResponse
result = op if result is None else result + op
result = result*(1./len(self.xi_sample_list)) + self.T
result = (result*makeOp(1./len(self.xi_sample_list), result.domain) +
self.T)
return InversionEnabler(result, self.inverter)
......@@ -19,6 +19,7 @@
from .wiener_filter_curvature import WienerFilterCurvature
from ..utilities import memo
from ..minimization.energy import Energy
from ..sugar import makeOp
class NonlinearWienerFilterEnergy(Energy):
......@@ -39,7 +40,8 @@ class NonlinearWienerFilterEnergy(Energy):
t1 = S.inverse_times(position)
t2 = N.inverse_times(residual)
self._value = 0.5 * (position.vdot(t1) + residual.vdot(t2)).real
self.R = Instrument * nonlinearity.derivative(m) * ht * power
self.R = (Instrument * makeOp(nonlinearity.derivative(m)) *
ht * makeOp(power))
self._gradient = (t1 - self.R.adjoint_times(t2)).lock()
def at(self, position):
......
......@@ -116,54 +116,18 @@ class LinearOperator(NiftyMetaBase()):
the adjoint of this operator."""
return self._flip_modes(self.ADJOINT_BIT)
@staticmethod
def _toOperator(thing, dom):
from .diagonal_operator import DiagonalOperator
from .scaling_operator import ScalingOperator
if isinstance(thing, LinearOperator):
return thing
if isinstance(thing, Field):
return DiagonalOperator(thing)
if np.isscalar(thing):
return ScalingOperator(thing, dom)
return NotImplemented
def __mul__(self, other):
from .chain_operator import ChainOperator
if np.isscalar(other) and other == 1.:
return self
other = self._toOperator(other, self.domain)
return ChainOperator.make([self, other])
def __rmul__(self, other):
from .chain_operator import ChainOperator
if np.isscalar(other) and other == 1.:
return self
other = self._toOperator(other, self.target)
return ChainOperator.make([other, self])
def __add__(self, other):
from .sum_operator import SumOperator
if np.isscalar(other) and other == 0.:
return self
other = self._toOperator(other, self.domain)
return SumOperator.make([self, other], [False, False])
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
from .sum_operator import SumOperator
if np.isscalar(other) and other == 0.:
return self
other = self._toOperator(other, self.domain)
return SumOperator.make([self, other], [False, True])
def __rsub__(self, other):
from .sum_operator import SumOperator
other = self._toOperator(other, self.domain)
return SumOperator.make([other, self], [False, True])
@abc.abstractproperty
def capability(self):
"""int : the supported operation modes
......
......@@ -18,6 +18,7 @@
from .scaling_operator import ScalingOperator
from .laplace_operator import LaplaceOperator
from ..sugar import makeOp
def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
......@@ -51,4 +52,4 @@ def SmoothnessOperator(domain, strength=1., logarithmic=True, space=None):
if strength == 0.:
return ScalingOperator(0., domain)
laplace = LaplaceOperator(domain, logarithmic=logarithmic, space=space)
return (strength**2)*laplace.adjoint*laplace
return makeOp(strength**2, laplace.domain)*laplace.adjoint*laplace
......@@ -23,6 +23,7 @@ from .field import Field
from .multi.multi_field import MultiField
from .multi.multi_domain import MultiDomain
from .operators.diagonal_operator import DiagonalOperator
from .operators.scaling_operator import ScalingOperator
from .operators.power_distributor import PowerDistributor
from .domain_tuple import DomainTuple
from . import dobj, utilities
......@@ -32,7 +33,7 @@ __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'empty', 'from_global_data', 'from_local_data',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate',
'get_signal_variance']
'get_signal_variance', 'makeOp']
def PS_field(pspace, func):
......@@ -232,6 +233,18 @@ def makeDomain(domain):
return DomainTuple.make(domain)
def makeOp(input, domain=None):
if isinstance(input, Field):
return DiagonalOperator(input)
if isinstance(input, MultiField):
return BlockDiagonalOperator({key: makeOp(val)
for key, val in input.items()})
if np.isscalar(input):
if domain is None:
raise ValueError("domain needs to be set")
return ScalingOperator(input, domain)
raise NotImplementedError
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment