Commit d5563c04 authored by Philipp Arras's avatar Philipp Arras
Browse files

Change Nonlinear Energies to HarmonicTransform convention

parent 51de4744
Pipeline #24108 failed with stage
in 4 minutes and 9 seconds
......@@ -22,7 +22,7 @@ from ..minimization.energy import Energy
class NoiseEnergy(Energy):
def __init__(self, position, d, m, D, t, FFT, Instrument, nonlinearity,
def __init__(self, position, d, m, D, t, HarmonicTransform, Instrument, nonlinearity,
alpha, q, Projection, munit=1., sunit=1., dunit=1., samples=3, sample_list=None,
inverter=None):
super(NoiseEnergy, self).__init__(position=position)
......@@ -32,7 +32,7 @@ class NoiseEnergy(Energy):
self.N = DiagonalOperator(diagonal=dunit**2 * exp(self.position))
self.t = t
self.samples = samples
self.FFT = FFT
self.ht = HarmonicTransform
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.munit = munit
......@@ -54,11 +54,11 @@ class NoiseEnergy(Energy):
self.inverter = inverter
A = Projection.adjoint_times(munit * exp(.5*self.t)) # unit: munit
map_s = FFT.inverse_times(A * m)
map_s = self.ht(A * m)
self._gradient = None
for sample in self.sample_list:
map_s = FFT.inverse_times(A * sample)
map_s = self.ht(A * sample)
residual = self.d - self.Instrument(sunit * self.nonlinearity(map_s))
lh = .5 * residual.vdot(self.N.inverse_times(residual))
......@@ -79,7 +79,7 @@ class NoiseEnergy(Energy):
def at(self, position):
return self.__class__(
position, self.d, self.m, self.D, self.t, self.FFT,
position, self.d, self.m, self.D, self.t, self.ht,
self.Instrument, self.nonlinearity, self.alpha, self.q,
self.Projection, munit=self.munit, sunit=self.sunit,
dunit=self.dunit, sample_list=self.sample_list,
......
......@@ -20,11 +20,11 @@ from ..operators.inversion_enabler import InversionEnabler
from .response_operators import LinearizedPowerResponse
def NonlinearPowerCurvature(position, FFT, Instrument, nonlinearity,
def NonlinearPowerCurvature(position, HarmonicTransform, Instrument, nonlinearity,
Projection, N, T, sample_list, inverter, munit=1., sunit=1.):
result = None
for sample in sample_list:
LinR = LinearizedPowerResponse(Instrument, nonlinearity, FFT, Projection, position, sample, munit, sunit)
LinR = LinearizedPowerResponse(Instrument, nonlinearity, HarmonicTransform, Projection, position, sample, munit, sunit)
op = LinR.adjoint*N.inverse*LinR
result = op if result is None else result + op
result = result*(1./len(sample_list)) + T
......
......@@ -51,9 +51,9 @@ class NonlinearPowerEnergy(Energy):
default : 3
"""
def __init__(self, position, d, N, m, D, FFT, Instrument, nonlinearity, Projection, sigma=0., samples=3, sample_list=None, munit=1., sunit=1., inverter=None):
def __init__(self, position, d, N, m, D, HarmonicTransform, Instrument, nonlinearity, Projection, sigma=0., samples=3, sample_list=None, munit=1., sunit=1., inverter=None):
super(NonlinearPowerEnergy, self).__init__(position)
self.d, self.N, self.m, self.D, self.FFT = d, N, m, D, FFT
self.d, self.N, self.m, self.D, self.ht = d, N, m, D, HarmonicTransform
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.Projection = Projection
......@@ -73,13 +73,13 @@ class NonlinearPowerEnergy(Energy):
strength=sigma, logarithmic=True)
A = Projection.adjoint_times(munit * exp(.5*position)) # unit: munit
map_s = FFT.inverse_times(A * m)
map_s = self.ht(A * m)
Tpos = self.T(position)
self._gradient = None
for sample in self.sample_list:
map_s = FFT.inverse_times(A * sample)
LinR = LinearizedPowerResponse(Instrument, nonlinearity, FFT, Projection, position, sample, munit, sunit)
map_s = self.ht(A * sample)
LinR = LinearizedPowerResponse(Instrument, nonlinearity, self.ht, Projection, position, sample, munit, sunit)
residual = self.d - self.Instrument(sunit * self.nonlinearity(map_s))
lh = 0.5 * residual.vdot(self.N.inverse_times(residual))
......@@ -99,7 +99,7 @@ class NonlinearPowerEnergy(Energy):
def at(self, position):
return self.__class__(position, self.d, self.N, self.m, self.D,
self.FFT, self.Instrument, self.nonlinearity,
self.ht, self.Instrument, self.nonlinearity,
self.Projection, sigma=self.sigma,
samples=len(self.sample_list),
sample_list=self.sample_list,
......@@ -119,6 +119,6 @@ class NonlinearPowerEnergy(Energy):
@memo
def curvature(self):
return NonlinearPowerCurvature(
self.position, self.FFT, self.Instrument, self.nonlinearity,
self.position, self.ht, self.Instrument, self.nonlinearity,
self.Projection, self.N, self.T, self.sample_list,
self.inverter, self.munit, self.sunit)
......@@ -23,18 +23,18 @@ from .response_operators import LinearizedSignalResponse
class NonlinearWienerFilterEnergy(Energy):
def __init__(self, position, d, Instrument, nonlinearity, FFT, power, N, S, sunit=1.,
def __init__(self, position, d, Instrument, nonlinearity, HarmonicTransform, power, N, S, sunit=1.,
inverter=None):
super(NonlinearWienerFilterEnergy, self).__init__(position=position)
self.d = d
self.sunit = sunit
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.FFT = FFT
self.ht = HarmonicTransform
self.power = power
position_map = FFT.inverse_times(self.power * self.position)
position_map = self.ht(self.power * self.position)
self.LinearizedResponse = \
LinearizedSignalResponse(Instrument, nonlinearity, FFT, power,
LinearizedSignalResponse(Instrument, nonlinearity, self.ht, power,
position_map, sunit)
residual = d - Instrument(sunit * nonlinearity(position_map))
self.N = N
......@@ -48,7 +48,7 @@ class NonlinearWienerFilterEnergy(Energy):
def at(self, position):
return self.__class__(position, self.d, self.Instrument,
self.nonlinearity, self.FFT, self.power, self.N,
self.nonlinearity, self.ht, self.power, self.N,
self.S, self.sunit, inverter=self.inverter)
@property
......
......@@ -19,13 +19,13 @@
from ..field import exp
def LinearizedSignalResponse(Instrument, nonlinearity, FFT, power, s, sunit):
return sunit * (Instrument * nonlinearity.derivative(s) * FFT.inverse * power)
def LinearizedSignalResponse(Instrument, nonlinearity, HarmonicTransform, power, s, sunit):
return sunit * (Instrument * nonlinearity.derivative(s) * HarmonicTransform * power)
def LinearizedPowerResponse(Instrument, nonlinearity, FFT, Projection, t, m, munit, sunit):
def LinearizedPowerResponse(Instrument, nonlinearity, HarmonicTransform, Projection, t, m, munit, sunit):
power = exp(0.5*t) * munit
position = FFT.inverse_times(Projection.adjoint_times(power) * m)
position = HarmonicTransform(Projection.adjoint_times(power) * m)
linearization = nonlinearity.derivative(position)
return sunit * (0.5 * Instrument * linearization * FFT.inverse * m *
return sunit * (0.5 * Instrument * linearization * HarmonicTransform * m *
Projection.adjoint * power)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment