diff --git a/nifty/library/nonlinear_signal_energy.py b/nifty/library/nonlinear_signal_energy.py index 9e6aed8e98ee64a2305a386f1639020a18caafd0..a62b9f625722b34307f780092ae1e91035ee7515 100644 --- a/nifty/library/nonlinear_signal_energy.py +++ b/nifty/library/nonlinear_signal_energy.py @@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy): @property @memo def gradient(self): - gradient = self._t1.copy() - gradient -= self.LinearizedResponse.adjoint_times(self._t2) - return gradient + return self._t1 - self.LinearizedResponse.adjoint_times(self._t2) @property @memo diff --git a/nifty/library/nonlinearities.py b/nifty/library/nonlinearities.py index 10ab15f02ee371f513ac263bb98342be165b46db..baaefdb6d09873cfcc0011e9c46e4a18807d4c9d 100644 --- a/nifty/library/nonlinearities.py +++ b/nifty/library/nonlinearities.py @@ -1,13 +1,12 @@ -from numpy import logical_and, where -from .. import Field, exp, tanh +from ..field import Field, exp, tanh -class Linear: +class Linear(object): def __call__(self, x): return x def derivative(self, x): - return 1 + return Field.ones_like(x) class Exponential(object): @@ -26,55 +25,9 @@ class Tanh(object): return (1. - tanh(x)**2) -class PositiveTanh: +class PositiveTanh(object): def __call__(self, x): return 0.5 * tanh(x) + 0.5 def derivative(self, x): return 0.5 * (1. - tanh(x)**2) - - -class LinearThenQuadraticWithJump(object): - def __call__(self, x): - dom = x.domain - x = x.copy().val.get_full_data() - cond = where(x > 0.) - not_cond = where(x <= 0.) - x[cond] **= 2 - x[not_cond] -= 1 - return Field(domain=dom, val=x) - - def derivative(self, x): - dom = x.domain - x = x.copy().val.get_full_data() - cond = where(x > 0.) - not_cond = where(x <= 0.) - x[cond] *= 2 - x[not_cond] = 1 - return Field(domain=dom, val=x) - - -class ReallyStupidNonlinearity(object): - def __call__(self, x): - dom = x.domain - x = x.copy().val.get_full_data() - cond1 = where(logical_and(x > 0., x < .5)) - cond2 = where(x >= .5) - not_cond = where(x <= 0.) - x[cond2] -= 0.5 - x[cond2] **= 2 - x[cond1] = 0. - x[not_cond] -= 1 - return Field(domain=dom, val=x) - - def derivative(self, x): - dom = x.domain - x = x.copy().val.get_full_data() - cond1 = where(logical_and(x > 0., x < 0.5)) - cond2 = where(x > .5) - not_cond = where(x <= 0.) - x[cond2] -= 0.5 - x[cond2] *= 2 - x[cond1] = 0. - x[not_cond] = 1 - return Field(domain=dom, val=x) diff --git a/nifty/library/response_operators.py b/nifty/library/response_operators.py index c080df1ed7e095b554ffc283d8003a8dceb7ebf8..f2f05c129672310d93da2b31e9206ed707342ffe 100644 --- a/nifty/library/response_operators.py +++ b/nifty/library/response_operators.py @@ -1,31 +1,35 @@ -from .. import exp +from ..field import exp from ..operators.linear_operator import LinearOperator class LinearizedSignalResponse(LinearOperator): def __init__(self, Instrument, nonlinearity, FFT, power, m): super(LinearizedSignalResponse, self).__init__() - self._target = Instrument.target - self._domain = FFT.target self.Instrument = Instrument self.FFT = FFT self.power = power - position = FFT.adjoint_times(self.power * m) + position = FFT.adjoint_times(self.power*m) self.linearization = nonlinearity.derivative(position) def _times(self, x): - return self.Instrument(self.linearization * self.FFT.adjoint_times(self.power * x)) + tmp = self.FFT.adjoint_times(self.power*x) + tmp *= self.linearization + return self.Instrument(tmp) def _adjoint_times(self, x): - return self.power * self.FFT(self.linearization * self.Instrument.adjoint_times(x)) + tmp = self.Instrument.adjoint_times(x) + tmp *= self.linearization + tmp = self.FFT(tmp) + tmp *= self.power + return tmp @property def domain(self): - return self._domain + return self.FFT.target @property def target(self): - return self._target + return self.Instrument.target @property def unitary(self): @@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator): class LinearizedPowerResponse(LinearOperator): def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m): super(LinearizedPowerResponse, self).__init__() - self._target = Instrument.target - self._domain = t.domain self.Instrument = Instrument self.FFT = FFT self.Projection = Projection @@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator): self.linearization = nonlinearity.derivative(position) def _times(self, x): - return 0.5 * self.Instrument(self.linearization - * self.FFT.adjoint_times(self.m - * self.Projection.adjoint_times(self.power * x))) + tmp = self.Projection.adjoint_times(self.power*x) + tmp *= self.m + tmp = self.FFT.adjoint_times(tmp) + tmp *= self.linearization + tmp = self.Instrument(tmp) + tmp *= 0.5 + return tmp def _adjoint_times(self, x): - return 0.5 * self.power * self.Projection(self.m.conjugate() - * self.FFT(self.linearization - * self.Instrument.adjoint_times(x))) # .weight(-1) + tmp = self.Instrument.adjoint_times(x) + tmp *= self.linearization + tmp = self.FFT(tmp) + tmp *= self.m.conjugate() + tmp = self.Projection(tmp) + tmp *= self.power + tmp *= 0.5 + return tmp @property def domain(self): - return self._domain + return self.power.domain @property def target(self): - return self._target + return self.Instrument.target @property def unitary(self):