Skip to content
Snippets Groups Projects

remove depricated `jax.tree_map`

Merged Jakob Roth requested to merge tree_map into NIFTy_8
1 file
+ 33
45
Compare changes
  • Side-by-side
  • Inline
@@ -23,19 +23,15 @@ pmp = pytest.mark.parametrize
def random_draw(key, shape, dtype, method):
def _isleaf(x):
if isinstance(x, tuple):
return bool(
reduce(lambda a, b: a * b, (isinstance(ii, int) for ii in x))
)
return bool(reduce(lambda a, b: a * b, (isinstance(ii, int) for ii in x)))
return False
swd = tree_map(
lambda x: jft.ShapeWithDtype(x, dtype), shape, is_leaf=_isleaf
)
swd = tree_map(lambda x: jft.ShapeWithDtype(x, dtype), shape, is_leaf=_isleaf)
return jft.random_like(key, jft.Vector(swd), method)
def random_noise_std_inv(key, shape):
diag = 1. / random_draw(key, shape, float, random.exponential)
diag = 1.0 / random_draw(key, shape, float, random.exponential)
def noise_std_inv(tangents):
return diag * tangents
@@ -48,45 +44,46 @@ LH_INIT = (
jft.Gaussian,
{
"data": partial(random_draw, dtype=float, method=random.normal),
"noise_std_inv": random_noise_std_inv
"noise_std_inv": random_noise_std_inv,
},
partial(random_draw, dtype=float, method=random.normal),
), (
),
(
jft.StudentT,
{
"data": partial(random_draw, dtype=float, method=random.normal),
"dof": partial(random_draw, dtype=float, method=random.exponential),
"noise_std_inv": random_noise_std_inv
"noise_std_inv": random_noise_std_inv,
},
partial(random_draw, dtype=float, method=random.normal),
), (
),
(
jft.Poissonian,
{
"data":
partial(
random_draw,
dtype=int,
method=partial(random.poisson, lam=3.14)
),
"data": partial(
random_draw, dtype=int, method=partial(random.poisson, lam=3.14)
),
},
lambda key, shape: tree_map(
lambda x: x.astype(float), 6. +
random_draw(key, shape, int, partial(random.poisson, lam=3.14))
lambda x: x.astype(float),
6.0 + random_draw(key, shape, int, partial(random.poisson, lam=3.14)),
),
), (
),
(
jft.VariableCovarianceGaussian,
{
"data": partial(random_draw, dtype=float, method=random.normal),
},
lambda key, shape: jft.Vector(
(
random_draw(key, shape, float, random.normal), 3. + 1. / jax.
tree_map(
jnp.exp, random_draw(key, shape, float, random.normal)
)
random_draw(key, shape, float, random.normal),
3.0
+ 1.0
/ tree_map(jnp.exp, random_draw(key, shape, float, random.normal)),
)
),
), (
),
(
jft.VariableCovarianceStudentT,
{
"data": partial(random_draw, dtype=float, method=random.normal),
@@ -96,40 +93,34 @@ LH_INIT = (
(
random_draw(key, shape, float, random.normal),
tree_map(
jnp.exp, 3. + 1e-1 *
random_draw(key, shape, float, random.normal)
)
jnp.exp, 3.0 + 1e-1 * random_draw(key, shape, float, random.normal)
),
)
),
)
),
)
def _assert_zero_point_estimate_residuals(residuals, point_estimates):
if not point_estimates:
return
assert_array_equal(jft.sum(point_estimates * residuals**2), 0.)
assert_array_equal(jft.sum(point_estimates * residuals**2), 0.0)
@pmp("seed", (42, 43))
@pmp("shape", ((5, 12), (5, ), (1, 2, 3, 4)))
@pmp("shape", ((5, 12), (5,), (1, 2, 3, 4)))
@pmp("ndim", (1, 2, 3))
def test_concatenate_zip(seed, shape, ndim):
keys = random.split(random.PRNGKey(seed), ndim)
zip_args = tuple(random.normal(k, shape=shape) for k in keys)
assert_array_equal(
concatenate_zip(*zip_args),
np.concatenate(list(zip(*(list(el) for el in zip_args))))
np.concatenate(list(zip(*(list(el) for el in zip_args)))),
)
@pmp("seed", (12, 42))
@pmp(
"shape",
((5, ), ((4, ), (2, ), (1, ), (1, )), ((2, [1, 1]), {
'a': (3, 1)
}))
)
@pmp("shape", ((5,), ((4,), (2,), (1,), (1,)), ((2, [1, 1]), {"a": (3, 1)})))
@pmp("lh_init", LH_INIT)
@pmp("random_point_estimates", (True, False))
@pmp("sample_mode", ("linear_resample", "nonlinear_resample"))
@@ -144,8 +135,7 @@ def test_optimize_kl_sample_consistency(
key = random.PRNGKey(seed)
key, *subkeys = random.split(key, 1 + len(draw))
init_kwargs = {
k: method(key=sk, shape=shape)
for (k, method), sk in zip(draw.items(), subkeys)
k: method(key=sk, shape=shape) for (k, method), sk in zip(draw.items(), subkeys)
}
lh = lh_init_method(**init_kwargs)
key, sk = random.split(key)
@@ -164,7 +154,7 @@ def test_optimize_kl_sample_consistency(
draw_linear_kwargs = dict(
cg_name="SL",
cg_kwargs=dict(miniter=2, absdelta=absdelta / 10., maxiter=100),
cg_kwargs=dict(miniter=2, absdelta=absdelta / 10.0, maxiter=100),
)
minimize_kwargs = dict(
name="SN",
@@ -177,7 +167,7 @@ def test_optimize_kl_sample_consistency(
if random_point_estimates:
prob = 0.5
pe, pe_td = jax.tree_util.tree_flatten(pos)
pe = list(random.bernoulli(sk, prob, shape=(len(pe), )))
pe = list(random.bernoulli(sk, prob, shape=(len(pe),)))
pe[0] = False # Ensure at least one variable is not a point-estimate
pe = tuple(map(bool, pe))
point_estimates = jax.tree_util.tree_unflatten(pe_td, pe)
@@ -256,9 +246,7 @@ if __name__ == "__main__":
test_concatenate_zip(1, (5, 12), 3)
test_optimize_kl_sample_consistency(
1,
((2, [1, 1]), {
'a': (3, 1)
}),
((2, [1, 1]), {"a": (3, 1)}),
LH_INIT[1],
random_point_estimates=True,
sample_mode="nonlinear_resample",
Loading