Commit 6751d5bc authored by Philipp Arras's avatar Philipp Arras
Browse files

Revert unit stuff

This stuff may be merged separately.
parent 367e0628
...@@ -25,26 +25,23 @@ from ..minimization.energy import Energy ...@@ -25,26 +25,23 @@ from ..minimization.energy import Energy
class NoiseEnergy(Energy): class NoiseEnergy(Energy):
def __init__(self, position, d, xi, D, t, ht, Instrument, def __init__(self, position, d, xi, D, t, ht, Instrument,
nonlinearity, alpha, q, Projection, munit=1., sunit=1., nonlinearity, alpha, q, Projection, samples=3,
dunit=1., samples=3, xi_sample_list=None, inverter=None): xi_sample_list=None, inverter=None):
super(NoiseEnergy, self).__init__(position=position) super(NoiseEnergy, self).__init__(position=position)
self.xi = xi self.xi = xi
self.D = D self.D = D
self.d = d self.d = d
self.N = DiagonalOperator(diagonal=dunit**2 * exp(self.position)) self.N = DiagonalOperator(diagonal=exp(self.position))
self.t = t self.t = t
self.samples = samples self.samples = samples
self.ht = ht self.ht = ht
self.Instrument = Instrument self.Instrument = Instrument
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
self.munit = munit
self.sunit = sunit
self.dunit = dunit
self.alpha = alpha self.alpha = alpha
self.q = q self.q = q
self.Projection = Projection self.Projection = Projection
self.power = self.Projection.adjoint_times(munit * exp(0.5 * self.t)) self.power = self.Projection.adjoint_times(exp(0.5 * self.t))
if xi_sample_list is None: if xi_sample_list is None:
if samples is None or samples == 0: if samples is None or samples == 0:
xi_sample_list = [xi] xi_sample_list = [xi]
...@@ -54,14 +51,14 @@ class NoiseEnergy(Energy): ...@@ -54,14 +51,14 @@ class NoiseEnergy(Energy):
self.xi_sample_list = xi_sample_list self.xi_sample_list = xi_sample_list
self.inverter = inverter self.inverter = inverter
A = Projection.adjoint_times(munit * exp(.5 * self.t)) # unit: munit A = Projection.adjoint_times(exp(.5 * self.t))
self._gradient = None self._gradient = None
for sample in self.xi_sample_list: for sample in self.xi_sample_list:
map_s = self.ht(A * sample) map_s = self.ht(A * sample)
residual = self.d - \ residual = self.d - \
self.Instrument(sunit * self.nonlinearity(map_s)) self.Instrument(self.nonlinearity(map_s))
lh = .5 * residual.vdot(self.N.inverse_times(residual)) lh = .5 * residual.vdot(self.N.inverse_times(residual))
grad = -.5 * self.N.inverse_times(residual.conjugate() * residual) grad = -.5 * self.N.inverse_times(residual.conjugate() * residual)
...@@ -84,8 +81,7 @@ class NoiseEnergy(Energy): ...@@ -84,8 +81,7 @@ class NoiseEnergy(Energy):
return self.__class__( return self.__class__(
position, self.d, self.xi, self.D, self.t, self.ht, position, self.d, self.xi, self.D, self.t, self.ht,
self.Instrument, self.nonlinearity, self.alpha, self.q, self.Instrument, self.nonlinearity, self.alpha, self.q,
self.Projection, munit=self.munit, sunit=self.sunit, self.Projection, xi_sample_list=self.xi_sample_list,
dunit=self.dunit, xi_sample_list=self.xi_sample_list,
samples=self.samples, inverter=self.inverter) samples=self.samples, inverter=self.inverter)
@property @property
......
...@@ -20,22 +20,12 @@ from ..operators.inversion_enabler import InversionEnabler ...@@ -20,22 +20,12 @@ from ..operators.inversion_enabler import InversionEnabler
from .response_operators import LinearizedPowerResponse from .response_operators import LinearizedPowerResponse
def NonlinearPowerCurvature( def NonlinearPowerCurvature(tau, ht, Instrument, nonlinearity, Projection, N,
tau, T, xi_sample_list, inverter):
ht,
Instrument,
nonlinearity,
Projection,
N,
T,
xi_sample_list,
inverter,
munit=1.,
sunit=1.):
result = None result = None
for xi_sample in xi_sample_list: for xi_sample in xi_sample_list:
LinearizedResponse = LinearizedPowerResponse( LinearizedResponse = LinearizedPowerResponse(
Instrument, nonlinearity, ht, Projection, tau, xi_sample, munit, sunit) Instrument, nonlinearity, ht, Projection, tau, xi_sample)
op = LinearizedResponse.adjoint * N.inverse * LinearizedResponse op = LinearizedResponse.adjoint * N.inverse * LinearizedResponse
result = op if result is None else result + op result = op if result is None else result + op
result = result * (1. / len(xi_sample_list)) + T result = result * (1. / len(xi_sample_list)) + T
......
...@@ -53,7 +53,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -53,7 +53,7 @@ class NonlinearPowerEnergy(Energy):
def __init__(self, position, d, N, xi, D, ht, Instrument, nonlinearity, def __init__(self, position, d, N, xi, D, ht, Instrument, nonlinearity,
Projection, sigma=0., samples=3, xi_sample_list=None, Projection, sigma=0., samples=3, xi_sample_list=None,
inverter=None, munit=1., sunit=1.): inverter=None):
super(NonlinearPowerEnergy, self).__init__(position) super(NonlinearPowerEnergy, self).__init__(position)
self.xi = xi self.xi = xi
self.D = D self.D = D
...@@ -66,8 +66,6 @@ class NonlinearPowerEnergy(Energy): ...@@ -66,8 +66,6 @@ class NonlinearPowerEnergy(Energy):
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
self.Projection = Projection self.Projection = Projection
self.sigma = sigma self.sigma = sigma
self.munit = munit
self.sunit = sunit
if xi_sample_list is None: if xi_sample_list is None:
if samples is None or samples == 0: if samples is None or samples == 0:
xi_sample_list = [xi] xi_sample_list = [xi]
...@@ -77,7 +75,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -77,7 +75,7 @@ class NonlinearPowerEnergy(Energy):
self.xi_sample_list = xi_sample_list self.xi_sample_list = xi_sample_list
self.inverter = inverter self.inverter = inverter
A = Projection.adjoint_times(munit * exp(.5 * position)) # unit: munit A = Projection.adjoint_times(exp(.5 * position))
map_s = self.ht(A * xi) map_s = self.ht(A * xi)
Tpos = self.T(position) Tpos = self.T(position)
...@@ -86,10 +84,10 @@ class NonlinearPowerEnergy(Energy): ...@@ -86,10 +84,10 @@ class NonlinearPowerEnergy(Energy):
map_s = self.ht(A * xi_sample) map_s = self.ht(A * xi_sample)
LinR = LinearizedPowerResponse( LinR = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.ht, self.Projection, self.Instrument, self.nonlinearity, self.ht, self.Projection,
self.position, xi_sample, munit=self.munit, sunit=self.sunit) self.position, xi_sample)
residual = self.d - \ residual = self.d - \
self.Instrument(sunit * self.nonlinearity(map_s)) self.Instrument(self.nonlinearity(map_s))
lh = 0.5 * residual.vdot(self.N.inverse_times(residual)) lh = 0.5 * residual.vdot(self.N.inverse_times(residual))
grad = LinR.adjoint_times(self.N.inverse_times(residual)) grad = LinR.adjoint_times(self.N.inverse_times(residual))
...@@ -111,8 +109,6 @@ class NonlinearPowerEnergy(Energy): ...@@ -111,8 +109,6 @@ class NonlinearPowerEnergy(Energy):
self.Projection, sigma=self.sigma, self.Projection, sigma=self.sigma,
samples=len(self.xi_sample_list), samples=len(self.xi_sample_list),
xi_sample_list=self.xi_sample_list, xi_sample_list=self.xi_sample_list,
munit=self.munit,
sunit=self.sunit,
inverter=self.inverter) inverter=self.inverter)
@property @property
...@@ -129,4 +125,4 @@ class NonlinearPowerEnergy(Energy): ...@@ -129,4 +125,4 @@ class NonlinearPowerEnergy(Energy):
return NonlinearPowerCurvature( return NonlinearPowerCurvature(
self.position, self.ht, self.Instrument, self.nonlinearity, self.position, self.ht, self.Instrument, self.nonlinearity,
self.Projection, self.N, self.T, self.xi_sample_list, self.Projection, self.N, self.T, self.xi_sample_list,
self.inverter, self.munit, self.sunit) self.inverter)
...@@ -24,19 +24,18 @@ from .response_operators import LinearizedSignalResponse ...@@ -24,19 +24,18 @@ from .response_operators import LinearizedSignalResponse
class NonlinearWienerFilterEnergy(Energy): class NonlinearWienerFilterEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, ht, power, N, S, def __init__(self, position, d, Instrument, nonlinearity, ht, power, N, S,
inverter=None, sunit=1.): inverter=None):
super(NonlinearWienerFilterEnergy, self).__init__(position=position) super(NonlinearWienerFilterEnergy, self).__init__(position=position)
self.d = d self.d = d
self.sunit = sunit
self.Instrument = Instrument self.Instrument = Instrument
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
self.ht = ht self.ht = ht
self.power = power self.power = power
m = self.ht(self.power * self.position) m = self.ht(self.power * self.position)
self.LinearizedResponse = LinearizedSignalResponse( self.LinearizedResponse = LinearizedSignalResponse(
Instrument, nonlinearity, ht, power, m, sunit) Instrument, nonlinearity, ht, power, m)
residual = d - Instrument(sunit * nonlinearity(m)) residual = d - Instrument(nonlinearity(m))
self.N = N self.N = N
self.S = S self.S = S
self.inverter = inverter self.inverter = inverter
...@@ -49,7 +48,7 @@ class NonlinearWienerFilterEnergy(Energy): ...@@ -49,7 +48,7 @@ class NonlinearWienerFilterEnergy(Energy):
def at(self, position): def at(self, position):
return self.__class__(position, self.d, self.Instrument, return self.__class__(position, self.d, self.Instrument,
self.nonlinearity, self.ht, self.power, self.N, self.nonlinearity, self.ht, self.power, self.N,
self.S, self.inverter, self.sunit) self.S, self.inverter)
@property @property
def value(self): def value(self):
......
...@@ -19,21 +19,12 @@ ...@@ -19,21 +19,12 @@
from ..field import exp from ..field import exp
def LinearizedSignalResponse(Instrument, nonlinearity, ht, power, m, sunit): def LinearizedSignalResponse(Instrument, nonlinearity, ht, power, m):
return sunit * (Instrument * nonlinearity.derivative(m) * ht * power) return (Instrument * nonlinearity.derivative(m) * ht * power)
def LinearizedPowerResponse( def LinearizedPowerResponse(Instrument, nonlinearity, ht, Projection, tau, xi):
Instrument, power = exp(0.5 * tau)
nonlinearity,
ht,
Projection,
tau,
xi,
munit,
sunit):
power = exp(0.5 * tau) * munit
position = ht(Projection.adjoint_times(power) * xi) position = ht(Projection.adjoint_times(power) * xi)
linearization = nonlinearity.derivative(position) linearization = nonlinearity.derivative(position)
return sunit * (0.5 * Instrument * linearization * ht * xi * return (0.5 * Instrument * linearization * ht * xi * Projection.adjoint * power)
Projection.adjoint * power)
...@@ -209,9 +209,6 @@ def plot(f, **kwargs): ...@@ -209,9 +209,6 @@ def plot(f, **kwargs):
ax.set_xlabel(_get_kw("xlabel", "", **kwargs)) ax.set_xlabel(_get_kw("xlabel", "", **kwargs))
ax.set_ylabel(_get_kw("ylabel", "", **kwargs)) ax.set_ylabel(_get_kw("ylabel", "", **kwargs))
cmap = _get_kw("colormap", plt.rcParams['image.cmap'], **kwargs) cmap = _get_kw("colormap", plt.rcParams['image.cmap'], **kwargs)
unit = kwargs.get('xunit')
if not unit:
unit = 1.
if isinstance(dom, RGSpace): if isinstance(dom, RGSpace):
if len(dom.shape) == 1: if len(dom.shape) == 1:
npoints = dom.shape[0] npoints = dom.shape[0]
...@@ -231,8 +228,8 @@ def plot(f, **kwargs): ...@@ -231,8 +228,8 @@ def plot(f, **kwargs):
ny = dom.shape[1] ny = dom.shape[1]
dx = dom.distances[0] dx = dom.distances[0]
dy = dom.distances[1] dy = dom.distances[1]
xc = np.arange(nx, dtype=np.float64)*dx/unit xc = np.arange(nx, dtype=np.float64)*dx
yc = np.arange(ny, dtype=np.float64)*dy/unit yc = np.arange(ny, dtype=np.float64)*dy
im = ax.imshow(dobj.to_global_data(f.val), im = ax.imshow(dobj.to_global_data(f.val),
extent=[xc[0], xc[-1], yc[0], yc[-1]], extent=[xc[0], xc[-1], yc[0], yc[-1]],
vmin=kwargs.get("zmin"), vmin=kwargs.get("zmin"),
...@@ -249,7 +246,7 @@ def plot(f, **kwargs): ...@@ -249,7 +246,7 @@ def plot(f, **kwargs):
plt.xscale('log') plt.xscale('log')
plt.yscale('log') plt.yscale('log')
plt.title('power') plt.title('power')
xcoord = dom.k_lengths / unit xcoord = dom.k_lengths
for i, fld in enumerate(f): for i, fld in enumerate(f):
ycoord = dobj.to_global_data(fld.val) ycoord = dobj.to_global_data(fld.val)
plt.plot(xcoord, ycoord, label=label[i]) plt.plot(xcoord, ycoord, label=label[i])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment