From dda4a0921ed820c496d6e25db1b7221223cf7333 Mon Sep 17 00:00:00 2001 From: Gordian Edenhofer <gordian.edenhofer@gmail.com> Date: Sun, 19 Jan 2025 20:03:27 -0600 Subject: [PATCH] Drop HMC/NUTS hash test b/c of changing PRNG --- test/test_re/test_hmc_hashes.py | 91 --------------------------------- 1 file changed, 91 deletions(-) delete mode 100644 test/test_re/test_hmc_hashes.py diff --git a/test/test_re/test_hmc_hashes.py b/test/test_re/test_hmc_hashes.py deleted file mode 100644 index c906b9784..000000000 --- a/test/test_re/test_hmc_hashes.py +++ /dev/null @@ -1,91 +0,0 @@ -import sys - -import jax -import nifty8.re as jft -import pytest -from jax import numpy as jnp -from numpy import ndarray - -jax.config.update("jax_enable_x64", True) - -NDARRAY_TYPE = [ndarray] - -try: - from jax.numpy import ndarray as jndarray - - NDARRAY_TYPE.append(jndarray) -except ImportError: - pass - -NDARRAY_TYPE = tuple(NDARRAY_TYPE) - - -def _json_serialize(obj): - if isinstance(obj, NDARRAY_TYPE): - return obj.tolist() - raise TypeError(f"unknown type {type(obj)}") - - -def hashit(obj, n_chars=8) -> str: - """Get first `n_chars` characters of Blake2B hash of `obj`.""" - import hashlib - import json - - return hashlib.blake2b( - bytes(json.dumps(obj, default=_json_serialize), "utf-8") - ).hexdigest()[:n_chars] - - -def test_hmc_hash(): - """Test sapmler output against known hash from previous commits.""" - x0 = jnp.array([0.1, 1.223], dtype=jnp.float32) - sampler = jft.HMCChain( - potential_energy=lambda x: jnp.sum(x**2), - inverse_mass_matrix=1.0, - position_proto=x0, - step_size=0.193, - num_steps=100, - max_energy_difference=1.0, - ) - chain, (key, pos) = sampler.generate_n_samples( - key=42, initial_position=x0, num_samples=1000, save_intermediates=True - ) - assert chain.divergences.sum() == 0 - accepted = chain.trees.accepted - results = (pos, key, chain.samples, accepted) - results_hash = hashit(results, n_chars=20) - print(f"full hash: {results_hash}", file=sys.stderr) - old_hash = "3d665689f809a98c81b3" - assert results_hash == old_hash - - -def test_nuts_hash(): - """Test sapmler output against known hash from previous commits.""" - jax.config.update("jax_enable_x64", False) - - x0 = jnp.array([0.1, 1.223], dtype=jnp.float32) - sampler = jft.NUTSChain( - potential_energy=lambda x: jnp.sum(x**2), - inverse_mass_matrix=1.0, - position_proto=x0, - step_size=0.193, - max_tree_depth=10, - bias_transition=False, - max_energy_difference=1.0, - ) - chain, (key, pos) = sampler.generate_n_samples( - key=42, initial_position=x0, num_samples=1000, save_intermediates=False - ) - assert chain.divergences.sum() == 0 - results = (pos, key, chain.samples) - results_hash = hashit(results, n_chars=20) - print(f"full hash: {results_hash}", file=sys.stderr) - old_hash = "8043850d7249acb77b26" - assert results_hash == old_hash - - jax.config.update("jax_enable_x64", True) - - -if __name__ == "__main__": - test_hmc_hash() - test_nuts_hash() -- GitLab