From 87bae7e97353804d305c75cbe3ecce3a19cbba7c Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Thu, 23 Nov 2017 15:50:30 +0100
Subject: [PATCH] tweak and remove some code that would need more adjustments
---
nifty/library/nonlinear_signal_energy.py | 4 +-
nifty/library/nonlinearities.py | 55 ++----------------------
nifty/library/response_operators.py | 47 ++++++++++++--------
3 files changed, 34 insertions(+), 72 deletions(-)
diff --git a/nifty/library/nonlinear_signal_energy.py b/nifty/library/nonlinear_signal_energy.py
index 9e6aed8e9..a62b9f625 100644
--- a/nifty/library/nonlinear_signal_energy.py
+++ b/nifty/library/nonlinear_signal_energy.py
@@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy):
@property
@memo
def gradient(self):
- gradient = self._t1.copy()
- gradient -= self.LinearizedResponse.adjoint_times(self._t2)
- return gradient
+ return self._t1 - self.LinearizedResponse.adjoint_times(self._t2)
@property
@memo
diff --git a/nifty/library/nonlinearities.py b/nifty/library/nonlinearities.py
index 10ab15f02..baaefdb6d 100644
--- a/nifty/library/nonlinearities.py
+++ b/nifty/library/nonlinearities.py
@@ -1,13 +1,12 @@
-from numpy import logical_and, where
-from .. import Field, exp, tanh
+from ..field import Field, exp, tanh
-class Linear:
+class Linear(object):
def __call__(self, x):
return x
def derivative(self, x):
- return 1
+ return Field.ones_like(x)
class Exponential(object):
@@ -26,55 +25,9 @@ class Tanh(object):
return (1. - tanh(x)**2)
-class PositiveTanh:
+class PositiveTanh(object):
def __call__(self, x):
return 0.5 * tanh(x) + 0.5
def derivative(self, x):
return 0.5 * (1. - tanh(x)**2)
-
-
-class LinearThenQuadraticWithJump(object):
- def __call__(self, x):
- dom = x.domain
- x = x.copy().val.get_full_data()
- cond = where(x > 0.)
- not_cond = where(x <= 0.)
- x[cond] **= 2
- x[not_cond] -= 1
- return Field(domain=dom, val=x)
-
- def derivative(self, x):
- dom = x.domain
- x = x.copy().val.get_full_data()
- cond = where(x > 0.)
- not_cond = where(x <= 0.)
- x[cond] *= 2
- x[not_cond] = 1
- return Field(domain=dom, val=x)
-
-
-class ReallyStupidNonlinearity(object):
- def __call__(self, x):
- dom = x.domain
- x = x.copy().val.get_full_data()
- cond1 = where(logical_and(x > 0., x < .5))
- cond2 = where(x >= .5)
- not_cond = where(x <= 0.)
- x[cond2] -= 0.5
- x[cond2] **= 2
- x[cond1] = 0.
- x[not_cond] -= 1
- return Field(domain=dom, val=x)
-
- def derivative(self, x):
- dom = x.domain
- x = x.copy().val.get_full_data()
- cond1 = where(logical_and(x > 0., x < 0.5))
- cond2 = where(x > .5)
- not_cond = where(x <= 0.)
- x[cond2] -= 0.5
- x[cond2] *= 2
- x[cond1] = 0.
- x[not_cond] = 1
- return Field(domain=dom, val=x)
diff --git a/nifty/library/response_operators.py b/nifty/library/response_operators.py
index c080df1ed..f2f05c129 100644
--- a/nifty/library/response_operators.py
+++ b/nifty/library/response_operators.py
@@ -1,31 +1,35 @@
-from .. import exp
+from ..field import exp
from ..operators.linear_operator import LinearOperator
class LinearizedSignalResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, power, m):
super(LinearizedSignalResponse, self).__init__()
- self._target = Instrument.target
- self._domain = FFT.target
self.Instrument = Instrument
self.FFT = FFT
self.power = power
- position = FFT.adjoint_times(self.power * m)
+ position = FFT.adjoint_times(self.power*m)
self.linearization = nonlinearity.derivative(position)
def _times(self, x):
- return self.Instrument(self.linearization * self.FFT.adjoint_times(self.power * x))
+ tmp = self.FFT.adjoint_times(self.power*x)
+ tmp *= self.linearization
+ return self.Instrument(tmp)
def _adjoint_times(self, x):
- return self.power * self.FFT(self.linearization * self.Instrument.adjoint_times(x))
+ tmp = self.Instrument.adjoint_times(x)
+ tmp *= self.linearization
+ tmp = self.FFT(tmp)
+ tmp *= self.power
+ return tmp
@property
def domain(self):
- return self._domain
+ return self.FFT.target
@property
def target(self):
- return self._target
+ return self.Instrument.target
@property
def unitary(self):
@@ -35,8 +39,6 @@ class LinearizedSignalResponse(LinearOperator):
class LinearizedPowerResponse(LinearOperator):
def __init__(self, Instrument, nonlinearity, FFT, Projection, t, m):
super(LinearizedPowerResponse, self).__init__()
- self._target = Instrument.target
- self._domain = t.domain
self.Instrument = Instrument
self.FFT = FFT
self.Projection = Projection
@@ -47,22 +49,31 @@ class LinearizedPowerResponse(LinearOperator):
self.linearization = nonlinearity.derivative(position)
def _times(self, x):
- return 0.5 * self.Instrument(self.linearization
- * self.FFT.adjoint_times(self.m
- * self.Projection.adjoint_times(self.power * x)))
+ tmp = self.Projection.adjoint_times(self.power*x)
+ tmp *= self.m
+ tmp = self.FFT.adjoint_times(tmp)
+ tmp *= self.linearization
+ tmp = self.Instrument(tmp)
+ tmp *= 0.5
+ return tmp
def _adjoint_times(self, x):
- return 0.5 * self.power * self.Projection(self.m.conjugate()
- * self.FFT(self.linearization
- * self.Instrument.adjoint_times(x))) # .weight(-1)
+ tmp = self.Instrument.adjoint_times(x)
+ tmp *= self.linearization
+ tmp = self.FFT(tmp)
+ tmp *= self.m.conjugate()
+ tmp = self.Projection(tmp)
+ tmp *= self.power
+ tmp *= 0.5
+ return tmp
@property
def domain(self):
- return self._domain
+ return self.power.domain
@property
def target(self):
- return self._target
+ return self.Instrument.target
@property
def unitary(self):
--
GitLab