Commit 741d11a6 authored by Philipp Arras's avatar Philipp Arras
Browse files

Renamings in energy classes

parent 4cf96788
Pipeline #24175 failed with stage
in 3 minutes and 57 seconds
...@@ -20,13 +20,13 @@ from ..operators.inversion_enabler import InversionEnabler ...@@ -20,13 +20,13 @@ from ..operators.inversion_enabler import InversionEnabler
from .response_operators import LinearizedPowerResponse from .response_operators import LinearizedPowerResponse
def NonlinearPowerCurvature(position, ht, Instrument, nonlinearity, def NonlinearPowerCurvature(tau, ht, Instrument, nonlinearity,
Projection, N, T, sample_list, inverter, munit=1., sunit=1.): Projection, N, T, xi_sample_list, inverter, munit=1., sunit=1.):
result = None result = None
for sample in sample_list: for xi_sample in xi_sample_list:
LinearizedResponse = LinearizedPowerResponse( LinearizedResponse = LinearizedPowerResponse(
Instrument, nonlinearity, ht, Projection, position, sample, munit, sunit) Instrument, nonlinearity, ht, Projection, tau, xi_sample, munit, sunit)
op = LinearizedResponse.adjoint*N.inverse*LinearizedResponse op = LinearizedResponse.adjoint*N.inverse*LinearizedResponse
result = op if result is None else result + op result = op if result is None else result + op
result = result * (1. / len(sample_list)) + T result = result * (1. / len(xi_sample_list)) + T
return InversionEnabler(result, inverter) return InversionEnabler(result, inverter)
...@@ -51,11 +51,11 @@ class NonlinearPowerEnergy(Energy): ...@@ -51,11 +51,11 @@ class NonlinearPowerEnergy(Energy):
default : 3 default : 3
""" """
def __init__(self, position, d, N, m, D, ht, Instrument, nonlinearity, def __init__(self, position, d, N, xi, D, ht, Instrument, nonlinearity,
Projection, sigma=0., samples=3, sample_list=None, Projection, sigma=0., samples=3, sample_list=None,
inverter=None, munit=1., sunit=1.): inverter=None, munit=1., sunit=1.):
super(NonlinearPowerEnergy, self).__init__(position) super(NonlinearPowerEnergy, self).__init__(position)
self.m = m self.xi = xi
self.D = D self.D = D
self.d = d self.d = d
self.N = N self.N = N
...@@ -70,36 +70,27 @@ class NonlinearPowerEnergy(Energy): ...@@ -70,36 +70,27 @@ class NonlinearPowerEnergy(Energy):
self.sunit = sunit self.sunit = sunit
if sample_list is None: if sample_list is None:
if samples is None or samples == 0: if samples is None or samples == 0:
sample_list = [m] sample_list = [xi]
else: else:
sample_list = [D.generate_posterior_sample() + m sample_list = [D.generate_posterior_sample() + xi
for _ in range(samples)] for _ in range(samples)]
self.sample_list = sample_list self.sample_list = sample_list
self.inverter = inverter self.inverter = inverter
A = Projection.adjoint_times(munit * exp(.5 * position)) # unit: munit A = Projection.adjoint_times(munit * exp(.5 * position)) # unit: munit
map_s = self.ht(A * m) map_s = self.ht(A * xi)
Tpos = self.T(position) Tpos = self.T(position)
self._gradient = None self._gradient = None
for sample in self.sample_list: for xi_sample in self.sample_list:
map_s = self.ht(A * sample) map_s = self.ht(A * xi_sample)
LinR = LinearizedPowerResponse( LinR = LinearizedPowerResponse(
Instrument, self.Instrument, self.nonlinearity, self.ht, self.Projection,
nonlinearity, self.position, xi_sample, munit=self.munit, sunit=self.sunit)
self.ht,
Projection,
position,
sample,
munit,
sunit)
residual = self.d - \ residual = self.d - \
self.Instrument(sunit * self.nonlinearity(map_s)) self.Instrument(sunit * self.nonlinearity(map_s))
lh = 0.5 * residual.vdot(self.N.inverse_times(residual)) lh = 0.5 * residual.vdot(self.N.inverse_times(residual))
LinR = LinearizedPowerResponse(
self.Instrument, self.nonlinearity, self.ht, self.Projection,
self.position, sample, munit=self.munit, sunit=self.sunit)
grad = LinR.adjoint_times(self.N.inverse_times(residual)) grad = LinR.adjoint_times(self.N.inverse_times(residual))
if self._gradient is None: if self._gradient is None:
...@@ -115,7 +106,7 @@ class NonlinearPowerEnergy(Energy): ...@@ -115,7 +106,7 @@ class NonlinearPowerEnergy(Energy):
self._gradient += Tpos self._gradient += Tpos
def at(self, position): def at(self, position):
return self.__class__(position, self.d, self.N, self.m, self.D, return self.__class__(position, self.d, self.N, self.xi, self.D,
self.ht, self.Instrument, self.nonlinearity, self.ht, self.Instrument, self.nonlinearity,
self.Projection, sigma=self.sigma, self.Projection, sigma=self.sigma,
samples=len(self.sample_list), samples=len(self.sample_list),
......
...@@ -23,9 +23,9 @@ def LinearizedSignalResponse(Instrument, nonlinearity, ht, power, s, sunit): ...@@ -23,9 +23,9 @@ def LinearizedSignalResponse(Instrument, nonlinearity, ht, power, s, sunit):
return sunit * (Instrument * nonlinearity.derivative(s) * ht * power) return sunit * (Instrument * nonlinearity.derivative(s) * ht * power)
def LinearizedPowerResponse(Instrument, nonlinearity, ht, Projection, t, m, munit, sunit): def LinearizedPowerResponse(Instrument, nonlinearity, ht, Projection, tau, xi, munit, sunit):
power = exp(0.5*t) * munit power = exp(0.5*tau) * munit
position = ht(Projection.adjoint_times(power) * m) position = ht(Projection.adjoint_times(power) * xi)
linearization = nonlinearity.derivative(position) linearization = nonlinearity.derivative(position)
return sunit * (0.5 * Instrument * linearization * ht * m * return sunit * (0.5 * Instrument * linearization * ht * xi *
Projection.adjoint * power) Projection.adjoint * power)
...@@ -313,5 +313,5 @@ class Curvature_Tests(unittest.TestCase): ...@@ -313,5 +313,5 @@ class Curvature_Tests(unittest.TestCase):
a = (gradient1 - gradient0) / eps a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction) b = energy0.curvature(direction)
tol = 1e-1 tol = 1e-3
assert_allclose(a.val, b.val, rtol=tol, atol=tol) assert_allclose(a.val, b.val, rtol=tol, atol=tol)
...@@ -143,7 +143,7 @@ class Energy_Tests(unittest.TestCase): ...@@ -143,7 +143,7 @@ class Energy_Tests(unittest.TestCase):
energy0 = ift.library.NonlinearPowerEnergy( energy0 = ift.library.NonlinearPowerEnergy(
position=tau0, position=tau0,
d=d, d=d,
m=xi, xi=xi,
D=D, D=D,
Instrument=R, Instrument=R,
Projection=P, Projection=P,
...@@ -154,7 +154,7 @@ class Energy_Tests(unittest.TestCase): ...@@ -154,7 +154,7 @@ class Energy_Tests(unittest.TestCase):
energy1 = ift.library.NonlinearPowerEnergy( energy1 = ift.library.NonlinearPowerEnergy(
position=tau1, position=tau1,
d=d, d=d,
m=xi, xi=xi,
D=D, D=D,
Instrument=R, Instrument=R,
Projection=P, Projection=P,
...@@ -286,7 +286,7 @@ class Curvature_Tests(unittest.TestCase): ...@@ -286,7 +286,7 @@ class Curvature_Tests(unittest.TestCase):
energy0 = ift.library.NonlinearPowerEnergy( energy0 = ift.library.NonlinearPowerEnergy(
position=tau0, position=tau0,
d=d, d=d,
m=xi, xi=xi,
D=D, D=D,
Instrument=R, Instrument=R,
Projection=P, Projection=P,
...@@ -300,5 +300,5 @@ class Curvature_Tests(unittest.TestCase): ...@@ -300,5 +300,5 @@ class Curvature_Tests(unittest.TestCase):
a = (gradient1 - gradient0) / eps a = (gradient1 - gradient0) / eps
b = energy0.curvature(direction) b = energy0.curvature(direction)
tol = 1e-1 tol = 1e-3
assert_allclose(a.val, b.val, rtol=tol, atol=tol) assert_allclose(a.val, b.val, rtol=tol, atol=tol)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment