Commit 4380d2b9 authored by Philipp Arras's avatar Philipp Arras

Add new minimization heuristics to getting_started_3

parent 4c3c5422
Pipeline #53583 passed with stages
in 8 minutes and 13 seconds
......@@ -2,6 +2,7 @@
git_version.py
# custom
*.txt
setup.cfg
.idea
.DS_Store
......
......@@ -103,10 +103,15 @@ if __name__ == '__main__':
data = signal_response(mock_position) + N.draw_sample()
# Minimization parameters
ic_sampling = ift.GradientNormController(iteration_limit=100)
ic_newton = ift.GradInfNormController(
name='Newton', tol=1e-7, iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton)
fo = 'energy0.txt'
fs = 'energy1.txt'
fi = 'energy2.txt'
ic_sampling = ift.AbsDeltaEnergyController(0.5, convergence_level=5,
iteration_limit=100,
file_name=fs)
ic_newton = ift.GradInfNormController(name='Newton', tol=1e-7,
iteration_limit=35, file_name=fo)
minimizer = ift.NewtonCG(ic_newton, file_name=fi)
# Set up likelihood and information Hamiltonian
likelihood = ift.GaussianEnergy(mean=data,
......@@ -128,7 +133,7 @@ if __name__ == '__main__':
# Draw new samples to approximate the KL five times
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL(mean, H, N_samples)
KL = ift.MetricGaussianKL(mean, H, N_samples, napprox=20)
KL, convergence = minimizer(KL)
mean = KL.position
......@@ -159,3 +164,5 @@ if __name__ == '__main__':
linewidth=[1.]*len(powers) + [3., 3.])
plot.output(ny=1, nx=3, xsize=24, ysize=6, name=filename_res)
print("Saved results as '{}'.".format(filename_res))
ift.energy_history_analysis(fo, fi, fs, fname='energy_history.png')
......@@ -555,7 +555,7 @@ def energy_history_analysis(fname_outer, fname_inner, fname_sampling,
tsa, esa, _ = np.loadtxt(fname_sampling, delimiter=' ').T
tou, eou, _ = np.loadtxt(fname_outer, delimiter=' ').T
tin, ein, _ = np.loadtxt(fname_inner, delimiter=' ').T
t0 = np.min([tsa, tou, tin])
t0 = np.min([*tsa, *tou, *tin])
tsa = (tsa-t0)/3600
tou = (tou-t0)/3600
tin = (tin-t0)/3600
......@@ -569,11 +569,11 @@ def energy_history_analysis(fname_outer, fname_inner, fname_sampling,
ax0.legend([p1, p2], [p1.get_label(), p2.get_label()])
ax1.scatter(tou, eou, marker='>', c='g')
ax1.set_ylabel('Newton energy')
ax1.set_yscale('log')
# ax1.set_yscale('log')
ax1.set_xlabel("Time [h]")
plt.tight_layout()
if fname is None:
plt.savefig(fname)
else:
plt.show()
else:
plt.savefig(fname)
plt.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment