Commit 87bae7e9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak and remove some code that would need more adjustments

parent 7ee6cbfa
Pipeline #22137 passed with stage
in 4 minutes and 43 seconds
......@@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy):
@property
@memo
def gradient(self):
gradient = self._t1.copy()
gradient -= self.LinearizedResponse.adjoint_times(self._t2)
return gradient
return self._t1 - self.LinearizedResponse.adjoint_times(self._t2)
@property
@memo
......
from numpy import logical_and, where
from .. import Field, exp, tanh
from ..field import Field, exp, tanh
class Linear:
class Linear(object):
def __call__(self, x):
return x
def derivative(self, x):
return 1
return Field.ones_like(x)
class Exponential(object):
......@@ -26,55 +25,9 @@ class Tanh(object):
return (1. - tanh(x)**2)
class PositiveTanh:
class PositiveTanh(object):
def __call__(self, x):
return 0.5 * tanh(x) + 0.5
def derivative(self, x):
return 0.5 * (1. - tanh(x)**2)
class LinearThenQuadraticWithJump(object):
def __call__(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond = where(x > 0.)
not_cond = where(x <= 0.)
x[cond] **= 2
x[not_cond] -= 1
return Field(domain=dom, val=x)
def derivative(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond = where(x > 0.)
not_cond = where(x <= 0.)
x[cond] *= 2
x[not_cond] = 1
return Field(domain=dom, val=x)
class ReallyStupidNonlinearity(object):
def __call__(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond1 = where(logical_and(x > 0., x < .5))
cond2 = where(x >= .5)
not_cond = where(x <= 0.)
x[cond2] -= 0.5
x[cond2] **= 2
x[cond1] = 0.
x[not_cond] -= 1
return Field(domain=dom, val=x)
def derivative(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond1 = where(logical_and(x > 0., x < 0.5))
cond2 = where(x > .5)
not_cond = where(x <= 0.)
x[cond2] -= 0.5
x[cond2] *= 2
x[cond1] = 0.
x[not_cond] = 1
return Field(domain=dom, val=x)
from .. import exp
from ..field import exp
from ..operators.linear_operator import LinearOperator
class LinearizedSignalResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, power, m):
super(LinearizedSignalResponse, self).__init__()
self._target = Instrument.target
self._domain = FFT.target
self.Instrument = Instrument
self.FFT = FFT
self.power = power
position = FFT.adjoint_times(self.power * m)
position = FFT.adjoint_times(self.power*m)
self.linearization = nonlinearity.derivative(position)
def _times(self, x):
return self.Instrument(self.linearization * self.FFT.adjoint_times(self.power * x))
tmp = self.FFT.adjoint_times(self.power*x)
tmp *= self.linearization
return self.Instrument(tmp)
def _adjoint_times(self, x):
return self.power * self.FFT(self.linearization * self.Instrument.adjoint_times(x))
tmp = self.Instrument.adjoint_times(x)
tmp *= self.linearization
tmp = self.FFT(tmp)
tmp *= self.power
return tmp
@property
def domain(self):
return self._domain
return self.FFT.target
@property
def target(self):
return self._target
return self.Instrument.target
@property
def unitary(self):
......@@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator):
class LinearizedPowerResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m):
super(LinearizedPowerResponse, self).__init__()
self._target = Instrument.target
self._domain = t.domain
self.Instrument = Instrument
self.FFT = FFT
self.Projection = Projection
......@@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator):
self.linearization = nonlinearity.derivative(position)
def _times(self, x):
return 0.5 * self.Instrument(self.linearization
* self.FFT.adjoint_times(self.m
* self.Projection.adjoint_times(self.power * x)))
tmp = self.Projection.adjoint_times(self.power*x)
tmp *= self.m
tmp = self.FFT.adjoint_times(tmp)
tmp *= self.linearization
tmp = self.Instrument(tmp)
tmp *= 0.5
return tmp
def _adjoint_times(self, x):
return 0.5 * self.power * self.Projection(self.m.conjugate()
* self.FFT(self.linearization
* self.Instrument.adjoint_times(x))) # .weight(-1)
tmp = self.Instrument.adjoint_times(x)
tmp *= self.linearization
tmp = self.FFT(tmp)
tmp *= self.m.conjugate()
tmp = self.Projection(tmp)
tmp *= self.power
tmp *= 0.5
return tmp
@property
def domain(self):
return self._domain
return self.power.domain
@property
def target(self):
return self._target
return self.Instrument.target
@property
def unitary(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment