Skip to content
Snippets Groups Projects
Commit 87bae7e9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak and remove some code that would need more adjustments

parent 7ee6cbfa
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy):
@property
@memo
def gradient(self):
gradient = self._t1.copy()
gradient -= self.LinearizedResponse.adjoint_times(self._t2)
return gradient
return self._t1 - self.LinearizedResponse.adjoint_times(self._t2)
@property
@memo
......
from numpy import logical_and, where
from .. import Field, exp, tanh
from ..field import Field, exp, tanh
class Linear:
class Linear(object):
def __call__(self, x):
return x
def derivative(self, x):
return 1
return Field.ones_like(x)
class Exponential(object):
......@@ -26,55 +25,9 @@ class Tanh(object):
return (1. - tanh(x)**2)
class PositiveTanh:
class PositiveTanh(object):
def __call__(self, x):
return 0.5 * tanh(x) + 0.5
def derivative(self, x):
return 0.5 * (1. - tanh(x)**2)
class LinearThenQuadraticWithJump(object):
def __call__(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond = where(x > 0.)
not_cond = where(x <= 0.)
x[cond] **= 2
x[not_cond] -= 1
return Field(domain=dom, val=x)
def derivative(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond = where(x > 0.)
not_cond = where(x <= 0.)
x[cond] *= 2
x[not_cond] = 1
return Field(domain=dom, val=x)
class ReallyStupidNonlinearity(object):
def __call__(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond1 = where(logical_and(x > 0., x < .5))
cond2 = where(x >= .5)
not_cond = where(x <= 0.)
x[cond2] -= 0.5
x[cond2] **= 2
x[cond1] = 0.
x[not_cond] -= 1
return Field(domain=dom, val=x)
def derivative(self, x):
dom = x.domain
x = x.copy().val.get_full_data()
cond1 = where(logical_and(x > 0., x < 0.5))
cond2 = where(x > .5)
not_cond = where(x <= 0.)
x[cond2] -= 0.5
x[cond2] *= 2
x[cond1] = 0.
x[not_cond] = 1
return Field(domain=dom, val=x)
from .. import exp
from ..field import exp
from ..operators.linear_operator import LinearOperator
class LinearizedSignalResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, power, m):
super(LinearizedSignalResponse, self).__init__()
self._target = Instrument.target
self._domain = FFT.target
self.Instrument = Instrument
self.FFT = FFT
self.power = power
......@@ -14,18 +12,24 @@ class LinearizedSignalResponse(LinearOperator):
self.linearization = nonlinearity.derivative(position)
def _times(self, x):
return self.Instrument(self.linearization * self.FFT.adjoint_times(self.power * x))
tmp = self.FFT.adjoint_times(self.power*x)
tmp *= self.linearization
return self.Instrument(tmp)
def _adjoint_times(self, x):
return self.power * self.FFT(self.linearization * self.Instrument.adjoint_times(x))
tmp = self.Instrument.adjoint_times(x)
tmp *= self.linearization
tmp = self.FFT(tmp)
tmp *= self.power
return tmp
@property
def domain(self):
return self._domain
return self.FFT.target
@property
def target(self):
return self._target
return self.Instrument.target
@property
def unitary(self):
......@@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator):
class LinearizedPowerResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m):
super(LinearizedPowerResponse, self).__init__()
self._target = Instrument.target
self._domain = t.domain
self.Instrument = Instrument
self.FFT = FFT
self.Projection = Projection
......@@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator):
self.linearization = nonlinearity.derivative(position)
def _times(self, x):
return 0.5 * self.Instrument(self.linearization
* self.FFT.adjoint_times(self.m
* self.Projection.adjoint_times(self.power * x)))
tmp = self.Projection.adjoint_times(self.power*x)
tmp *= self.m
tmp = self.FFT.adjoint_times(tmp)
tmp *= self.linearization
tmp = self.Instrument(tmp)
tmp *= 0.5
return tmp
def _adjoint_times(self, x):
return 0.5 * self.power * self.Projection(self.m.conjugate()
* self.FFT(self.linearization
* self.Instrument.adjoint_times(x))) # .weight(-1)
tmp = self.Instrument.adjoint_times(x)
tmp *= self.linearization
tmp = self.FFT(tmp)
tmp *= self.m.conjugate()
tmp = self.Projection(tmp)
tmp *= self.power
tmp *= 0.5
return tmp
@property
def domain(self):
return self._domain
return self.power.domain
@property
def target(self):
return self._target
return self.Instrument.target
@property
def unitary(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment