...
 
Commits (85)
HMCF -- Hamilton Monte Carlo Sampling for Fields
HMCF - Hamiltonian Monte Carlo Sampling for Fields
===============================================
Summary
......@@ -56,7 +56,13 @@ and for using HMC tools PyQt5 is mandatory, too:
For a quick start dive into HMCF by running one of the demonstrations, e.g.:
python3 demos/wiener_filter.py
### Further Documentation
For a detailled description of the whole package, see this [paper](https://arxiv.org/abs/1807.02709).
Note that the exact version of this software described in the paper can be found
on the [jss_publication](https://gitlab.mpcdf.mpg.de/ift/HMCF/tree/jss_publication)
branch.
### Release Notes
......
from .generate_mock_data import *
from .energies import *
from .generate_mock_data import *
from .plotting import *
from .energies import *
import nifty4 as ift
from copy import copy
import numpy as np
__all__ = ['SimpleHierarchicalEnergy']
class PowSpec(object):
def __init__(self, field_variance):
self._l_c = field_variance
def __call__(self, k):
pre_fac = self._l_c**4
return pre_fac / (1 + self._l_c * k)**4
def grad(self, k):
pre_fac = - .25 * self._l_c**5
return pre_fac / (1 + self._l_c * k)**3
def curv(self, k):
pre_fac = .25 * self._l_c**6
return pre_fac / (1 + self._l_c * k)**2 / (5 + 2 * self._l_c * k)
class LogPowSpec(PowSpec):
def __init__(self, field_variance):
super(LogPowSpec, self).__init__(field_variance)
self._log_l_c = np.log(self._l_c)
def __call__(self, k):
return 4 * (self._log_l_c - np.log(1 + self._l_c * k))
def grad(self, k):
return 4 * (1. / self._l_c - k / (1 + self._l_c * k))
def curv(self, k):
return 4 * (- 1. / self._l_c**2 + (k / (1 + self._l_c * k))**2)
class PriorEnergy(ift.Energy):
def __init__(self, position, l_c_mu, l_c_sig2):
"""
Parameters
----------
position : ift.MultiField
The position of the energy.
The MultiField contains entries for the 'signal' and the correlation length 'l_c'.
l_c_mu : float
Hyperparameter for the log-normal distribution of l_c.
l_c_sig2 : float
Hyperparameter for the log-normal distribution of l_c.
"""
super(PriorEnergy, self).__init__(position)
self._l_c = self._position['l_c']
self._logl_c = ift.log(self._l_c)
self._l_c_mu = l_c_mu
self._l_c_sig2 = l_c_sig2
def at(self, position):
return self.__class__(position, self._l_c_mu, self._l_c_sig2)
@property
@ift.memo
def value(self):
part = (self._logl_c.to_global_data() - self._l_c_mu)**2 / 2. / self._l_c_sig2 + self._logl_c.to_global_data()
return part[0]
@property
@ift.memo
def gradient(self):
signal_field = self._position['signal'].copy().fill(0.)
l_c_field = ((self._logl_c - self._l_c_mu) / self._l_c_sig2 + 1.) / self._l_c
result_dict = {'signal': signal_field, 'l_c': l_c_field}
return ift.MultiField(val=result_dict)
@property
@ift.memo
def curvature(self):
signal_field = self._position['signal'].copy().fill(0.)
l_c_field = ((1. - self._logl_c + self._l_c_mu) / self._l_c_sig2 - 1.) / self._l_c**2
diagonal = ift.MultiField(val={'signal': signal_field, 'l_c': l_c_field})
return ift.makeOp(diagonal)
class LikelihoodEnergy(ift.Energy):
def __init__(self, position, d, R, N, HT, inverter=None):
"""
The likelyhood part of the full posterior energy, i.e.
0.5 (sS^{-1}s + res N.inverse res + log(det(S)))
where res = d - R exp(s)
Parameters
----------
position : ift.MultiField
Containing 'signal' and 'l_c'
d : ift.Field
The data vector.
R : ift.LinearOperator
The instrument's response.
N : ift.EndomorphicOperator
The covariance of the noise.
HT : ift.HarmonicTransformOperator
"""
super(LikelihoodEnergy, self).__init__(position)
self._d = d
self._R = R
self._N = N
self._HT = HT
self._inverter = inverter
self._s = self._position['signal']
self._l_c_val = self._position['l_c'].to_global_data()[0]
self._x = ift.exp(self._HT(self._s))
self._res = self._d - self._R.times(self._x)
self._pow_spec = PowSpec(self._l_c_val)
self._log_pow_spec = LogPowSpec(self._l_c_val)
self._s_space = self._HT.domain
self._l_c_space = self._position.domain['l_c']
def at(self, position):
return self.__class__(position, self._d, self._R, self._N, self._HT, self._inverter)
def _l_c_det_part(self, power_spectrum):
covariance = ift.create_power_operator(domain=self._s_space, power_spectrum=power_spectrum)
return ift.Field(domain=self._l_c_space, val=.5 * covariance.times(ift.Field.full(self._s_space, val=1.)).sum())
def _l_c_exp_part(self, power_spectrum):
covariance = ift.create_power_operator(domain=self._s_space, power_spectrum=power_spectrum)
return ift.Field(domain=self._l_c_space, val=.5 * self._s.vdot(covariance.inverse_times(self._s)))
@property
@ift.memo
def value(self):
S = ift.create_power_operator(domain=self._s_space, power_spectrum=self._pow_spec)
s_part = 0.5 * (self._s.vdot(S.inverse_times(self._s)) + self._res.vdot(self._N.inverse_times(self._res)))
return s_part + self._l_c_det_part(self._log_pow_spec).to_global_data()[0]
@property
@ift.memo
def gradient(self):
S = ift.create_power_operator(domain=self._s_space, power_spectrum=self._pow_spec)
signal_grad = (S.inverse_times(self._s)
- self._HT.adjoint_times(self._x * self._R.adjoint_times(self._N.inverse_times(self._res))))
l_c_grad = self._l_c_det_part(self._log_pow_spec.grad) + self._l_c_exp_part(self._pow_spec.grad)
val = {'signal': signal_grad, 'l_c': l_c_grad}
return ift.MultiField(val=val)
@property
@ift.memo
def curvature(self):
# not really the curvature but rather a metric in the context of differential manifolds
loc_R = self._R * ift.makeOp(self._x) * self._HT
S = ift.create_power_operator(domain=self._s_space, power_spectrum=self._pow_spec)
signal_curv = ift.library.WienerFilterCurvature(loc_R, self._N, S, self._inverter)
l_c_curv_diag = self._l_c_det_part(self._log_pow_spec.curv) + self._l_c_exp_part(self._pow_spec.curv)
operators = {'signal': signal_curv, 'l_c': ift.makeOp(l_c_curv_diag)}
return ift.BlockDiagonalOperator(operators)
class SimpleHierarchicalEnergy(ift.Energy):
def __init__(self, position, d, R, N, HT, l_c_mu, l_c_sig2, inverter=None):
"""
A NIFTy Energy implementing a simple hierarchical model.
The 'signal' s has a Gaussian prior with covariance S.
S is diagonal in the harmonic representation of s and follows a power law which is dependent on another free
parameter, the correlation length 'l_c'.
l_c is log-normal distributed.
In this example, exp(s) represents the photon flux coming from the large-scale structure of the universe.
Parameters
----------
position : ift.MultiField
Containing 'signal' and 'l_c'
d : ift.Field
The data vector.
R : ift.LinearOperator
The instrument's response.
N : ift.EndomorphicOperator
The covariance of the noise.
l_c_mu : float
Hyperparameter for the log-normal distribution of l_c.
l_c_sig2 : float
Hyperparameter for the log-normal distribution of l_c. HT : ift.HarmonicTransformOperator
inverter : ift.Minimizer
Numerical method for inverting LinearOperators.
This is necessary to be able to sample from the curvature which is required if the mass matrix is supposed
to be reevaluated.
"""
super(SimpleHierarchicalEnergy, self).__init__(position)
self._prior = PriorEnergy(self._position, l_c_mu, l_c_sig2)
self._likel = LikelihoodEnergy(self._position, d, R, N, HT, inverter=inverter)
def at(self, position):
new_energy = copy(self)
new_energy._position = position
new_energy._prior = self._prior.at(position)
new_energy._likel = self._likel.at(position)
return new_energy
@property
def value(self):
return self._prior.value + self._likel.value
@property
def gradient(self):
return self._prior.gradient + self._likel.gradient
@property
@ift.memo
def curvature(self):
return ift.InversionEnabler(self._prior.curvature + self._likel.curvature, inverter=self._likel._inverter)
import numpy as np
import nifty4 as ift
__all__ = ['get_poisson_data', 'get_hierarchical_data']
class Exp3(object):
def __call__(self, x):
return ift.exp(3*x)
def derivative(self, x):
return 3*ift.exp(3*x)
def get_poisson_data(n_pixels=1024, dimension=1, length=2., excitation=1., diffusion=10.):
"""
Parameters
----------
n_pixels : int
Number of pixels (per dimension)
dimension : int
Dimensionality of underlying space
length : float
Total physical length of regular grid.
excitation : float
Excitation field level.
diffusion : float
Diffusion constant.
Returns
-------
data : ift.Field
The poisson distributed `meausured' data field.
s : ift.Field
The `true' signal field (in harmonic space representation).
R : ift.LinearOperator
Instrument's response.
S : ift.LinearOperator
Correlation of signal field.
"""
# small number to tame zero mode
eps = 1e-8
# Define data gaps
na = int(0.6*n_pixels)
nb = int(0.7*n_pixels)
# Set up derived constants
amp = excitation/(2*diffusion) # spectral normalization
pow_spec = lambda k: amp / (eps + k**2)
pixel_width = length/n_pixels
# Set up the geometry
s_domain = ift.RGSpace(tuple(n_pixels for _ in range(dimension)), distances=pixel_width)
h_domain = s_domain.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_domain, s_domain)
# Create mock signal
S = ift.create_power_operator(h_domain, power_spectrum=pow_spec)
s = S.draw_sample()
# remove zero mode
glob = s.to_global_data()
glob[0] = 0.
s = ift.Field.from_global_data(s.domain, glob)
# Setting up an exemplary response
GeoRem = ift.GeometryRemover(s_domain)
d_domain = GeoRem.target[0]
mask = np.ones(d_domain.shape)
idx = np.arange(na, nb)
mask[np.ix_(*tuple(idx for _ in range(dimension)))] = 0.
print(mask)
fmask = ift.Field.from_global_data(d_domain, mask)
Mask = ift.DiagonalOperator(fmask)
R0 = Mask*GeoRem
# generate poisson distributed data counts
nonlin = Exp3()
lam = R0(nonlin(HT(s)))
data = ift.Field.from_local_data(d_domain, np.random.poisson(lam.local_data).astype(np.float64))
return data, s, R0, S
def get_hierarchical_data(n_pixels=100, dimension=2, length=2., correlation_length=4., signal_to_noise=2.):
"""
Generates a mock data set for a Wiener filter with log-normal prior.
Implemented as d = R(exp(s)) + n where d is the mock data field, s is the 'true' signal field with a gaussian
distribution, R the instrument response and n white noise.
Parameters
----------
n_pixels : int
Number of pixels per dimension.
Default: 1024
dimension : int
Number of dimensions
Default: 1
length : float
Physical length of underlying space.
Default: 2.
correlation_length : float
Default: 0.1
signal_to_noise : float
Amplitude of noise compared to variance of the signal.
Default: 1.
Returns
-------
d : ift.Field
The mock data field.
s : ift.Field
The 'true' signal (in harmonic co-domain)
R : ift.LinearOperator
The instrument response (harmonic signal space -> data space)
N : ift.DiagonalOperator
The covariance for the noise n defined on the data domain.
"""
shape = tuple(n_pixels for _ in range(dimension))
pixel_width = length / n_pixels
a = correlation_length**4
def p_spec(k):
return a / (k*correlation_length + 1)**4
non_linearity = ift.library.nonlinearities.Exponential()
# Set up position space
s_space = ift.RGSpace(shape, distances=pixel_width)
h_space = s_space.get_default_codomain()
param_space = ift.UnstructuredDomain((1,))
# Define harmonic transformation and associated harmonic space
HT = ift.HarmonicTransformOperator(h_space, target=s_space)
# Drawing a sample sh from the prior distribution in harmonic space
S = ift.ScalingOperator(1., h_space)
sh = S.draw_sample()
p_op = ift.create_power_operator(h_space, p_spec)
power = ift.sqrt(p_op(ift.Field.full(h_space, 1.)))
# Choosing the measurement instrument
mask = np.ones(s_space.shape)
idx = np.arange(int(0.6*n_pixels), int(0.8*n_pixels))
mask[np.ix_(*tuple(idx for _ in range(dimension)))] = 0.
fmask = ift.Field.from_global_data(s_space, mask)
R = ift.GeometryRemover(s_space) * ift.DiagonalOperator(fmask)
d_space = R.target
# Creating the mock data
true_sky = non_linearity(HT(power*sh))
noiseless_data = R(true_sky)
noise_amplitude = noiseless_data.val.std() / signal_to_noise
N = ift.ScalingOperator(noise_amplitude ** 2, d_space)
n = N.draw_sample()
d = noiseless_data + n
l_c = ift.Field(domain=param_space, val=correlation_length)
val = {'signal': power*sh, 'l_c': l_c}
return d, ift.MultiField(val=val), R, N
import matplotlib.pyplot as plt
import nifty4 as ift
import hmcf
import numpy as np
__all__ = ['plot_simple_hierarchical_result', 'plot_simple_hierarchical_source', 'plot_poisson_result']
def plot_simple_hierarchical_source(sl, d, R, sample_transformation):
true_flux = sample_transformation(sl)['signal']
fig, (x_ax, rx_ax, d_ax) = plt.subplots(nrows=1, ncols=3)
v_min, v_max = d.min(), d.max()
x_ax.imshow(true_flux.to_global_data(), vmin=v_min, vmax=v_max)
x_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
x_ax.set_xlabel('x')
rx_ax.imshow(R(true_flux).to_global_data(), cmap='viridis', vmin=v_min, vmax=v_max)
rx_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
rx_ax.set_xlabel('R(x)')
im = d_ax.imshow(d.to_global_data(), vmin=v_min, vmax=v_max)
d_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
d_ax.set_xlabel('d')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.1, 0.02, 0.8])
fig.colorbar(im, cax=cbar_ax)
def plot_simple_hierarchical_result(sl, sampler):
"""
Parameters
----------
sl : ift.MultiField
sampler : hmcf.HMCSampler
"""
true_flux = sampler.sample_transform(sl)['signal']
fig, (signal_ax, mean_ax) = plt.subplots(nrows=1, ncols=2)
mean_val = sampler.mean['signal'].to_global_data()
v_min, v_max = mean_val.min(), mean_val.max()
signal_ax.imshow(true_flux.to_global_data(), vmin=v_min, vmax=v_max)
signal_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
signal_ax.set_xlabel('x')
im = mean_ax.imshow(mean_val, vmin=v_min, vmax=v_max)
mean_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
mean_ax.set_xlabel('HMC mean')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.1, 0.02, 0.8])
fig.colorbar(im, cax=cbar_ax)
fig, (diff_ax, std_ax) = plt.subplots(nrows=1, ncols=2)
std_val = ift.sqrt(sampler.var['signal']).to_global_data()
diff = abs(true_flux.to_global_data() - mean_val)
v_min, v_max = min(std_val.min(), diff.min()), max(std_val.max(), diff.max())
diff_ax.imshow(diff, vmin=v_min, vmax=v_max)
diff_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
diff_ax.set_xlabel('difference')
im = std_ax.imshow(std_val, vmin=v_min, vmax=v_max)
std_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
std_ax.set_xlabel('HMC std')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.1, 0.02, 0.8])
fig.colorbar(im, cax=cbar_ax)
def plot_poisson_result(data, s, sampler):
"""
Plots the result of the poisson demo script.
Works only with one-dimensional problems.
Saves the result to a pdf called 'HMC_Poisson.pdf'.
Parameters
----------
data : ift.Field
The data vector.
s : ift.Field
The true signal field.
sampler : hmcf.HMCSampler
The sampler used for solving the problem.
"""
hmc_mean = sampler.mean
hmc_std = ift.sqrt(sampler.var)
HT = ift.HarmonicTransformOperator(s.domain)
co_dom = HT.target
plt.clf()
if len(co_dom.shape) > 1:
raise TypeError('Poisson demo result plotting only possible for one-dimensional problems.')
pixel_width = co_dom[0].distances[0]
length = co_dom.shape[0] * pixel_width
x = np.arange(0, length, pixel_width)
plt.rcParams["text.usetex"] = True
c1 = (hmc_mean-hmc_std).to_global_data()
c2 = (hmc_mean+hmc_std).to_global_data()
plt.scatter(x, data.to_global_data(), label='data', s=1, color='blue', alpha=0.5)
plt.plot(x, sampler.sample_transform(s).to_global_data(), label='true sky', color='black')
plt.fill_between(x, c1, c2, color='orange', alpha=0.4)
plt.plot(x, hmc_mean.to_global_data(), label='hmc mean', color='orange')
plt.xlim([0, length])
plt.ylim([-0.1, 7.5])
plt.title('Poisson log-normal HMC reconstruction')
plt.legend()
plt.savefig('HMC_Poisson.pdf')
from demo_utils import *
import numpy as np
import nifty4 as ift
import hmcf
import matplotlib.pyplot as plt
from logging import INFO, DEBUG
np.random.seed(41)
def get_ln_params(mean, mode):
mu = np.log(mode*mean**2) / 3.
sig2 = 2. / 3. * np.log(mean/mode)
return mu, sig2
if __name__ == '__main__':
ift.logger.setLevel(INFO)
hmcf.logger.setLevel(INFO)
# hyperparameter
l_c_mean = 2.
l_c_mode = .1
l_c_mu, l_c_sig2 = get_ln_params(l_c_mean, l_c_mode)
d, sl, R, N = get_hierarchical_data()
s_space = sl['signal'].domain
HT = ift.HarmonicTransformOperator(s_space)
# energy related stuff
ICI = ift.GradientNormController(iteration_limit=2000,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
energy = SimpleHierarchicalEnergy(sl, d, R, N, HT, l_c_mu, l_c_sig2, inverter=inverter)
def sample_trafo(x):
val = dict(x)
val['signal'] = ift.exp(HT(val['signal']))
return ift.MultiField(val=val)
sampler = hmcf.HMCSampler(energy, num_processes=6, sample_transform=sample_trafo)
sampler.display = hmcf.TableDisplay
x_initial = [sl*c for c in [.5, .7, 1.2]]
sampler.run(100, x_initial=x_initial)
plot_simple_hierarchical_source(sl, d, R, sample_trafo)
plot_simple_hierarchical_result(sl, sampler)
plt.show()
import numpy as np
import nifty4 as ift
import hmcf
from matplotlib import pyplot as plt
from demo_utils import *
class Exp3(object):
......@@ -13,131 +13,31 @@ class Exp3(object):
if __name__ == "__main__":
np.random.seed(20)
# Set up physical constants
nu = 1. # excitation field level
kappa = 10. # diffusion constant
eps = 1e-8 # small number to tame zero mode
sigma_n = 0.2 # noise level
sigma_n2 = sigma_n**2
L = 1. # Total length of interval or volume the field lives on
nprobes = 1000 # Number of probes for uncertainty quantification
# Define resolution (pixels per dimension)
N_pixels = 1024
# Define data gaps
N1a = int(0.6*N_pixels)
N1b = int(0.64*N_pixels)
N2a = int(0.67*N_pixels)
N2b = int(0.8*N_pixels)
# Set up derived constants
amp = nu/(2*kappa) # spectral normalization
pow_spec = lambda k: amp / (eps + k**2)
lambda2 = 2*kappa*sigma_n2/nu # resulting correlation length squared
lambda1 = np.sqrt(lambda2)
pixel_width = L/N_pixels
x = np.arange(0, 1, pixel_width)
# Set up the geometry
s_domain = ift.RGSpace([N_pixels], distances=pixel_width)
h_domain = s_domain.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_domain, s_domain)
aHT = HT.adjoint
# Create mock signal
Phi_h = ift.create_power_operator(h_domain, power_spectrum=pow_spec)
phi_h = Phi_h.draw_sample()
# remove zero mode
glob = phi_h.to_global_data()
glob[0] = 0.
phi_h = ift.Field.from_global_data(phi_h.domain, glob)
phi = HT(phi_h)
num_processes = 5
# Setting up an exemplary response
GeoRem = ift.GeometryRemover(s_domain)
d_domain = GeoRem.target[0]
mask = np.ones(d_domain.shape)
mask[N1a:N1b] = 0.
mask[N2a:N2b] = 0.
fmask = ift.Field.from_global_data(d_domain, mask)
Mask = ift.DiagonalOperator(fmask)
R0 = Mask*GeoRem
R = R0*HT
np.random.seed(20)
# Linear measurement scenario
N = ift.ScalingOperator(sigma_n2, d_domain) # Noise covariance
n = Mask(N.draw_sample()) # seting the noise to zero in masked region
d = R(phi_h) + n
data, s, R, S = get_poisson_data()
# Wiener filter
j = R.adjoint_times(N.inverse_times(d))
# additional needed objects
HT = ift.HarmonicTransformOperator(s.domain)
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator.make(R, N.inverse) + Phi_h.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Phi_h)
m = HT(D(j))
# Uncertainty
D = ift.SandwichOperator.make(aHT, D) # real space propagator
Dhat = ift.probe_with_posterior_samples(D.inverse, None,
nprobes=nprobes)[1]
nonlin = Exp3()
lam = R0(nonlin(HT(phi_h)))
data = ift.Field.from_local_data(d_domain, np.random.poisson(lam.local_data).astype(np.float64))
# initial guess
psi0 = ift.full(h_domain, 1e-7)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h,
inverter)
IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
tol_abs_gradnorm=1e-4)
minimizer = ift.RelaxedNewton(IC1)
energy = minimizer(energy)[0]
non_lin = Exp3()
m = HT(energy.position)
energy = ift.library.PoissonEnergy(s, data, R, non_lin, HT, S, inverter)
# Using an HMC sampler
def sample_transform(z):
return nonlin(HT(z))
return non_lin(HT(z))
num_processes = 5
mass_reevaluations = 3
num_samples = 200
max_burn_in = 5000
convergence_tolerance = 1.
x_initial = [energy.position + 0.01*Phi_h.draw_sample() for _ in range(num_processes)]
x_initial = [energy.position + 0.01*S.draw_sample() for _ in range(num_processes)]
poisson_hmc = hmcf.HMCSampler(potential=energy, num_processes=num_processes, sample_transform=sample_transform)
poisson_hmc.display = hmcf.TableDisplay
poisson_hmc.epsilon = hmcf.EpsilonPowerLawDivergence
poisson_hmc.mass.reevaluations = mass_reevaluations
poisson_hmc.run(num_samples=num_samples, max_burn_in=max_burn_in, convergence_tolerance=convergence_tolerance,
x_initial=x_initial)
hmc_mean = poisson_hmc.mean
hmc_std = ift.sqrt(poisson_hmc.var)
# Plotting
sig = ift.sqrt(Dhat)
poisson_hmc.mass.reevaluations = 2
poisson_hmc.run(200, x_initial=x_initial)
x_mod = np.where(mask > 0, x, None)
plt.clf()
plt.rcParams["text.usetex"] = True
c1 = (hmc_mean-hmc_std).to_global_data()
c2 = (hmc_mean+hmc_std).to_global_data()
plt.scatter(x_mod, data.to_global_data(), label='data', s=1, color='blue', alpha=0.5)
plt.plot(x, nonlin(phi).to_global_data(), label='true sky', color='black')
plt.plot(x, nonlin(m).to_global_data(), label='reconstructed sky', color='red')
plt.fill_between(x, c1, c2, color='orange', alpha=0.4)
plt.plot(x, hmc_mean.to_global_data(), label='hmc mean', color='orange')
plt.xlim([0, L])
plt.ylim([-0.1, 7.5])
plt.title('Poisson log-normal HMC reconstruction')
plt.legend()
plt.savefig('HMC_Poisson.pdf')
plot_poisson_result(data, s, poisson_hmc)
from .hmc_sampler import *
from .utils import *
from .convergence import *
from .logging import *
......@@ -391,12 +391,13 @@ class MultiConvergence(object):
converged if max(convergence_measure) < tolerance * 10**level
"""
def __init__(self, domain, dtype=np.float64, num_chains=1, tolerance=2., level=1, convergences=None):
def __init__(self, domain, dtype=None, num_chains=1, tolerance=2., level=1, convergences=None):
"""
Parameters
----------
domain : ift.MultiDomain
dtype : np.dtype.dtype
Default: np.float64 for all sub fields.
num_chains : int
number of MCMC chains
tolerance : float
......@@ -408,6 +409,8 @@ class MultiConvergence(object):
"""
self._domain = domain
self._keys = self._domain.keys()
if dtype is None:
dtype = {key: np.float64 for key in self._keys}
self._dtype = dtype
self._num_chains = num_chains
......@@ -479,7 +482,7 @@ class MultiConvergence(object):
quota_sum = np.zeros(self._num_chains, dtype=float)
weight_sum = 0
for key in self._keys:
weight = np.prod(self._convergences[key].shape)
weight = np.prod(self._convergences[key]._shape)
quota_sum += weight * self._convergences[key].quota
weight_sum += weight
return quota_sum / weight_sum
......@@ -513,13 +516,13 @@ class MultiConvergence(object):
def _get_default_convergences(self, tolerance=2., level=0):
if self._num_chains == 1:
result = {key: HansonConvergence(domain=domain, dtype=self._dtype, num_chains=self._num_chains,
result = {key: HansonConvergence(domain=domain, dtype=self._dtype[key], num_chains=self._num_chains,
tolerance=tolerance, level=level)
for key, domain in self._domain.domains.items()}
for key, domain in self._domain.items()}
else:
result = {key: GelmanRubinConvergence(domain=domain, dtype=self._dtype, num_chains=self._num_chains,
result = {key: GelmanRubinConvergence(domain=domain, dtype=self._dtype[key], num_chains=self._num_chains,
tolerance=tolerance, level=level)
for key, domain in self._domain.domains.items()}
for key, domain in self._domain.items()}
return result
def reset(self, level=None):
......
This diff is collapsed.
import logging
logger = logging.getLogger('HMCF')
formatter = logging.Formatter('{asctime} - {name} - {levelname}: {message}', style='{')
logger.setLevel(logging.DEBUG)
_hdlr = logging.StreamHandler()
_hdlr.setLevel(logger.level)
_hdlr.setFormatter(formatter)
logger.addHandler(_hdlr)
from .hmcf_logger import *
from logging.handlers import QueueHandler, BufferingHandler
from . import logger
import nifty4 as ift
__all__ = ['HMCLogHandler', 'handle_logs', 'redirect_logging', 'reset_logging', 'DisplayHandler']
class HMCLogHandler(QueueHandler):
""" A class for handling sub-process logging messages.
This logging.Handler deals with logging messages as any other Handler (e.g. it can be combined with a
logging.Formatter and so on) but appends the output to the queue packages which are send from the sup-process to the
main process.
This is particular important for logging messages coming from the nifty package.
Attributes
----------
queue : dict
Package for data transfer between sub-process and main process during an HMC run.
"""
def __init__(self, out_pkg):
"""
Parameters
----------
out_pkg : dict
Package for data transfer between sub-process and main process during an HMC run.
"""
super(HMCLogHandler, self).__init__(queue=out_pkg)
def enqueue(self, record):
if 'log' not in self.queue.keys():
self.queue['log'] = []
self.queue['log'].append(record)
def handle_logs(package):
"""
Handles the log messages in the multi-process communication.
Parameters
----------
package : dict
The queue packages interchanged between sub- and main processes.
"""
logs = package['log']
ch_id = package['chain_identifier']
for record in logs:
record.msg = 'chain{:03d}: '.format(ch_id) + record.msg
logger.handle(record)
def redirect_logging(package):
""" Takes care of redirecting logging messages from nifty and hmcf during sampling in sup-processes.
Parameters
----------
package : dict
Package of interest (used for exchange of information between sub- and main processes).
Returns
-------
old_nifty_hdlrs : list
For resetting.
old_hmcf_hdlrs : list
For resetting.
"""
loggers = [logger, ift.logger]
old_hdlrs = []
new_hdlrs = [HMCLogHandler(package) for _ in range(2)]
for index, log in enumerate(loggers):
handlers = log.handlers
for hdlr in handlers:
log.removeHandler(hdlr)
new_hdlrs[index].setLevel(log.level)
log.addHandler(new_hdlrs[index])
old_hdlrs.append(handlers)
return tuple(old_hdlrs)
def reset_logging(old_nifty_hdlrs, old_hmcf_hdlrs):
""" Resets the redirection of logging in sub-processes.
Parameters
----------
old_nifty_hdlrs : list
The NIFTy logging handlers before redirection (only necessary if reset=True)
old_hmcf_hdlrs : list
The HMCF logging handlers before redirection (only necessary if reset=True)
"""
loggers = [ift.logger, logger]
old_hdlrs = [old_nifty_hdlrs, old_hmcf_hdlrs]
for index, log in enumerate(loggers):
hdlr = log.handlers[0]
log.removeHandler(hdlr)
for hdlr in old_hdlrs[index]:
log.addHandler(hdlr)
class DisplayHandler(BufferingHandler):
def emit(self, record):
"""
Emit a record.
Append the record. If shouldFlush() tells us to, call flush() to process
the buffer.
The important difference to BufferingHandler is, that the buffer does not contain records but formatted strings.
"""
self.buffer.append(self.format(record))
if self.shouldFlush(record):
self.flush()
logger.warning('Display logging handler overflow. Flushed all messages.')
......@@ -22,7 +22,7 @@ this module provides some functionality to probe a mass matrix given the curvatu
"""
import numpy as np
import nifty4 as ift
from logging import INFO
from ..logging import logger
__all__ = ['get_mass_operator', 'MassBase', 'MassMain', 'MassChain']
......@@ -55,7 +55,8 @@ def get_mass_operator(curvature, probe_count=10, dtype=None):
sample = curvature.draw_sample(from_inverse=True, dtype=dtype)
sample_squared_sum += sample.conjugate() * sample
inv_diagonal = sample_squared_sum / (probe_count-1)
if (inv_diagonal == 0.).any():
raise ValueError("Inverse diagonal of mass operator contains zero-entries.")
return ift.makeOp(1. / inv_diagonal)
......@@ -180,7 +181,7 @@ class MassMain(MassBase):
For non-shared mass operators this class mostly handles information related tasks.
For shared mass operations it also reevaluates mass matrices.
"""
def __init__(self, potential, num_chains=1, num_samples=50, mass_reevaluations=0):
def __init__(self, potential, num_chains=1, num_samples=50, mass_reevaluations=1):
"""
Parameters
----------
......@@ -199,6 +200,20 @@ class MassMain(MassBase):
self._new_mass_flag = np.full(self._num_chains, False, dtype=bool)
self._current_positions = [self._potential.position.copy() for _ in range(self._num_chains)]
self._got_current_position = self._new_mass_flag.copy()
get_init_mass_val = self._get_initial_mass
self._get_initial_mass = np.full(self._num_chains, get_init_mass_val)
@MassBase.get_initial_mass.getter
def get_initial_mass(self):
return self._get_initial_mass.any()
@get_initial_mass.setter
def get_initial_mass(self, value):
if value:
self._reevaluations += 1
else:
self._dec_reevaluations()
self._get_initial_mass = np.full(self._num_chains, value)
def get_client(self):
"""
......@@ -211,7 +226,7 @@ class MassMain(MassBase):
mass_client = MassChain(self._potential, num_samples=self.num_samples,
mass_reevaluations=self._reevaluations[0])
mass_client._operator = self._operator
mass_client._get_initial_mass = self._get_initial_mass
mass_client._get_initial_mass = self._get_initial_mass.any()
mass_client.shared = self.shared
return mass_client
......@@ -236,39 +251,46 @@ class MassMain(MassBase):
step = incoming['step']
new_mass_flag = False
if self.shared:
engage = incoming['accepted'] and incoming['epsilon'].converged
engage = incoming['accepted']
nominal_clearance = outgoing['converged'] and self._reevaluations[ch_id] > 0
if self._new_mass_flag[ch_id]: # check if there is an new 'shared' mass already
self._new_mass_flag[ch_id] = False
new_mass_flag = True
elif engage and (self._get_initial_mass or nominal_clearance):
elif engage and (self._get_initial_mass[ch_id] or nominal_clearance):
# add a new sample
self._current_positions[ch_id] = incoming['sample']
self._got_current_position[ch_id] = True
# try curvature based approach (every time the conditions above are met)
position = self._get_position_for_reeval()
curvature = self._potential.at(position).curvature
self._operator = self._get_mass_operator(curvature)
new_mass_flag = True
incoming['log'].append((INFO, 'mass matrix reevaluated based on curvature. (step %d)' % step))
self.reset(reset_operator=False)
self._new_mass_flag = np.full_like(self._new_mass_flag, True)
self._new_mass_flag[ch_id] = False
try:
new_op = self._get_mass_operator(curvature)
except ValueError:
logger.warning('Curvature not positive definite, try again in next step. (step {:d})'.format(step))
new_mass_flag = False
else:
logger.info('Mass matrix reevaluated. (step {:d})'.format(step))
new_mass_flag = True
self._operator = new_op
self.reset(reset_operator=False)
self._new_mass_flag = np.full_like(self._new_mass_flag, True)
self._new_mass_flag[ch_id] = False
# if this was an initial mass thing, remember to start the real work.
if new_mass_flag:
outgoing.update(new_mass=self._operator)
if self._get_initial_mass:
self._get_initial_mass = False
if self._get_initial_mass[ch_id]:
self._get_initial_mass = np.full(self._num_chains, False)
else:
self._reevaluations[ch_id] -= 1
else:
# chains do all the work. Just do some stuff for information purposes
if incoming['mass_reeval']:
self._reevaluations[ch_id] -= 1
if self._get_initial_mass: # this is problematic since it should be an num_chains dim vector.
self._get_initial_mass = False
if self._get_initial_mass[ch_id]: # this is problematic since it should be an num_chains dim vector.
self._get_initial_mass[ch_id] = False
new_mass_flag = True
return new_mass_flag
......@@ -334,19 +356,25 @@ class MassChain(MassBase):
if 'new_mass' in incoming:
self._operator = incoming['new_mass']
new_mass_flag = True
outgoing['log'].append((INFO, 'Received new Mass matrix. (step %d)' % step))
logger.info('Received new Mass matrix. (step {:d})'.format(step))
else:
sample = outgoing['sample'] # type: ift.Field
engage = outgoing['accepted'] and outgoing['epsilon'].converged
engage = outgoing['accepted']
nominal_clearance = incoming['converged'] and self._reevaluations[ch_id] > 0
if engage and (self._get_initial_mass or nominal_clearance):
# curvature based approach
curvature = self._potential.at(sample).curvature
self._operator = self._get_mass_operator(curvature)
outgoing['log'].append((INFO, 'mass matrix reevaluated based on curvature. (step %d)' % step))
new_mass_flag = True
self.reset(reset_operator=False)
try:
new_op = self._get_mass_operator(curvature)
except ValueError:
logger.warning('Curvature not positive definite, try again in next step. (step {:d})'.format(step))
new_mass_flag = False
else:
logger.info('Mass matrix reevaluated. (step {:d})'.format(step))
new_mass_flag = True
self._operator = new_op
self.reset(reset_operator=False)
if new_mass_flag:
outgoing.update(mass_reeval=True)
......
......@@ -17,7 +17,7 @@
#
# HMCF is being developed at the Max-Planck-Institut fuer Astrophysik
from hmc import load
from hmcf import load
import numpy as np
import os
......@@ -25,7 +25,7 @@ import os
__all__ = ['get_autocorrelation']
def get_autocorrelation(path, shift=1):
def get_autocorrelation(path, shift=1, field_name=None):
"""
Calculates the autocorrelation for the samples of a h5 run-file.
It follows the convention of numpy correlate which is
......@@ -37,6 +37,8 @@ def get_autocorrelation(path, shift=1):
path to run file or sampler directory.
If the latter is used, the latest run file in this directory is evaluated.
shift : int
field_name : str, optional
If MultiFields were used, this specifies which of the sub-fields is supposed to be used.
Returns
-------
auto_corr : np.ndarray
......@@ -45,7 +47,9 @@ def get_autocorrelation(path, shift=1):
"""
samples = load(path=path)
if isinstance(samples, dict):
raise NotImplementedError('MultiField capability not implemented for tools yet')
if field_name is None:
raise ValueError("Samples are present as MultiFields. This requires the field_name keyword argument. ")
samples = samples[field_name]
samples[np.isnan(samples)] = 0
shift_samples = samples.copy()
reverse = shift < 0
......
......@@ -20,7 +20,7 @@
from .matplotlib_gui import EvaluationWindow
import numpy as np
import nifty4 as ift
from hmc import load, load_mean
from hmcf import load, load_mean
import sys
from PyQt5 import QtWidgets
......@@ -28,7 +28,7 @@ from PyQt5 import QtWidgets
__all__ = ['show_trajectories']
def show_trajectories(path, reference_field='mean', solution=None, start=0, stop=None, step=1):
def show_trajectories(path, reference_field=None, field_name=None, solution=None, start=0, stop=None, step=1):
"""
A GUI for examining the trajectories of the individual Markov chains.
......@@ -37,27 +37,31 @@ def show_trajectories(path, reference_field='mean', solution=None, start=0, stop
path : str
path to the h5 file containing the samples or the directory where the run files are located.
If a directory is passed, the latest run file is loaded.
reference_field : str, ift.Field, np.ndarray
reference_field : ift.Field, ift.MultiField, np.ndarray, dict(str -> np.ndarray)
the image of interest. As string the following statements are valid:
'mean' : use the final mean value of the sampling process. If there are only burn_in samples use the mean
of the last sample of each chain
'bad_pixel_distr' : the sum of all convergence measures during burn in (only possible if HMCSamplerDebug or
the like has been used with approbiate attributes
it is also possible to pass your own reference image as either ift-field or numpy array. Check shape and
dimensions though
solution : ift.Field, np.ndarray, optional
if known the 'right' solution for your sampling problem
field_name : str, optional
If MultiFields were used, this specifies which of the sub-fields is supposed to be displayed.
solution : ift.Field, ift.MultiField, np.ndarray, optional
If known the 'right' solution for your sampling problem
start : int
use samples starting with start
Use samples starting with start.
Default: 0
stop : int, optional
use samples until stop
Use samples until stop.
step : int
only use every 'step' sample
Only use every 'step' sample.
Default: 1
"""
# load required samples
burn_in_samples = load(path, 'burn_in', start=start, stop=stop, step=step)
if isinstance(burn_in_samples, dict):
raise NotImplementedError('MultiField capability not implemented for tools yet')
if field_name is None:
raise ValueError("Samples are present as MultiFields. This requires the field_name keyword argument. ")
burn_in_samples = burn_in_samples[field_name]
burn_in_length = burn_in_samples.shape[1]
if stop is None:
samples_stop = None
......@@ -69,6 +73,8 @@ def show_trajectories(path, reference_field='mean', solution=None, start=0, stop
if load_samples:
try:
samples = load(path, 'samples', start=samples_start, stop=samples_stop, step=step)
if isinstance(samples, dict):
samples = samples[field_name]
except KeyError:
all_samples = burn_in_samples
else:
......@@ -77,30 +83,28 @@ def show_trajectories(path, reference_field='mean', solution=None, start=0, stop
all_samples = burn_in_samples
# check reference_field input and set field_val accordingly
if reference_field == 'mean':
if reference_field is None:
try:
field_val = load_mean(path)
reference_field = load_mean(path)
except (KeyError, ValueError):
field_val = np.nanmean(burn_in_samples, axis=(0, 1))
elif reference_field == 'bad_pixel_distr':
conv_meas = load(path, 'conv_meas', start=start, stop=stop, step=step)
conv_meas[np.isnan(conv_meas)] = 0.
inf_mask = np.isinf(conv_meas)
kinda_normalizer = np.nanmax(conv_meas[np.logical_not(inf_mask)])
conv_meas[inf_mask] = kinda_normalizer
conv_meas /= kinda_normalizer
num_samples = conv_meas.shape[0] * conv_meas.shape[1]
conv_meas /= num_samples
field_val = np.nansum(conv_meas, axis=(0, 1))
del conv_meas
elif isinstance(reference_field, ift.Field):
field_val = reference_field.val
else:
field_val = reference_field
reference_field = np.nanmean(burn_in_samples, axis=(0, 1))
if isinstance(reference_field, dict):
reference_field = reference_field[field_name]
if isinstance(reference_field, ift.Field):
reference_field = reference_field.val
if solution is not None:
if isinstance(solution, ift.MultiField) or isinstance(solution, dict):
solution = solution[field_name]
if isinstance(solution, ift.Field):
solution = solution.val
if burn_in_length <= 0:
burn_in_length = None
# start QT App
q_app = QtWidgets.QApplication(sys.argv)
aw = EvaluationWindow(field_val, all_samples, solution=solution)
aw = EvaluationWindow(reference_field, all_samples, solution=solution, max_burn_in_ix=burn_in_length)
aw.setWindowTitle('HMC Sampling Trajectories')
aw.show()
sys.exit(q_app.exec_())
......@@ -77,6 +77,10 @@ class FieldCanvas2D(MyMplCanvas):
class TrajectoriesCanvas(MyMplCanvas):
def __init__(self, *args, **kwargs):
self._max_burn_in_ix = kwargs.pop("max_burn_in_ix", None)
super(TrajectoriesCanvas, self).__init__(*args, **kwargs)
def compute_initial_figure(self):
self.axes.plot([0, 1], [0, 0], 'r')
......@@ -89,14 +93,17 @@ class TrajectoriesCanvas(MyMplCanvas):
legend.append('real')
self.axes.plot([0, x_max], [real_val