Skip to content
Snippets Groups Projects
Commit dd5faa0a authored by Gordian Edenhofer's avatar Gordian Edenhofer
Browse files

geomap: Plot evolution of orthonormal samples

parent 66a8b1e7
No related branches found
No related tags found
1 merge request!832Better Lanczos interface
...@@ -130,6 +130,7 @@ def geomap(ham: jft.StandardHamiltonian, order: int, key, mirror_samples=True): ...@@ -130,6 +130,7 @@ def geomap(ham: jft.StandardHamiltonian, order: int, key, mirror_samples=True):
logdet, vecs, smpl = stochastic_lq_logdet( logdet, vecs, smpl = stochastic_lq_logdet(
mat, order, pos, key_lcz, shape0=p.size mat, order, pos, key_lcz, shape0=p.size
) )
# smpl = random.normal(key_smpls, p.shape, dtype=p.dtype)
s = smpl.copy() s = smpl.copy()
# TODO: Pull into new lanczos method which computes orthoganlized smpls # TODO: Pull into new lanczos method which computes orthoganlized smpls
# for vecs # for vecs
...@@ -168,7 +169,7 @@ jax.config.update("jax_log_compiles", False) ...@@ -168,7 +169,7 @@ jax.config.update("jax_log_compiles", False)
print("!!!!!!!!!!!!!!!!!!!!!!! HAM", ham(pos)) print("!!!!!!!!!!!!!!!!!!!!!!! HAM", ham(pos))
print("!!!!!!!!!!!!!!!!!!!!!!! metric", ham.metric(pos, pos) @ pos) print("!!!!!!!!!!!!!!!!!!!!!!! metric", ham.metric(pos, pos) @ pos)
# This is 50 times slower in compile time than ham.metric # This is 50 times slower in compile time than ham.metric
geomap_order = 10 geomap_order = 5
geomap_energy = geomap(ham, geomap_order, subkey_geomap, mirror_samples=False) geomap_energy = geomap(ham, geomap_order, subkey_geomap, mirror_samples=False)
# jft.disable_jax_control_flow._DISABLE_CONTROL_FLOW_PRIM = True # jft.disable_jax_control_flow._DISABLE_CONTROL_FLOW_PRIM = True
...@@ -201,6 +202,24 @@ plt.plot(jnp.abs(prr_smpl - ortho_smpl), label="abs diff", alpha=0.3) ...@@ -201,6 +202,24 @@ plt.plot(jnp.abs(prr_smpl - ortho_smpl), label="abs diff", alpha=0.3)
plt.legend() plt.legend()
plt.show() plt.show()
# %%
smpls_by_order = []
for i in range(1, geomap_order):
_, _, s = geomap(ham, i, subkey_geomap, mirror_samples=False)(opt_state_geomap.x, return_sample=True)
smpls_by_order += [s]
smpls_by_order = jnp.array(smpls_by_order)
# %%
fig, axs = plt.subplots(2, 1, sharex=True)
d = jnp.diff(smpls_by_order, axis=0)
axs.flat[0].plot(smpls_by_order.T, label=jnp.arange(1, geomap_order), alpha=0.3, marker=".")
axs.flat[0].axhline(0., color="red")
axs.flat[0].legend()
axs.flat[1].plot(d.T, label=jnp.arange(1, geomap_order - 1), alpha=0.3, marker=".")
axs.flat[1].axhline(0., color="red")
axs.flat[1].legend()
plt.show()
# %% # %%
plt.plot( plt.plot(
jnp.array( jnp.array(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment