Commit 741d11a6 authored by Philipp Arras's avatar Philipp Arras
Browse files

Renamings in energy classes

parent 4cf96788
Pipeline #24175 failed with stage
in 3 minutes and 57 seconds
......@@ -20,13 +20,13 @@ from ..operators.inversion_enabler import InversionEnabler
from .response_operators import LinearizedPowerResponse
def NonlinearPowerCurvature(position, ht, Instrument, nonlinearity,
Projection, N, T, sample_list, inverter, munit=1., sunit=1.):
def NonlinearPowerCurvature(tau, ht, Instrument, nonlinearity,
Projection, N, T, xi_sample_list, inverter, munit=1., sunit=1.):
result = None
for sample in sample_list:
for xi_sample in xi_sample_list:
LinearizedResponse = LinearizedPowerResponse(
Instrument, nonlinearity, ht, Projection, position, sample, munit, sunit)
Instrument, nonlinearity, ht, Projection, tau, xi_sample, munit, sunit)
op = LinearizedResponse.adjoint*N.inverse*LinearizedResponse
result = op if result is None else result + op
result = result * (1. / len(sample_list)) + T
result = result * (1. / len(xi_sample_list)) + T
return InversionEnabler(result, inverter)
......@@ -51,11 +51,11 @@ class NonlinearPowerEnergy(Energy):
default : 3
"""
def __init__(self, position, d, N, m, D, ht, Instrument, nonlinearity,
def __init__(self, position, d, N, xi, D, ht, Instrument, nonlinearity,
Projection, sigma=0., samples=3, sample_list=None,
inverter=None, munit=1., sunit=1.):
super(NonlinearPowerEnergy, self).__init__(position)
self.m = m
self.xi = xi
self.D = D
self.d = d
self.N = N
......@@ -70,36 +70,27 @@ class NonlinearPowerEnergy(Energy):
self.sunit = sunit
if sample_list is None:
if samples is None or samples == 0:
sample_list = [m]
sample_list = [xi]
else:
sample_list = [D.generate_posterior_sample() + m
sample_list = [D.generate_posterior_sample() + xi
for _ in range(samples)]
self.sample_list = sample_list
self.inverter = inverter
A = Projection.adjoint_times(munit * exp(.5 * position)) # unit: munit
map_s = self.ht(A * m)
map_s = self.ht(A * xi)
Tpos = self.T(position)
self._gradient = None
for sample in self.sample_list:
map_s = self.ht(A * sample)
for xi_sample in self.sample_list:
map_s = self.ht(A * xi_sample)
LinR = LinearizedPowerResponse(
Instrument,
nonlinearity,
self.ht,
Projection,
position,
sample,
munit,
sunit)
self.Instrument, self.nonlinearity, self.ht, self.Projection,
self.position, xi_sample, munit=self.munit, sunit=self.sunit)
residual = self.d - \
self.Instrument(sunit * self.nonlinearity(map_s))
lh = 0.5 * residual.vdot(self.N.inverse_times(residual))
LinR = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.ht, self.Projection,
self.position, sample, munit=self.munit, sunit=self.sunit)
grad = LinR.adjoint_times(self.N.inverse_times(residual))
if self._gradient is None:
......@@ -115,7 +106,7 @@ class NonlinearPowerEnergy(Energy):
self._gradient += Tpos
def at(self, position):
return self.__class__(position, self.d, self.N, self.m, self.D,
return self.__class__(position, self.d, self.N, self.xi, self.D,
self.ht, self.Instrument, self.nonlinearity,
self.Projection, sigma=self.sigma,
samples=len(self.sample_list),
......
......@@ -23,9 +23,9 @@ def LinearizedSignalResponse(Instrument, nonlinearity, ht, power, s, sunit):
return sunit * (Instrument * nonlinearity.derivative(s) * ht * power)
def LinearizedPowerResponse(Instrument, nonlinearity, ht, Projection, t, m, munit, sunit):
power = exp(0.5*t) * munit
position = ht(Projection.adjoint_times(power) * m)
def LinearizedPowerResponse(Instrument, nonlinearity, ht, Projection, tau, xi, munit, sunit):
power = exp(0.5*tau) * munit
position = ht(Projection.adjoint_times(power) * xi)
linearization = nonlinearity.derivative(position)
return sunit * (0.5 * Instrument * linearization * ht * m *
return sunit * (0.5 * Instrument * linearization * ht * xi *
Projection.adjoint * power)
......@@ -313,5 +313,5 @@ class Curvature_Tests(unittest.TestCase):
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
tol = 1e-1
tol = 1e-3
assert_allclose(a.val, b.val, rtol=tol, atol=tol)
......@@ -143,7 +143,7 @@ class Energy_Tests(unittest.TestCase):
energy0 = ift.library.NonlinearPowerEnergy(
position=tau0,
d=d,
m=xi,
xi=xi,
D=D,
Instrument=R,
Projection=P,
......@@ -154,7 +154,7 @@ class Energy_Tests(unittest.TestCase):
energy1 = ift.library.NonlinearPowerEnergy(
position=tau1,
d=d,
m=xi,
xi=xi,
D=D,
Instrument=R,
Projection=P,
......@@ -286,7 +286,7 @@ class Curvature_Tests(unittest.TestCase):
energy0 = ift.library.NonlinearPowerEnergy(
position=tau0,
d=d,
m=xi,
xi=xi,
D=D,
Instrument=R,
Projection=P,
......@@ -300,5 +300,5 @@ class Curvature_Tests(unittest.TestCase):
a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction)
tol = 1e-1
tol = 1e-3
assert_allclose(a.val, b.val, rtol=tol, atol=tol)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment