Skip to content
Snippets Groups Projects
Commit 87bae7e9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak and remove some code that would need more adjustments

parent 7ee6cbfa
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy): ...@@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy):
@property @property
@memo @memo
def gradient(self): def gradient(self):
gradient = self._t1.copy() return self._t1 - self.LinearizedResponse.adjoint_times(self._t2)
gradient -= self.LinearizedResponse.adjoint_times(self._t2)
return gradient
@property @property
@memo @memo
......
from numpy import logical_and, where from ..field import Field, exp, tanh
from .. import Field, exp, tanh
class Linear: class Linear(object):
def __call__(self, x): def __call__(self, x):
return x return x
def derivative(self, x): def derivative(self, x):
return 1 return Field.ones_like(x)
class Exponential(object): class Exponential(object):
...@@ -26,55 +25,9 @@ class Tanh(object): ...@@ -26,55 +25,9 @@ class Tanh(object):
return (1. - tanh(x)**2) return (1. - tanh(x)**2)
class PositiveTanh: class PositiveTanh(object):
def __call__(self, x): def __call__(self, x):
return 0.5 * tanh(x) + 0.5 return 0.5 * tanh(x) + 0.5
def derivative(self, x): def derivative(self, x):
return 0.5 * (1. - tanh(x)**2) return 0.5 * (1. - tanh(x)**2)
class LinearThenQuadraticWithJump(object):
def __call__(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond = where(x > 0.)
not_cond = where(x <= 0.)
x[cond] **= 2
x[not_cond] -= 1
return Field(domain=dom, val=x)
def derivative(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond = where(x > 0.)
not_cond = where(x <= 0.)
x[cond] *= 2
x[not_cond] = 1
return Field(domain=dom, val=x)
class ReallyStupidNonlinearity(object):
def __call__(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond1 = where(logical_and(x > 0., x < .5))
cond2 = where(x >= .5)
not_cond = where(x <= 0.)
x[cond2] -= 0.5
x[cond2] **= 2
x[cond1] = 0.
x[not_cond] -= 1
return Field(domain=dom, val=x)
def derivative(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond1 = where(logical_and(x > 0., x < 0.5))
cond2 = where(x > .5)
not_cond = where(x <= 0.)
x[cond2] -= 0.5
x[cond2] *= 2
x[cond1] = 0.
x[not_cond] = 1
return Field(domain=dom, val=x)
from .. import exp from ..field import exp
from ..operators.linear_operator import LinearOperator from ..operators.linear_operator import LinearOperator
class LinearizedSignalResponse(LinearOperator): class LinearizedSignalResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, power, m): def __init__(self, Instrument, nonlinearity, FFT, power, m):
super(LinearizedSignalResponse, self).__init__() super(LinearizedSignalResponse, self).__init__()
self._target = Instrument.target
self._domain = FFT.target
self.Instrument = Instrument self.Instrument = Instrument
self.FFT = FFT self.FFT = FFT
self.power = power self.power = power
position = FFT.adjoint_times(self.power * m) position = FFT.adjoint_times(self.power*m)
self.linearization = nonlinearity.derivative(position) self.linearization = nonlinearity.derivative(position)
def _times(self, x): def _times(self, x):
return self.Instrument(self.linearization * self.FFT.adjoint_times(self.power * x)) tmp = self.FFT.adjoint_times(self.power*x)
tmp *= self.linearization
return self.Instrument(tmp)
def _adjoint_times(self, x): def _adjoint_times(self, x):
return self.power * self.FFT(self.linearization * self.Instrument.adjoint_times(x)) tmp = self.Instrument.adjoint_times(x)
tmp *= self.linearization
tmp = self.FFT(tmp)
tmp *= self.power
return tmp
@property @property
def domain(self): def domain(self):
return self._domain return self.FFT.target
@property @property
def target(self): def target(self):
return self._target return self.Instrument.target
@property @property
def unitary(self): def unitary(self):
...@@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator): ...@@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator):
class LinearizedPowerResponse(LinearOperator): class LinearizedPowerResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m): def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m):
super(LinearizedPowerResponse, self).__init__() super(LinearizedPowerResponse, self).__init__()
self._target = Instrument.target
self._domain = t.domain
self.Instrument = Instrument self.Instrument = Instrument
self.FFT = FFT self.FFT = FFT
self.Projection = Projection self.Projection = Projection
...@@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator): ...@@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator):
self.linearization = nonlinearity.derivative(position) self.linearization = nonlinearity.derivative(position)
def _times(self, x): def _times(self, x):
return 0.5 * self.Instrument(self.linearization tmp = self.Projection.adjoint_times(self.power*x)
* self.FFT.adjoint_times(self.m tmp *= self.m
* self.Projection.adjoint_times(self.power * x))) tmp = self.FFT.adjoint_times(tmp)
tmp *= self.linearization
tmp = self.Instrument(tmp)
tmp *= 0.5
return tmp
def _adjoint_times(self, x): def _adjoint_times(self, x):
return 0.5 * self.power * self.Projection(self.m.conjugate() tmp = self.Instrument.adjoint_times(x)
* self.FFT(self.linearization tmp *= self.linearization
* self.Instrument.adjoint_times(x))) # .weight(-1) tmp = self.FFT(tmp)
tmp *= self.m.conjugate()
tmp = self.Projection(tmp)
tmp *= self.power
tmp *= 0.5
return tmp
@property @property
def domain(self): def domain(self):
return self._domain return self.power.domain
@property @property
def target(self): def target(self):
return self._target return self.Instrument.target
@property @property
def unitary(self): def unitary(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment