Commit 87bae7e9 by Martin Reinecke

### tweak and remove some code that would need more adjustments

parent 7ee6cbfa
Pipeline #22137 passed with stage
in 4 minutes and 43 seconds
 ... ... @@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy): @property @memo def gradient(self): gradient = self._t1.copy() gradient -= self.LinearizedResponse.adjoint_times(self._t2) return gradient return self._t1 - self.LinearizedResponse.adjoint_times(self._t2) @property @memo ... ...
 from numpy import logical_and, where from .. import Field, exp, tanh from ..field import Field, exp, tanh class Linear: class Linear(object): def __call__(self, x): return x def derivative(self, x): return 1 return Field.ones_like(x) class Exponential(object): ... ... @@ -26,55 +25,9 @@ class Tanh(object): return (1. - tanh(x)**2) class PositiveTanh: class PositiveTanh(object): def __call__(self, x): return 0.5 * tanh(x) + 0.5 def derivative(self, x): return 0.5 * (1. - tanh(x)**2) class LinearThenQuadraticWithJump(object): def __call__(self, x): dom = x.domain x = x.copy().val.get_full_data() cond = where(x > 0.) not_cond = where(x <= 0.) x[cond] **= 2 x[not_cond] -= 1 return Field(domain=dom, val=x) def derivative(self, x): dom = x.domain x = x.copy().val.get_full_data() cond = where(x > 0.) not_cond = where(x <= 0.) x[cond] *= 2 x[not_cond] = 1 return Field(domain=dom, val=x) class ReallyStupidNonlinearity(object): def __call__(self, x): dom = x.domain x = x.copy().val.get_full_data() cond1 = where(logical_and(x > 0., x < .5)) cond2 = where(x >= .5) not_cond = where(x <= 0.) x[cond2] -= 0.5 x[cond2] **= 2 x[cond1] = 0. x[not_cond] -= 1 return Field(domain=dom, val=x) def derivative(self, x): dom = x.domain x = x.copy().val.get_full_data() cond1 = where(logical_and(x > 0., x < 0.5)) cond2 = where(x > .5) not_cond = where(x <= 0.) x[cond2] -= 0.5 x[cond2] *= 2 x[cond1] = 0. x[not_cond] = 1 return Field(domain=dom, val=x)
 from .. import exp from ..field import exp from ..operators.linear_operator import LinearOperator class LinearizedSignalResponse(LinearOperator): def __init__(self, Instrument, nonlinearity, FFT, power, m): super(LinearizedSignalResponse, self).__init__() self._target = Instrument.target self._domain = FFT.target self.Instrument = Instrument self.FFT = FFT self.power = power position = FFT.adjoint_times(self.power * m) position = FFT.adjoint_times(self.power*m) self.linearization = nonlinearity.derivative(position) def _times(self, x): return self.Instrument(self.linearization * self.FFT.adjoint_times(self.power * x)) tmp = self.FFT.adjoint_times(self.power*x) tmp *= self.linearization return self.Instrument(tmp) def _adjoint_times(self, x): return self.power * self.FFT(self.linearization * self.Instrument.adjoint_times(x)) tmp = self.Instrument.adjoint_times(x) tmp *= self.linearization tmp = self.FFT(tmp) tmp *= self.power return tmp @property def domain(self): return self._domain return self.FFT.target @property def target(self): return self._target return self.Instrument.target @property def unitary(self): ... ... @@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator): class LinearizedPowerResponse(LinearOperator): def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m): super(LinearizedPowerResponse, self).__init__() self._target = Instrument.target self._domain = t.domain self.Instrument = Instrument self.FFT = FFT self.Projection = Projection ... ... @@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator): self.linearization = nonlinearity.derivative(position) def _times(self, x): return 0.5 * self.Instrument(self.linearization * self.FFT.adjoint_times(self.m * self.Projection.adjoint_times(self.power * x))) tmp = self.Projection.adjoint_times(self.power*x) tmp *= self.m tmp = self.FFT.adjoint_times(tmp) tmp *= self.linearization tmp = self.Instrument(tmp) tmp *= 0.5 return tmp def _adjoint_times(self, x): return 0.5 * self.power * self.Projection(self.m.conjugate() * self.FFT(self.linearization * self.Instrument.adjoint_times(x))) # .weight(-1) tmp = self.Instrument.adjoint_times(x) tmp *= self.linearization tmp = self.FFT(tmp) tmp *= self.m.conjugate() tmp = self.Projection(tmp) tmp *= self.power tmp *= 0.5 return tmp @property def domain(self): return self._domain return self.power.domain @property def target(self): return self._target return self.Instrument.target @property def unitary(self): ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment