Commit 7180662a authored by Philipp Arras's avatar Philipp Arras
Browse files

Get rid of wrappers

parent 9c64937b
......@@ -31,58 +31,6 @@ from ..sugar import makeField, makeOp
from ..utilities import allreduce_sum, get_MPI_params_from_comm, shareRange
def get_ESS(samples):
""" The effective sample size over a set of samples.
Returns
-----------
ESS: MultiField
The effective sample size for all parameters in the samples.
"""
result = {}
autocorrelations = get_AFC(samples)
for key in autocorrelations.keys():
fld = autocorrelations[key].val
addaxis = False
if len(fld.shape) == 1:
# empty domaintuple war weird
fld = fld.reshape((1,) + fld.shape)
addaxis = True
cum_field = np.cumsum(fld, axis=-1)
correlation_length = np.argmax(fld < 0, axis=-1)
indices = np.where(np.ones(cum_field[..., 0].shape))
indices += (correlation_length.flatten() - 1,)
integr_corr = cum_field[indices] - 1
ESS = len(samples)/(1 + 2*integr_corr)
if addaxis:
result[key[:-2]] = Field(samples[0].domain[key[:-2]], ESS[0])
else:
result[key[:-2]] = Field(samples[0].domain[key[:-2]],
ESS.reshape(correlation_length.shape))
return MultiField.from_dict(result)
def get_AFC(samples):
"""Calculates the auto-correlation function for every parameter in the samples.
Returns
-----------
result: MultiField
The auto-correlation function for every parameter in the samples
"""
sample_field = _standardized_sample_field(samples)
result = {}
for key in sample_field.keys():
AFC = ACF_Selector(sample_field[key].domain, len(samples))
FFT = FFTOperator(sample_field[key].domain,
space=len(sample_field[key].domain._dom) - 1)
h = FFT(sample_field[key])
hch = h.conjugate()*h
autocorr = FFT.inverse(hch)
result[key] = AFC(autocorr).real
return result
def _mean(fld, dom):
result = {}
for key in fld.keys():
......@@ -327,7 +275,35 @@ class HMC_chain:
ESS: MultiField
The effective sample size of all model parameters of the chain.
"""
return get_ESS(self.samples)
sample_field = _standardized_sample_field(self.samples)
autocorrelations = {}
for key in sample_field.keys():
AFC = ACF_Selector(sample_field[key].domain, len(self.samples))
FFT = FFTOperator(sample_field[key].domain,
space=len(sample_field[key].domain._dom) - 1)
h = FFT(sample_field[key])
hch = h.conjugate()*h
autocorr = FFT.inverse(hch)
autocorrelations[key] = AFC(autocorr).real
result = {}
for key, fld in autocorrelations.items():
addaxis = False
if len(fld.shape) == 1: # FIXME ?
fld = fld.val.reshape((1,) + fld.shape)
addaxis = True
cum_field = np.cumsum(fld.val, axis=-1)
correlation_length = np.argmax(fld.val < 0, axis=-1)
indices = np.where(np.ones(cum_field[..., 0].shape))
indices += (correlation_length.flatten() - 1,)
integr_corr = cum_field[indices] - 1
ESS = len(self.samples)/(1 + 2*integr_corr)
if addaxis:
result[key[:-2]] = Field(self.samples[0].domain[key[:-2]], ESS[0])
else:
result[key[:-2]] = Field(self.samples[0].domain[key[:-2]],
ESS.reshape(correlation_length.shape))
return MultiField.from_dict(result)
def mean(self):
"""The mean over all samples of the chain.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment