Commit e07a9ff6 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add units to NonlinearPowerEnergy and gradient

parent c7f8b091
......@@ -17,11 +17,11 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from .. import exp
from ..utilities import memo
from ..minimization.energy import Energy
from ..operators.smoothness_operator import SmoothnessOperator
from ..utilities import memo
from .nonlinear_power_curvature import NonlinearPowerCurvature
from .response_operators import LinearizedPowerResponse
from ..minimization.energy import Energy
class NonlinearPowerEnergy(Energy):
......@@ -51,23 +51,14 @@ class NonlinearPowerEnergy(Energy):
default : 3
"""
def __init__(self, position, d, N, m, D, FFT, Instrument, nonlinearity,
Projection, sigma=0., samples=3, sample_list=None,
inverter=None):
def __init__(self, position, d, N, m, D, FFT, Instrument, nonlinearity, Projection, sigma=0., samples=3, sample_list=None, munit=1., sunit=1., inverter=None):
super(NonlinearPowerEnergy, self).__init__(position)
self.m = m
self.D = D
self.d = d
self.N = N
self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=sigma, logarithmic=True)
self.FFT = FFT
self.d, self.N, self.m, self.D, self.FFT = d, N, m, D, FFT
self.Instrument = Instrument
self.nonlinearity = nonlinearity
self.Projection = Projection
self._sigma = sigma
self.power = self.Projection.adjoint_times(exp(0.5*self.position))
self.sigma = sigma
self.munit = munit
if sample_list is None:
if samples is None or samples == 0:
sample_list = [m]
......@@ -76,7 +67,34 @@ class NonlinearPowerEnergy(Energy):
for _ in range(samples)]
self.sample_list = sample_list
self.inverter = inverter
self._value, self._gradient = self._value_and_grad()
self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=sigma, logarithmic=True)
A = Projection.adjoint_times(munit * exp(.5*position)) # unit: munit
map_s = FFT.inverse_times(A * m)
Tpos = self.T(position)
self._gradient = None
for sample in self.sample_list:
map_s = FFT.inverse_times(A * sample)
LinR = LinearizedPowerResponse(Instrument, nonlinearity, FFT, Projection, position, sample, munit, sunit)
residual = self.d - self.Instrument(sunit * self.nonlinearity(map_s))
lh = 0.5 * residual.vdot(self.N.inverse_times(residual))
grad = LinR.adjoint_times(self.N.inverse_times(residual))
if self._gradient is None:
self._value = lh
self._gradient = grad.copy()
else:
self._value += lh
self._gradient += grad
self._value *= 1./len(self.sample_list)
self._value += 0.5*self.position.vdot(Tpos)
self._gradient *= -1./len(self.sample_list)
self._gradient += Tpos
def at(self, position):
return self.__class__(position, self.d, self.N, self.m, self.D,
......@@ -86,30 +104,6 @@ class NonlinearPowerEnergy(Energy):
sample_list=self.sample_list,
inverter=self.inverter)
def _value_and_grad(self):
likelihood_gradient = None
for sample in self.sample_list:
residual = self.d - \
self.Instrument(self.nonlinearity(
self.FFT.adjoint_times(self.power*sample)))
lh = 0.5 * residual.vdot(self.N.inverse_times(residual))
LinR = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.FFT, self.Projection,
self.position, sample)
grad = LinR.adjoint_times(self.N.inverse_times(residual))
if likelihood_gradient is None:
likelihood = lh
likelihood_gradient = grad.copy()
else:
likelihood += lh
likelihood_gradient += grad
Tpos = self.T(self.position)
likelihood *= 1./len(self.sample_list)
likelihood += 0.5*self.position.vdot(Tpos)
likelihood_gradient *= -1./len(self.sample_list)
likelihood_gradient += Tpos
return likelihood, likelihood_gradient
@property
def value(self):
return self._value
......
......@@ -23,9 +23,9 @@ def LinearizedSignalResponse(Instrument, nonlinearity, FFT, power, s, sunit):
return sunit * (Instrument * nonlinearity.derivative(s) * FFT.inverse * power)
def LinearizedPowerResponse(Instrument, nonlinearity, FFT, Projection, t, m):
power = exp(0.5*t)
position = FFT.adjoint_times(Projection.adjoint_times(power) * m)
def LinearizedPowerResponse(Instrument, nonlinearity, FFT, Projection, t, m, munit, sunit):
power = exp(0.5*t) * munit
position = FFT.inverse_times(Projection.adjoint_times(power) * m)
linearization = nonlinearity.derivative(position)
return (0.5 * Instrument * linearization * FFT.adjoint * m *
Projection.adjoint * power)
return sunit * (0.5 * Instrument * linearization * FFT.inverse * m *
Projection.adjoint * power)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment