Commit 87bae7e9 by Martin Reinecke

### tweak and remove some code that would need more adjustments

parent 7ee6cbfa
Pipeline #22137 passed with stage
in 4 minutes and 43 seconds
 ... @@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy): ... @@ -45,9 +45,7 @@ class NonlinearWienerFilterEnergy(Energy): @property @property @memo @memo def gradient(self): def gradient(self): gradient = self._t1.copy() return self._t1 - self.LinearizedResponse.adjoint_times(self._t2) gradient -= self.LinearizedResponse.adjoint_times(self._t2) return gradient @property @property @memo @memo ... ...
 from numpy import logical_and, where from ..field import Field, exp, tanh from .. import Field, exp, tanh class Linear: class Linear(object): def __call__(self, x): def __call__(self, x): return x return x def derivative(self, x): def derivative(self, x): return 1 return Field.ones_like(x) class Exponential(object): class Exponential(object): ... @@ -26,55 +25,9 @@ class Tanh(object): ... @@ -26,55 +25,9 @@ class Tanh(object): return (1. - tanh(x)**2) return (1. - tanh(x)**2) class PositiveTanh: class PositiveTanh(object): def __call__(self, x): def __call__(self, x): return 0.5 * tanh(x) + 0.5 return 0.5 * tanh(x) + 0.5 def derivative(self, x): def derivative(self, x): return 0.5 * (1. - tanh(x)**2) return 0.5 * (1. - tanh(x)**2) class LinearThenQuadraticWithJump(object): def __call__(self, x): dom = x.domain x = x.copy().val.get_full_data() cond = where(x > 0.) not_cond = where(x <= 0.) x[cond] **= 2 x[not_cond] -= 1 return Field(domain=dom, val=x) def derivative(self, x): dom = x.domain x = x.copy().val.get_full_data() cond = where(x > 0.) not_cond = where(x <= 0.) x[cond] *= 2 x[not_cond] = 1 return Field(domain=dom, val=x) class ReallyStupidNonlinearity(object): def __call__(self, x): dom = x.domain x = x.copy().val.get_full_data() cond1 = where(logical_and(x > 0., x < .5)) cond2 = where(x >= .5) not_cond = where(x <= 0.) x[cond2] -= 0.5 x[cond2] **= 2 x[cond1] = 0. x[not_cond] -= 1 return Field(domain=dom, val=x) def derivative(self, x): dom = x.domain x = x.copy().val.get_full_data() cond1 = where(logical_and(x > 0., x < 0.5)) cond2 = where(x > .5) not_cond = where(x <= 0.) x[cond2] -= 0.5 x[cond2] *= 2 x[cond1] = 0. x[not_cond] = 1 return Field(domain=dom, val=x)