Commit 6d5b914b authored by Philipp Arras's avatar Philipp Arras
Browse files

Plotting of energy history

parent baccf4fd
Pipeline #75216 failed with stages
in 3 minutes and 5 seconds
......@@ -114,7 +114,9 @@ if __name__ == '__main__':
ic_newton = ift.AbsDeltaEnergyController(name='Newton',
deltaE=0.01,
iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton)
ic_sampling.enable_logging()
ic_newton.enable_logging()
minimizer = ift.NewtonCG(ic_newton, activate_logging=True)
## number of samples used to estimate the KL
N_samples = 20
......@@ -143,10 +145,15 @@ if __name__ == '__main__':
plot.add([A2.force(KL.position),
A2.force(mock_position)],
title="power2")
plot.output(nx=2,
plot.add((ic_newton.history, ic_sampling.history,
minimizer.inversion_history),
label=['KL', 'Sampling', 'Newton inversion'],
title='Cumulative energies', s=[None, None, 1],
alpha=[None, 0.2, None])
plot.output(nx=3,
ny=2,
ysize=10,
xsize=10,
xsize=15,
name=filename.format("loop_{:02d}".format(i)))
# Done, draw posterior samples
......
......@@ -104,10 +104,8 @@ class EnergyHistory(object):
raise ValueError
self._lst.append((float(x[0]), float(x[1])))
def pop_all(self):
lst = self._lst
def reset(self):
self._lst = []
return lst
@property
def timestamps(self):
......
......@@ -278,19 +278,25 @@ def _plot_history(f, ax, **kwargs):
color = kwargs.pop("color", None)
if not isinstance(color, list):
color = [color] * len(f)
size = kwargs.pop("s", None)
if not isinstance(size, list):
size = [size] * len(f)
ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", ""))
plt.xscale(kwargs.pop("xscale", "linear"))
plt.yscale(kwargs.pop("yscale", "linear"))
mi, ma = np.inf, -np.inf
for i, fld in enumerate(f):
xcoord = fld.timestamps
# xcoord = date2num([dt.fromtimestamp(ts) for ts in xcoord])
# xfmt = DateFormatter('%H:%M')
# ax.xaxis.set_major_formatter(xfmt)
xcoord = date2num([dt.fromtimestamp(ts) for ts in fld.timestamps])
ycoord = fld.energy_values
ax.scatter(xcoord, ycoord, label=label[i], alpha=alpha[i],
color=color[i])
color=color[i], s=size[i])
mi, ma = min([min(xcoord), mi]), max([max(xcoord), ma])
delta = (ma-mi)*0.05
ax.set_xlim((mi-delta, ma+delta))
xfmt = DateFormatter('%H:%M')
ax.xaxis.set_major_formatter(xfmt)
_limit_xy(**kwargs)
if label != ([None]*len(f)):
plt.legend()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment