remove depricated `jax.tree_map`
Compare changes
+ 33
− 45
@@ -23,19 +23,15 @@ pmp = pytest.mark.parametrize
@@ -48,45 +44,46 @@ LH_INIT = (
@@ -96,40 +93,34 @@ LH_INIT = (
@@ -144,8 +135,7 @@ def test_optimize_kl_sample_consistency(
@@ -164,7 +154,7 @@ def test_optimize_kl_sample_consistency(
@@ -177,7 +167,7 @@ def test_optimize_kl_sample_consistency(
@@ -256,9 +246,7 @@ if __name__ == "__main__":