Commit 87bae7e9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak and remove some code that would need more adjustments

parent 7ee6cbfa
Pipeline #22137 passed with stage
in 4 minutes and 43 seconds
...@@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy): ...@@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy):
@property @property
@memo @memo
def gradient(self): def gradient(self):
gradient = self._t1.copy() return self._t1 - self.LinearizedResponse.adjoint_times(self._t2)
gradient -= self.LinearizedResponse.adjoint_times(self._t2)
return gradient
@property @property
@memo @memo
......
from numpy import logical_and, where from ..field import Field, exp, tanh
from .. import Field, exp, tanh
class Linear: class Linear(object):
def __call__(self, x): def __call__(self, x):
return x return x
def derivative(self, x): def derivative(self, x):
return 1 return Field.ones_like(x)
class Exponential(object): class Exponential(object):
...@@ -26,55 +25,9 @@ class Tanh(object): ...@@ -26,55 +25,9 @@ class Tanh(object):
return (1. - tanh(x)**2) return (1. - tanh(x)**2)
class PositiveTanh: class PositiveTanh(object):
def __call__(self, x): def __call__(self, x):
return 0.5 * tanh(x) + 0.5 return 0.5 * tanh(x) + 0.5
def derivative(self, x): def derivative(self, x):
return 0.5 * (1. - tanh(x)**2) return 0.5 * (1. - tanh(x)**2)
class LinearThenQuadraticWithJump(object):
def __call__(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond = where(x > 0.)
not_cond = where(x <= 0.)
x[cond] **= 2
x[not_cond] -= 1
return Field(domain=dom, val=x)
def derivative(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond = where(x > 0.)
not_cond = where(x <= 0.)
x[cond] *= 2
x[not_cond] = 1
return Field(domain=dom, val=x)
class ReallyStupidNonlinearity(object):
def __call__(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond1 = where(logical_and(x > 0., x < .5))
cond2 = where(x >= .5)
not_cond = where(x <= 0.)
x[cond2] -= 0.5
x[cond2] **= 2
x[cond1] = 0.
x[not_cond] -= 1
return Field(domain=dom, val=x)
def derivative(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond1 = where(logical_and(x > 0., x < 0.5))
cond2 = where(x > .5)
not_cond = where(x <= 0.)
x[cond2] -= 0.5
x[cond2] *= 2
x[cond1] = 0.
x[not_cond] = 1
return Field(domain=dom, val=x)
from .. import exp from ..field import exp
from ..operators.linear_operator import LinearOperator from ..operators.linear_operator import LinearOperator
class LinearizedSignalResponse(LinearOperator): class LinearizedSignalResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, power, m): def __init__(self, Instrument, nonlinearity, FFT, power, m):
super(LinearizedSignalResponse, self).__init__() super(LinearizedSignalResponse, self).__init__()
self._target = Instrument.target
self._domain = FFT.target
self.Instrument = Instrument self.Instrument = Instrument
self.FFT = FFT self.FFT = FFT
self.power = power self.power = power
position = FFT.adjoint_times(self.power * m) position = FFT.adjoint_times(self.power*m)
self.linearization = nonlinearity.derivative(position) self.linearization = nonlinearity.derivative(position)
def _times(self, x): def _times(self, x):
return self.Instrument(self.linearization * self.FFT.adjoint_times(self.power * x)) tmp = self.FFT.adjoint_times(self.power*x)
tmp *= self.linearization
return self.Instrument(tmp)
def _adjoint_times(self, x): def _adjoint_times(self, x):
return self.power * self.FFT(self.linearization * self.Instrument.adjoint_times(x)) tmp = self.Instrument.adjoint_times(x)
tmp *= self.linearization
tmp = self.FFT(tmp)
tmp *= self.power
return tmp
@property @property
def domain(self): def domain(self):
return self._domain return self.FFT.target
@property @property
def target(self): def target(self):
return self._target return self.Instrument.target
@property @property
def unitary(self): def unitary(self):
...@@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator): ...@@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator):
class LinearizedPowerResponse(LinearOperator): class LinearizedPowerResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m): def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m):
super(LinearizedPowerResponse, self).__init__() super(LinearizedPowerResponse, self).__init__()
self._target = Instrument.target
self._domain = t.domain
self.Instrument = Instrument self.Instrument = Instrument
self.FFT = FFT self.FFT = FFT
self.Projection = Projection self.Projection = Projection
...@@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator): ...@@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator):
self.linearization = nonlinearity.derivative(position) self.linearization = nonlinearity.derivative(position)
def _times(self, x): def _times(self, x):
return 0.5 * self.Instrument(self.linearization tmp = self.Projection.adjoint_times(self.power*x)
* self.FFT.adjoint_times(self.m tmp *= self.m
* self.Projection.adjoint_times(self.power * x))) tmp = self.FFT.adjoint_times(tmp)
tmp *= self.linearization
tmp = self.Instrument(tmp)
tmp *= 0.5
return tmp
def _adjoint_times(self, x): def _adjoint_times(self, x):
return 0.5 * self.power * self.Projection(self.m.conjugate() tmp = self.Instrument.adjoint_times(x)
* self.FFT(self.linearization tmp *= self.linearization
* self.Instrument.adjoint_times(x))) # .weight(-1) tmp = self.FFT(tmp)
tmp *= self.m.conjugate()
tmp = self.Projection(tmp)
tmp *= self.power
tmp *= 0.5
return tmp
@property @property
def domain(self): def domain(self):
return self._domain return self.power.domain
@property @property
def target(self): def target(self):
return self._target return self.Instrument.target
@property @property
def unitary(self): def unitary(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment