diff --git a/src/minimization/hmc.py b/src/minimization/hmc.py index 4e671e2006c9a30f51e6c7f418c8dafab1f943b0..eed1e9f35c16f591378e35987ba2cdb6d1cfba68 100644 --- a/src/minimization/hmc.py +++ b/src/minimization/hmc.py @@ -172,9 +172,8 @@ class HMC_chain: tmp = self._sseq.spawn(2)[1] with Context(tmp): momentum = self._M.draw_sample_with_dtype(dtype=np.float64) - - new_position, new_momentum = self._integrate(momentum) - self._accepting(momentum, new_position, new_momentum) + new_position, new_momentum = self._integrate(momentum) + self._accepting(momentum, new_position, new_momentum) self._update_acceptance() def _integrate(self, momentum): @@ -235,8 +234,7 @@ class HMC_chain: rate = np.exp(energy - new_energy) if np.isnan(rate): return - rng = current_rng() - accept = rng.binomial(1, rate) + accept = current_rng().binomial(1, rate) if accept: self._position = new_position self._accepted.append(accept) @@ -245,10 +243,7 @@ class HMC_chain: def _update_acceptance(self): """Calculates the current acceptance rate based on the last ten samples.""" - current_accepted = self._accepted[-10:] - current_accepted = np.array(current_accepted) - current_acceptance = np.mean(current_accepted) - self._current_acceptance.append(current_acceptance) + self._current_acceptance.append(np.mean(self._accepted[-10:])) def _tune_parameters(self, preferred_acceptance): """Increases or decreases the steplength in the leapfrog integration @@ -393,13 +388,11 @@ class HMC_Sampler: The mean and variance over the samples. """ - locmeanvar = [ + lmv = [ chain.estimate_quantity(function) for chain in self._local_chains ] - locmean = [x[0] for x in locmeanvar] - locvar = [x[1] for x in locmeanvar] - mean = allreduce_sum(locmean, self._comm) - var = allreduce_sum(locvar, self._comm) + mean = allreduce_sum([x[0] for x in lmv], self._comm) + var = allreduce_sum([x[1] for x in lmv], self._comm) return mean/self._N_chains, var/self._N_chains @property @@ -433,8 +426,8 @@ class HMC_Sampler: dom = self._dom locfld = [_sample_field(chain.samples) for chain in self._local_chains] locmeanmean = [_mean(fld, dom) for fld in locfld] - locW = [_var(fld, dom) for fld in locfld] mean_mean = allreduce_sum(locmeanmean, self._comm)/M + locW = [_var(fld, dom) for fld in locfld] W = allreduce_sum(locW, self._comm)/M locB = [(mean_mean - _mean(fld, dom))**2 for fld in locfld] B = allreduce_sum(locB, self._comm)*N/(M - 1)