diff --git a/ChangeLog.md b/ChangeLog.md index 6f7b699d15e2301d15e41726e1b46890bbbf61fe..fb887b1a0b40c3826fe9372b42ad396f88e13d58 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -1,6 +1,12 @@ Changes since NIFTy 8 ===================== +`estimate_evidence_lower_bound` +------------- + +Renamed `batch_size` to `n_batches` for clarity. Improved batch logic. + + Minimum Python version increased to 3.10 Stabilize ICR at the cost of disallowing using old ICR reconstructions diff --git a/src/evidence_lower_bound.py b/src/evidence_lower_bound.py index 41e235142966764ddb64cc2264951538fce3ce8a..58e52ed30b81c1043e858d7d518537c779587a72 100644 --- a/src/evidence_lower_bound.py +++ b/src/evidence_lower_bound.py @@ -42,14 +42,15 @@ class _Projector(ssl.LinearOperator): Projector : LinearOperator Operator representing the projection. """ + def __init__(self, eigenvectors): - super().__init__(np.dtype('f8'), 2 * (eigenvectors.shape[0],)) + super().__init__(eigenvectors.dtype, 2 * (eigenvectors.shape[0],)) self.eigenvectors = eigenvectors def _matvec(self, x): res = x.copy() for eigenvector in self.eigenvectors.T: - res -= eigenvector * eigenvector.dot(x) + res -= eigenvector * np.vdot(eigenvector, x) return res def _rmatvec(self, x): @@ -61,40 +62,76 @@ def _explicify(M): m = [] for v in identity: m.append(M.matvec(v)) - return np.vstack(m).T - - -def _eigsh(metric, n_eigenvalues, tot_dofs, min_lh_eval=1e-4, batch_number=10, tol=0., verbose=True): + return np.column_stack(m) + + +def _eigsh( + metric, + n_eigenvalues, + tot_dofs, + dtype, + min_lh_eval=1e-4, + n_batches=10, + tol=0.0, + verbose=True, +): metric = SandwichOperator.make(_DomRemover(metric.domain).adjoint, metric) metric_size = metric.domain.size - M = ssl.LinearOperator(shape=2 * (metric_size,), matvec=lambda x: metric(makeField(metric.domain, x)).val) + M = ssl.LinearOperator( + shape=2 * (metric_size,), + matvec=lambda x: metric(makeField(metric.domain, x)).val, + dtype=dtype, + ) eigenvectors = None if n_eigenvalues > tot_dofs: - raise ValueError("Number of requested eigenvalues exceeds the number of relevant degrees of freedom!") + raise ValueError( + "Number of requested eigenvalues exceeds the number of relevant degrees of freedom!" + ) if tot_dofs == n_eigenvalues: # Compute exact eigensystem if verbose: logger.info(f"Computing all {tot_dofs} relevant metric eigenvalues.") - eigenvalues = slg.eigh(_explicify(M), eigvals_only=True, - subset_by_index=[metric_size - tot_dofs, metric_size - 1]) + eigenvalues = slg.eigh( + _explicify(M), + eigvals_only=True, + subset_by_index=[metric_size - tot_dofs, metric_size - 1], + ) eigenvalues = np.flip(eigenvalues) else: # Set up batches - batch_size = n_eigenvalues // batch_number - batches = [batch_size, ] * (batch_number - 1) - batches += [n_eigenvalues - batch_size * (batch_number - 1), ] + batch_size = n_eigenvalues // n_batches + remainder = n_eigenvalues % batch_size + batches = [ + batch_size, + ] * n_batches + batches += ( + [ + remainder, + ] + if remainder > 0 + else [] + ) eigenvalues, projected_metric = None, M + for batch in batches: if verbose: logger.info(f"\nNumber of eigenvalues being computed: {batch}") # Get eigensystem for current batch - eigvals, eigvecs = ssl.eigsh(projected_metric, k=batch, tol=tol, return_eigenvectors=True, which='LM') + eigvals, eigvecs = ssl.eigsh( + projected_metric, k=batch, tol=tol, return_eigenvectors=True, which="LM" + ) i = np.argsort(eigvals) eigvals, eigvecs = np.flip(eigvals[i]), np.flip(eigvecs[:, i], axis=1) - eigenvalues = eigvals if eigenvalues is None else np.concatenate((eigenvalues, eigvals)) - eigenvectors = eigvecs if eigenvectors is None else np.hstack((eigenvectors, eigvecs)) + eigenvalues = ( + eigvals + if eigenvalues is None + else np.concatenate((eigenvalues, eigvals)) + ) + eigenvectors = ( + eigvecs if eigenvectors is None else np.hstack((eigenvectors, eigvecs)) + ) if abs(1.0 - np.min(eigenvalues)) < min_lh_eval: break @@ -104,8 +141,17 @@ def _eigsh(metric, n_eigenvalues, tot_dofs, min_lh_eval=1e-4, batch_number=10, t return eigenvalues, eigenvectors -def estimate_evidence_lower_bound(hamiltonian, samples, n_eigenvalues, min_lh_eval=1e-3, batch_number=10, tol=0., - verbose=True): +def estimate_evidence_lower_bound( + hamiltonian, + samples, + n_eigenvalues, + *, + min_lh_eval=1e-3, + n_batches=10, + tol=0.0, + verbose=True, + dtype=np.float64, +): """Provides an estimate for the Evidence Lower Bound (ELBO). Statistical inference deals with the problem of hypothesis testing, given @@ -161,7 +207,7 @@ def estimate_evidence_lower_bound(hamiltonian, samples, n_eigenvalues, min_lh_ev eigenvalue estimation terminates and uses the smallest eigenvalue as a proxy for all remaining eigenvalues in the trace-log estimation. Default is 1e-3. - batch_number : int + n_batches : int Number of batches into which the eigenvalue estimation gets subdivided into. Only after completing one batch the early stopping criterion based on `min_lh_eval` is checked for. @@ -171,6 +217,8 @@ def estimate_evidence_lower_bound(hamiltonian, samples, n_eigenvalues, min_lh_ev verbose : Optional[bool] Print list of eigenvalues and summary of evidence calculation. Default is True. + dtype : Optional[numpy.dtype] + Data type of the eigenvalues and eigenvectors. Default is np.float64. Returns ------- @@ -218,36 +266,58 @@ def estimate_evidence_lower_bound(hamiltonian, samples, n_eigenvalues, min_lh_ev if not isinstance(hamiltonian, StandardHamiltonian): raise TypeError("hamiltonian is not an instance of `ift.StandardHamiltonian`.") - n_data_points = hamiltonian.likelihood_energy.data_domain.size if not None else hamiltonian.domain.size - n_relevant_dofs = min(n_data_points, hamiltonian.domain.size) # Number of metric eigenvalues that are not one + n_data_points = ( + hamiltonian.likelihood_energy.data_domain.size + if not None + else hamiltonian.domain.size + ) + n_relevant_dofs = min( + n_data_points, hamiltonian.domain.size + ) # Number of metric eigenvalues that are not one metric = hamiltonian(Linearization.make_var(samples._m, want_metric=True)).metric metric_size = metric.domain.size - eigenvalues, _ = _eigsh(metric, n_eigenvalues, tot_dofs=n_relevant_dofs, min_lh_eval=min_lh_eval, - batch_number=batch_number, tol=tol, verbose=verbose) + + eigenvalues, _ = _eigsh( + metric, + n_eigenvalues, + n_relevant_dofs, + dtype, + min_lh_eval=min_lh_eval, + n_batches=n_batches, + tol=tol, + verbose=verbose, + ) if verbose: # FIXME logger.info( f"\nComputed {eigenvalues.size} largest eigenvalues (out of {n_relevant_dofs} relevant degrees of freedom)." f"\nThe remaining {metric_size - n_relevant_dofs} metric eigenvalues (out of {metric_size}) are equal to " - f"1.\n\n{eigenvalues}.") + f"1.\n\n{eigenvalues}." + ) # Return a list of ELBO samples and a summary of the ELBO statistics log_eigenvalues = np.log(eigenvalues) - tr_log_lat_cov = - 0.5 * np.sum(log_eigenvalues) - tr_log_lat_cov_lower = 0.5 * (n_relevant_dofs - log_eigenvalues.size) * np.min(log_eigenvalues) + tr_log_lat_cov = -0.5 * np.sum(log_eigenvalues) + tr_log_lat_cov_lower = ( + 0.5 * (n_relevant_dofs - log_eigenvalues.size) * np.min(log_eigenvalues) + ) tr_log_lat_cov_lower = Field.scalar(tr_log_lat_cov_lower) posterior_contribution = Field.scalar(tr_log_lat_cov + 0.5 * metric_size) - elbo_samples = SampleList(list(samples.iterator(lambda x: posterior_contribution - hamiltonian(x)))) + elbo_samples = SampleList( + list(samples.iterator(lambda x: posterior_contribution - hamiltonian(x))) + ) - stats = {'lower_error': tr_log_lat_cov_lower} + stats = {"lower_error": tr_log_lat_cov_lower} elbo_mean, elbo_var = elbo_samples.sample_stat() elbo_up = elbo_mean + elbo_var.sqrt() elbo_lw = elbo_mean - elbo_var.sqrt() - stats["lower_error"] - stats['elbo_mean'], stats['elbo_up'], stats['elbo_lw'] = elbo_mean, elbo_up, elbo_lw + stats["elbo_mean"], stats["elbo_up"], stats["elbo_lw"] = elbo_mean, elbo_up, elbo_lw if verbose: - s = (f"\nELBO decomposition (in log units)" - f"\nELBO mean : {elbo_mean.val:.4e} (upper: {elbo_up.val:.4e}, lower: {elbo_lw.val:.4e})") + s = ( + f"\nELBO decomposition (in log units)" + f"\nELBO mean : {elbo_mean.val:.4e} (upper: {elbo_up.val:.4e}, lower: {elbo_lw.val:.4e})" + ) logger.info(s) return elbo_samples, stats diff --git a/src/re/evidence_lower_bound.py b/src/re/evidence_lower_bound.py index 20c7bf11d33662dacf89bc6d4123041b7c9b6fcc..f35921ece1d00da44392ef4278470de35ebaa1d1 100644 --- a/src/re/evidence_lower_bound.py +++ b/src/re/evidence_lower_bound.py @@ -22,11 +22,6 @@ class _Projector(ssl.LinearOperator): ---------- eigenvectors : ndarray The eigenvectors representing the directions to project out. - - Returns - ------- - Projector : LinearOperator - Operator representing the projection. """ def __init__(self, eigenvectors): @@ -44,35 +39,33 @@ class _Projector(ssl.LinearOperator): def _explicify(M): - n = M.shape[0] - m = [] - for i in range(n): - basis_vector = np.zeros(n) - basis_vector[i] = 1 - m.append(M @ basis_vector) - return np.stack(m, axis=1) + identity = np.identity(M.shape[0], dtype=np.float64) + return np.column_stack([M.matvec(v) for v in identity]) def _ravel_metric(metric, position, dtype): - shape = 2 * (size(metric(position, position)),) + def ravel(x): + return jax.flatten_util.ravel_pytree(x)[0] - ravel = lambda x: jax.flatten_util.ravel_pytree(x)[0] - unravel = lambda x: jax.linear_transpose(ravel, position)(x)[0] - met = lambda x: ravel(metric(position, unravel(x))) + shp, unravel = jax.flatten_util.ravel_pytree(position) + shape = 2 * (shp.size,) + + def met(x): + return ravel(metric(position, unravel(x))) return ssl.LinearOperator(shape=shape, dtype=dtype, matvec=met) def _eigsh( metric, + metric_size, n_eigenvalues, tot_dofs, min_lh_eval=1e-4, - batch_size=10, + n_batches=10, tol=0.0, verbose=True, ): - metric_size = metric.shape[0] eigenvectors = None if n_eigenvalues > tot_dofs: raise ValueError( @@ -83,7 +76,7 @@ def _eigsh( if tot_dofs == n_eigenvalues: # Compute exact eigensystem if verbose: - logger.info(f"Computing all {tot_dofs} relevant " f"metric eigenvalues.") + logger.info(f"Computing all {tot_dofs} relevant metric eigenvalues.") eigenvalues = slg.eigh( _explicify(metric), eigvals_only=True, @@ -92,14 +85,12 @@ def _eigsh( eigenvalues = np.flip(eigenvalues) else: # Set up batches - batch_size = n_eigenvalues // batch_size - batches = [ - batch_size, - ] * (batch_size - 1) - batches += [ - n_eigenvalues - batch_size * (batch_size - 1), - ] + batch_size = n_eigenvalues // n_batches + remainder = n_eigenvalues % batch_size + batches = [batch_size] * n_batches + batches += [remainder] if remainder > 0 else [] eigenvalues, projected_metric = None, metric + for batch in batches: if verbose: logger.info(f"\nNumber of eigenvalues being computed: {batch}") @@ -130,8 +121,9 @@ def estimate_evidence_lower_bound( likelihood, samples, n_eigenvalues, + *, min_lh_eval=1e-3, - batch_size=10, + n_batches=10, tol=0.0, verbose=True, ): @@ -189,7 +181,7 @@ def estimate_evidence_lower_bound( eigenvalue estimation terminates and uses the smallest eigenvalue as a proxy for all remaining eigenvalues in the trace-log estimation. Default is 1e-3. - batch_size : int + n_batches : int Number of batches into which the eigenvalue estimation gets subdivided into. Only after completing one batch the early stopping criterion based on `min_lh_eval` is checked for. @@ -250,19 +242,18 @@ def estimate_evidence_lower_bound( hamiltonian = StandardHamiltonian(likelihood) metric = hamiltonian.metric + metric_size = jax.flatten_util.ravel_pytree(samples.pos)[0].size metric = _ravel_metric(metric, samples.pos, dtype=likelihood.target.dtype) - metric_size = metric.shape[0] n_data_points = size(likelihood.lsm_tangents_shape) if not None else metric_size - n_relevant_dofs = min( - n_data_points, metric_size - ) # Number of metric eigenvalues that are not 1 + n_relevant_dofs = min(n_data_points, metric_size) eigenvalues, _ = _eigsh( metric, + metric_size, n_eigenvalues, tot_dofs=n_relevant_dofs, min_lh_eval=min_lh_eval, - batch_size=batch_size, + n_batches=n_batches, tol=tol, verbose=verbose, ) diff --git a/test/test_re/test_estimate_evidence_lower_bound.py b/test/test_re/test_estimate_evidence_lower_bound.py index 072f1992a53a19b3f8b7636d3715ca9fe0048b9b..03dba313a3737de50fa38bfb671eba34b6da64a0 100644 --- a/test/test_re/test_estimate_evidence_lower_bound.py +++ b/test/test_re/test_estimate_evidence_lower_bound.py @@ -24,10 +24,7 @@ def _explicify(M, position): unravel = lambda x: jax.linear_transpose(ravel, position)(x)[0] mat = lambda x: M(unravel(x)) identity = np.identity(dim, dtype=np.float64) - m = [] - for v in identity: - m.append(mat(v)) - return np.vstack(m).T + return np.column_stack([mat(v) for v in identity]) def get_linear_response(slope_op, intercept_op, sampling_points): @@ -240,10 +237,8 @@ def test_estimate_elbo_nifty_re_vs_nifty(seed): n_ham = ift.StandardHamiltonian(n_like) - elbo, stats = jft.estimate_evidence_lower_bound(like, samples, 4, batch_size=2) - n_elbo, nstats = ift.estimate_evidence_lower_bound( - n_ham, n_samples, 4, batch_number=2 - ) + elbo, stats = jft.estimate_evidence_lower_bound(like, samples, 4, n_batches=2) + n_elbo, nstats = ift.estimate_evidence_lower_bound(n_ham, n_samples, 4, n_batches=2) n_elbo_samples = [] for n_elbo_sample in n_elbo.iterator():