From 87bae7e97353804d305c75cbe3ecce3a19cbba7c Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Thu, 23 Nov 2017 15:50:30 +0100 Subject: [PATCH] tweak and remove some code that would need more adjustments --- nifty/library/nonlinear_signal_energy.py | 4 +- nifty/library/nonlinearities.py | 55 ++---------------------- nifty/library/response_operators.py | 47 ++++++++++++-------- 3 files changed, 34 insertions(+), 72 deletions(-) diff --git a/nifty/library/nonlinear_signal_energy.py b/nifty/library/nonlinear_signal_energy.py index 9e6aed8e9..a62b9f625 100644 --- a/nifty/library/nonlinear_signal_energy.py +++ b/nifty/library/nonlinear_signal_energy.py @@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy): @property @memo def gradient(self): - gradient = self._t1.copy() - gradient -= self.LinearizedResponse.adjoint_times(self._t2) - return gradient + return self._t1 - self.LinearizedResponse.adjoint_times(self._t2) @property @memo diff --git a/nifty/library/nonlinearities.py b/nifty/library/nonlinearities.py index 10ab15f02..baaefdb6d 100644 --- a/nifty/library/nonlinearities.py +++ b/nifty/library/nonlinearities.py @@ -1,13 +1,12 @@ -from numpy import logical_and, where -from .. import Field, exp, tanh +from ..field import Field, exp, tanh -class Linear: +class Linear(object): def __call__(self, x): return x def derivative(self, x): - return 1 + return Field.ones_like(x) class Exponential(object): @@ -26,55 +25,9 @@ class Tanh(object): return (1. - tanh(x)**2) -class PositiveTanh: +class PositiveTanh(object): def __call__(self, x): return 0.5 * tanh(x) + 0.5 def derivative(self, x): return 0.5 * (1. - tanh(x)**2) - - -class LinearThenQuadraticWithJump(object): - def __call__(self, x): - dom = x.domain - x = x.copy().val.get_full_data() - cond = where(x > 0.) - not_cond = where(x <= 0.) - x[cond] **= 2 - x[not_cond] -= 1 - return Field(domain=dom, val=x) - - def derivative(self, x): - dom = x.domain - x = x.copy().val.get_full_data() - cond = where(x > 0.) - not_cond = where(x <= 0.) - x[cond] *= 2 - x[not_cond] = 1 - return Field(domain=dom, val=x) - - -class ReallyStupidNonlinearity(object): - def __call__(self, x): - dom = x.domain - x = x.copy().val.get_full_data() - cond1 = where(logical_and(x > 0., x < .5)) - cond2 = where(x >= .5) - not_cond = where(x <= 0.) - x[cond2] -= 0.5 - x[cond2] **= 2 - x[cond1] = 0. - x[not_cond] -= 1 - return Field(domain=dom, val=x) - - def derivative(self, x): - dom = x.domain - x = x.copy().val.get_full_data() - cond1 = where(logical_and(x > 0., x < 0.5)) - cond2 = where(x > .5) - not_cond = where(x <= 0.) - x[cond2] -= 0.5 - x[cond2] *= 2 - x[cond1] = 0. - x[not_cond] = 1 - return Field(domain=dom, val=x) diff --git a/nifty/library/response_operators.py b/nifty/library/response_operators.py index c080df1ed..f2f05c129 100644 --- a/nifty/library/response_operators.py +++ b/nifty/library/response_operators.py @@ -1,31 +1,35 @@ -from .. import exp +from ..field import exp from ..operators.linear_operator import LinearOperator class LinearizedSignalResponse(LinearOperator): def __init__(self, Instrument, nonlinearity, FFT, power, m): super(LinearizedSignalResponse, self).__init__() - self._target = Instrument.target - self._domain = FFT.target self.Instrument = Instrument self.FFT = FFT self.power = power - position = FFT.adjoint_times(self.power * m) + position = FFT.adjoint_times(self.power*m) self.linearization = nonlinearity.derivative(position) def _times(self, x): - return self.Instrument(self.linearization * self.FFT.adjoint_times(self.power * x)) + tmp = self.FFT.adjoint_times(self.power*x) + tmp *= self.linearization + return self.Instrument(tmp) def _adjoint_times(self, x): - return self.power * self.FFT(self.linearization * self.Instrument.adjoint_times(x)) + tmp = self.Instrument.adjoint_times(x) + tmp *= self.linearization + tmp = self.FFT(tmp) + tmp *= self.power + return tmp @property def domain(self): - return self._domain + return self.FFT.target @property def target(self): - return self._target + return self.Instrument.target @property def unitary(self): @@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator): class LinearizedPowerResponse(LinearOperator): def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m): super(LinearizedPowerResponse, self).__init__() - self._target = Instrument.target - self._domain = t.domain self.Instrument = Instrument self.FFT = FFT self.Projection = Projection @@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator): self.linearization = nonlinearity.derivative(position) def _times(self, x): - return 0.5 * self.Instrument(self.linearization - * self.FFT.adjoint_times(self.m - * self.Projection.adjoint_times(self.power * x))) + tmp = self.Projection.adjoint_times(self.power*x) + tmp *= self.m + tmp = self.FFT.adjoint_times(tmp) + tmp *= self.linearization + tmp = self.Instrument(tmp) + tmp *= 0.5 + return tmp def _adjoint_times(self, x): - return 0.5 * self.power * self.Projection(self.m.conjugate() - * self.FFT(self.linearization - * self.Instrument.adjoint_times(x))) # .weight(-1) + tmp = self.Instrument.adjoint_times(x) + tmp *= self.linearization + tmp = self.FFT(tmp) + tmp *= self.m.conjugate() + tmp = self.Projection(tmp) + tmp *= self.power + tmp *= 0.5 + return tmp @property def domain(self): - return self._domain + return self.power.domain @property def target(self): - return self._target + return self.Instrument.target @property def unitary(self): -- GitLab