...
 
Commits (41)
from .generate_mock_data import *
from .energies import *
from .generate_mock_data import *
from .plotting import *
from .energies import *
import nifty4 as ift
from copy import copy
import numpy as np
__all__ = ['SimpleHierarchicalEnergy']
class PowSpec(object):
def __init__(self, field_variance):
self._l_c = field_variance
def __call__(self, k):
pre_fac = self._l_c**4
return pre_fac / (1 + self._l_c * k)**4
def grad(self, k):
pre_fac = - .25 * self._l_c**5
return pre_fac / (1 + self._l_c * k)**3
def curv(self, k):
pre_fac = .25 * self._l_c**6
return pre_fac / (1 + self._l_c * k)**2 / (5 + 2 * self._l_c * k)
class LogPowSpec(PowSpec):
def __init__(self, field_variance):
super(LogPowSpec, self).__init__(field_variance)
self._log_l_c = np.log(self._l_c)
def __call__(self, k):
return 4 * (self._log_l_c - np.log(1 + self._l_c * k))
def grad(self, k):
return 4 * (1. / self._l_c - k / (1 + self._l_c * k))
def curv(self, k):
return 4 * (- 1. / self._l_c**2 + (k / (1 + self._l_c * k))**2)
class PriorEnergy(ift.Energy):
def __init__(self, position, l_c_mu, l_c_sig2):
"""
Parameters
----------
position : ift.MultiField
The position of the energy.
The MultiField contains entries for the 'signal' and the correlation length 'l_c'.
l_c_mu : float
Hyperparameter for the log-normal distribution of l_c.
l_c_sig2 : float
Hyperparameter for the log-normal distribution of l_c.
"""
super(PriorEnergy, self).__init__(position)
self._l_c = self._position['l_c']
self._logl_c = ift.log(self._l_c)
self._l_c_mu = l_c_mu
self._l_c_sig2 = l_c_sig2
def at(self, position):
return self.__class__(position, self._l_c_mu, self._l_c_sig2)
@property
@ift.memo
def value(self):
part = (self._logl_c.to_global_data() - self._l_c_mu)**2 / 2. / self._l_c_sig2 + self._logl_c.to_global_data()
return part[0]
@property
@ift.memo
def gradient(self):
signal_field = self._position['signal'].copy().fill(0.)
l_c_field = ((self._logl_c - self._l_c_mu) / self._l_c_sig2 + 1.) / self._l_c
result_dict = {'signal': signal_field, 'l_c': l_c_field}
return ift.MultiField(val=result_dict)
@property
@ift.memo
def curvature(self):
signal_field = self._position['signal'].copy().fill(0.)
l_c_field = ((1. - self._logl_c + self._l_c_mu) / self._l_c_sig2 - 1.) / self._l_c**2
diagonal = ift.MultiField(val={'signal': signal_field, 'l_c': l_c_field})
return ift.makeOp(diagonal)
class LikelihoodEnergy(ift.Energy):
def __init__(self, position, d, R, N, HT, inverter=None):
"""
The likelyhood part of the full posterior energy, i.e.
0.5 (sS^{-1}s + res N.inverse res + log(det(S)))
where res = d - R exp(s)
Parameters
----------
position : ift.MultiField
Containing 'signal' and 'l_c'
d : ift.Field
The data vector.
R : ift.LinearOperator
The instrument's response.
N : ift.EndomorphicOperator
The covariance of the noise.
HT : ift.HarmonicTransformOperator
"""
super(LikelihoodEnergy, self).__init__(position)
self._d = d
self._R = R
self._N = N
self._HT = HT
self._inverter = inverter
self._s = self._position['signal']
self._l_c_val = self._position['l_c'].to_global_data()[0]
self._x = ift.exp(self._HT(self._s))
self._res = self._d - self._R.times(self._x)
self._pow_spec = PowSpec(self._l_c_val)
self._log_pow_spec = LogPowSpec(self._l_c_val)
self._s_space = self._HT.domain
self._l_c_space = self._position.domain['l_c']
def at(self, position):
return self.__class__(position, self._d, self._R, self._N, self._HT, self._inverter)
def _l_c_det_part(self, power_spectrum):
covariance = ift.create_power_operator(domain=self._s_space, power_spectrum=power_spectrum)
return ift.Field(domain=self._l_c_space, val=.5 * covariance.times(ift.Field.full(self._s_space, val=1.)).sum())
def _l_c_exp_part(self, power_spectrum):
covariance = ift.create_power_operator(domain=self._s_space, power_spectrum=power_spectrum)
return ift.Field(domain=self._l_c_space, val=.5 * self._s.vdot(covariance.inverse_times(self._s)))
@property
@ift.memo
def value(self):
S = ift.create_power_operator(domain=self._s_space, power_spectrum=self._pow_spec)
s_part = 0.5 * (self._s.vdot(S.inverse_times(self._s)) + self._res.vdot(self._N.inverse_times(self._res)))
return s_part + self._l_c_det_part(self._log_pow_spec).to_global_data()[0]
@property
@ift.memo
def gradient(self):
S = ift.create_power_operator(domain=self._s_space, power_spectrum=self._pow_spec)
signal_grad = (S.inverse_times(self._s)
- self._HT.adjoint_times(self._x * self._R.adjoint_times(self._N.inverse_times(self._res))))
l_c_grad = self._l_c_det_part(self._log_pow_spec.grad) + self._l_c_exp_part(self._pow_spec.grad)
val = {'signal': signal_grad, 'l_c': l_c_grad}
return ift.MultiField(val=val)
@property
@ift.memo
def curvature(self):
# not really the curvature but rather a metric in the context of differential manifolds
loc_R = self._R * ift.makeOp(self._x) * self._HT
S = ift.create_power_operator(domain=self._s_space, power_spectrum=self._pow_spec)
signal_curv = ift.library.WienerFilterCurvature(loc_R, self._N, S, self._inverter)
l_c_curv_diag = self._l_c_det_part(self._log_pow_spec.curv) + self._l_c_exp_part(self._pow_spec.curv)
operators = {'signal': signal_curv, 'l_c': ift.makeOp(l_c_curv_diag)}
return ift.BlockDiagonalOperator(operators)
class SimpleHierarchicalEnergy(ift.Energy):
def __init__(self, position, d, R, N, HT, l_c_mu, l_c_sig2, inverter=None):
"""
A NIFTy Energy implementing a simple hierarchical model.
The 'signal' s has a Gaussian prior with covariance S.
S is diagonal in the harmonic representation of s and follows a power law which is dependent on another free
parameter, the correlation length 'l_c'.
l_c is log-normal distributed.
In this example, exp(s) represents the photon flux coming from the large-scale structure of the universe.
Parameters
----------
position : ift.MultiField
Containing 'signal' and 'l_c'
d : ift.Field
The data vector.
R : ift.LinearOperator
The instrument's response.
N : ift.EndomorphicOperator
The covariance of the noise.
l_c_mu : float
Hyperparameter for the log-normal distribution of l_c.
l_c_sig2 : float
Hyperparameter for the log-normal distribution of l_c. HT : ift.HarmonicTransformOperator
inverter : ift.Minimizer
Numerical method for inverting LinearOperators.
This is necessary to be able to sample from the curvature which is required if the mass matrix is supposed
to be reevaluated.
"""
super(SimpleHierarchicalEnergy, self).__init__(position)
self._prior = PriorEnergy(self._position, l_c_mu, l_c_sig2)
self._likel = LikelihoodEnergy(self._position, d, R, N, HT, inverter=inverter)
def at(self, position):
new_energy = copy(self)
new_energy._position = position
new_energy._prior = self._prior.at(position)
new_energy._likel = self._likel.at(position)
return new_energy
@property
def value(self):
return self._prior.value + self._likel.value
@property
def gradient(self):
return self._prior.gradient + self._likel.gradient
@property
@ift.memo
def curvature(self):
return ift.InversionEnabler(self._prior.curvature + self._likel.curvature, inverter=self._likel._inverter)
import numpy as np
import nifty4 as ift
__all__ = ['get_poisson_data', 'get_hierarchical_data']
class Exp3(object):
def __call__(self, x):
return ift.exp(3*x)
def derivative(self, x):
return 3*ift.exp(3*x)
def get_poisson_data(n_pixels=1024, dimension=1, length=2., excitation=1., diffusion=10.):
"""
Parameters
----------
n_pixels : int
Number of pixels (per dimension)
dimension : int
Dimensionality of underlying space
length : float
Total physical length of regular grid.
excitation : float
Excitation field level.
diffusion : float
Diffusion constant.
Returns
-------
data : ift.Field
The poisson distributed `meausured' data field.
s : ift.Field
The `true' signal field (in harmonic space representation).
R : ift.LinearOperator
Instrument's response.
S : ift.LinearOperator
Correlation of signal field.
"""
# small number to tame zero mode
eps = 1e-8
# Define data gaps
na = int(0.6*n_pixels)
nb = int(0.7*n_pixels)
# Set up derived constants
amp = excitation/(2*diffusion) # spectral normalization
pow_spec = lambda k: amp / (eps + k**2)
pixel_width = length/n_pixels
# Set up the geometry
s_domain = ift.RGSpace(tuple(n_pixels for _ in range(dimension)), distances=pixel_width)
h_domain = s_domain.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_domain, s_domain)
# Create mock signal
S = ift.create_power_operator(h_domain, power_spectrum=pow_spec)
s = S.draw_sample()
# remove zero mode
glob = s.to_global_data()
glob[0] = 0.
s = ift.Field.from_global_data(s.domain, glob)
# Setting up an exemplary response
GeoRem = ift.GeometryRemover(s_domain)
d_domain = GeoRem.target[0]
mask = np.ones(d_domain.shape)
idx = np.arange(na, nb)
mask[np.ix_(*tuple(idx for _ in range(dimension)))] = 0.
print(mask)
fmask = ift.Field.from_global_data(d_domain, mask)
Mask = ift.DiagonalOperator(fmask)
R0 = Mask*GeoRem
# generate poisson distributed data counts
nonlin = Exp3()
lam = R0(nonlin(HT(s)))
data = ift.Field.from_local_data(d_domain, np.random.poisson(lam.local_data).astype(np.float64))
return data, s, R0, S
def get_hierarchical_data(n_pixels=100, dimension=2, length=2., correlation_length=4., signal_to_noise=2.):
"""
Generates a mock data set for a Wiener filter with log-normal prior.
Implemented as d = R(exp(s)) + n where d is the mock data field, s is the 'true' signal field with a gaussian
distribution, R the instrument response and n white noise.
Parameters
----------
n_pixels : int
Number of pixels per dimension.
Default: 1024
dimension : int
Number of dimensions
Default: 1
length : float
Physical length of underlying space.
Default: 2.
correlation_length : float
Default: 0.1
signal_to_noise : float
Amplitude of noise compared to variance of the signal.
Default: 1.
Returns
-------
d : ift.Field
The mock data field.
s : ift.Field
The 'true' signal (in harmonic co-domain)
R : ift.LinearOperator
The instrument response (harmonic signal space -> data space)
N : ift.DiagonalOperator
The covariance for the noise n defined on the data domain.
"""
shape = tuple(n_pixels for _ in range(dimension))
pixel_width = length / n_pixels
a = correlation_length**4
def p_spec(k):
return a / (k*correlation_length + 1)**4
non_linearity = ift.library.nonlinearities.Exponential()
# Set up position space
s_space = ift.RGSpace(shape, distances=pixel_width)
h_space = s_space.get_default_codomain()
param_space = ift.UnstructuredDomain((1,))
# Define harmonic transformation and associated harmonic space
HT = ift.HarmonicTransformOperator(h_space, target=s_space)
# Drawing a sample sh from the prior distribution in harmonic space
S = ift.ScalingOperator(1., h_space)
sh = S.draw_sample()
p_op = ift.create_power_operator(h_space, p_spec)
power = ift.sqrt(p_op(ift.Field.full(h_space, 1.)))
# Choosing the measurement instrument
mask = np.ones(s_space.shape)
idx = np.arange(int(0.6*n_pixels), int(0.8*n_pixels))
mask[np.ix_(*tuple(idx for _ in range(dimension)))] = 0.
fmask = ift.Field.from_global_data(s_space, mask)
R = ift.GeometryRemover(s_space) * ift.DiagonalOperator(fmask)
d_space = R.target
# Creating the mock data
true_sky = non_linearity(HT(power*sh))
noiseless_data = R(true_sky)
noise_amplitude = noiseless_data.val.std() / signal_to_noise
N = ift.ScalingOperator(noise_amplitude ** 2, d_space)
n = N.draw_sample()
d = noiseless_data + n
l_c = ift.Field(domain=param_space, val=correlation_length)
val = {'signal': power*sh, 'l_c': l_c}
return d, ift.MultiField(val=val), R, N
import matplotlib.pyplot as plt
import nifty4 as ift
import hmcf
import numpy as np
__all__ = ['plot_simple_hierarchical_result', 'plot_simple_hierarchical_source', 'plot_poisson_result']
def plot_simple_hierarchical_source(sl, d, R, sample_transformation):
true_flux = sample_transformation(sl)['signal']
fig, (x_ax, rx_ax, d_ax) = plt.subplots(nrows=1, ncols=3)
v_min, v_max = d.min(), d.max()
x_ax.imshow(true_flux.to_global_data(), vmin=v_min, vmax=v_max)
x_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
x_ax.set_xlabel('x')
rx_ax.imshow(R(true_flux).to_global_data(), cmap='viridis', vmin=v_min, vmax=v_max)
rx_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
rx_ax.set_xlabel('R(x)')
im = d_ax.imshow(d.to_global_data(), vmin=v_min, vmax=v_max)
d_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
d_ax.set_xlabel('d')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.1, 0.02, 0.8])
fig.colorbar(im, cax=cbar_ax)
def plot_simple_hierarchical_result(sl, sampler):
"""
Parameters
----------
sl : ift.MultiField
sampler : hmcf.HMCSampler
"""
true_flux = sampler.sample_transform(sl)['signal']
fig, (signal_ax, mean_ax) = plt.subplots(nrows=1, ncols=2)
mean_val = sampler.mean['signal'].to_global_data()
v_min, v_max = mean_val.min(), mean_val.max()
signal_ax.imshow(true_flux.to_global_data(), vmin=v_min, vmax=v_max)
signal_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
signal_ax.set_xlabel('x')
im = mean_ax.imshow(mean_val, vmin=v_min, vmax=v_max)
mean_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
mean_ax.set_xlabel('HMC mean')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.1, 0.02, 0.8])
fig.colorbar(im, cax=cbar_ax)
fig, (diff_ax, std_ax) = plt.subplots(nrows=1, ncols=2)
std_val = ift.sqrt(sampler.var['signal']).to_global_data()
diff = abs(true_flux.to_global_data() - mean_val)
v_min, v_max = min(std_val.min(), diff.min()), max(std_val.max(), diff.max())
diff_ax.imshow(diff, vmin=v_min, vmax=v_max)
diff_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
diff_ax.set_xlabel('difference')
im = std_ax.imshow(std_val, vmin=v_min, vmax=v_max)
std_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
std_ax.set_xlabel('HMC std')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.1, 0.02, 0.8])
fig.colorbar(im, cax=cbar_ax)
def plot_poisson_result(data, s, sampler):
"""
Plots the result of the poisson demo script.
Works only with one-dimensional problems.
Saves the result to a pdf called 'HMC_Poisson.pdf'.
Parameters
----------
data : ift.Field
The data vector.
s : ift.Field
The true signal field.
sampler : hmcf.HMCSampler
The sampler used for solving the problem.
"""
hmc_mean = sampler.mean
hmc_std = ift.sqrt(sampler.var)
HT = ift.HarmonicTransformOperator(s.domain)
co_dom = HT.target
plt.clf()
if len(co_dom.shape) > 1:
raise TypeError('Poisson demo result plotting only possible for one-dimensional problems.')
pixel_width = co_dom[0].distances[0]
length = co_dom.shape[0] * pixel_width
x = np.arange(0, length, pixel_width)
plt.rcParams["text.usetex"] = True
c1 = (hmc_mean-hmc_std).to_global_data()
c2 = (hmc_mean+hmc_std).to_global_data()
plt.scatter(x, data.to_global_data(), label='data', s=1, color='blue', alpha=0.5)
plt.plot(x, sampler.sample_transform(s).to_global_data(), label='true sky', color='black')
plt.fill_between(x, c1, c2, color='orange', alpha=0.4)
plt.plot(x, hmc_mean.to_global_data(), label='hmc mean', color='orange')
plt.xlim([0, length])
plt.ylim([-0.1, 7.5])
plt.title('Poisson log-normal HMC reconstruction')
plt.legend()
plt.savefig('HMC_Poisson.pdf')
from demo_utils import *
import numpy as np
import nifty4 as ift
import hmcf
import matplotlib.pyplot as plt
from logging import INFO, DEBUG
np.random.seed(41)
def get_ln_params(mean, mode):
mu = np.log(mode*mean**2) / 3.
sig2 = 2. / 3. * np.log(mean/mode)
return mu, sig2
if __name__ == '__main__':
ift.logger.setLevel(INFO)
hmcf.logger.setLevel(INFO)
# hyperparameter
l_c_mean = 2.
l_c_mode = .1
l_c_mu, l_c_sig2 = get_ln_params(l_c_mean, l_c_mode)
d, sl, R, N = get_hierarchical_data()
s_space = sl['signal'].domain
HT = ift.HarmonicTransformOperator(s_space)
# energy related stuff
ICI = ift.GradientNormController(iteration_limit=2000,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=ICI)
energy = SimpleHierarchicalEnergy(sl, d, R, N, HT, l_c_mu, l_c_sig2, inverter=inverter)
def sample_trafo(x):
val = dict(x)
val['signal'] = ift.exp(HT(val['signal']))
return ift.MultiField(val=val)
sampler = hmcf.HMCSampler(energy, num_processes=6, sample_transform=sample_trafo)
sampler.display = hmcf.TableDisplay
x_initial = [sl*c for c in [.5, .7, 1.2]]
sampler.run(100, x_initial=x_initial)
plot_simple_hierarchical_source(sl, d, R, sample_trafo)
plot_simple_hierarchical_result(sl, sampler)
plt.show()
import numpy as np
import nifty4 as ift
import hmcf
from matplotlib import pyplot as plt
from demo_utils import *
class Exp3(object):
......@@ -13,131 +13,31 @@ class Exp3(object):
if __name__ == "__main__":
np.random.seed(20)
# Set up physical constants
nu = 1. # excitation field level
kappa = 10. # diffusion constant
eps = 1e-8 # small number to tame zero mode
sigma_n = 0.2 # noise level
sigma_n2 = sigma_n**2
L = 1. # Total length of interval or volume the field lives on
nprobes = 1000 # Number of probes for uncertainty quantification
# Define resolution (pixels per dimension)
N_pixels = 1024
# Define data gaps
N1a = int(0.6*N_pixels)
N1b = int(0.64*N_pixels)
N2a = int(0.67*N_pixels)
N2b = int(0.8*N_pixels)
# Set up derived constants
amp = nu/(2*kappa) # spectral normalization
pow_spec = lambda k: amp / (eps + k**2)
lambda2 = 2*kappa*sigma_n2/nu # resulting correlation length squared
lambda1 = np.sqrt(lambda2)
pixel_width = L/N_pixels
x = np.arange(0, 1, pixel_width)
# Set up the geometry
s_domain = ift.RGSpace([N_pixels], distances=pixel_width)
h_domain = s_domain.get_default_codomain()
HT = ift.HarmonicTransformOperator(h_domain, s_domain)
aHT = HT.adjoint
# Create mock signal
Phi_h = ift.create_power_operator(h_domain, power_spectrum=pow_spec)
phi_h = Phi_h.draw_sample()
# remove zero mode
glob = phi_h.to_global_data()
glob[0] = 0.
phi_h = ift.Field.from_global_data(phi_h.domain, glob)
phi = HT(phi_h)
num_processes = 5
# Setting up an exemplary response
GeoRem = ift.GeometryRemover(s_domain)
d_domain = GeoRem.target[0]
mask = np.ones(d_domain.shape)
mask[N1a:N1b] = 0.
mask[N2a:N2b] = 0.
fmask = ift.Field.from_global_data(d_domain, mask)
Mask = ift.DiagonalOperator(fmask)
R0 = Mask*GeoRem
R = R0*HT
np.random.seed(20)
# Linear measurement scenario
N = ift.ScalingOperator(sigma_n2, d_domain) # Noise covariance
n = Mask(N.draw_sample()) # seting the noise to zero in masked region
d = R(phi_h) + n
data, s, R, S = get_poisson_data()
# Wiener filter
j = R.adjoint_times(N.inverse_times(d))
# additional needed objects
HT = ift.HarmonicTransformOperator(s.domain)
IC = ift.GradientNormController(name="inverter", iteration_limit=500,
tol_abs_gradnorm=1e-3)
inverter = ift.ConjugateGradient(controller=IC)
D = (ift.SandwichOperator.make(R, N.inverse) + Phi_h.inverse).inverse
D = ift.InversionEnabler(D, inverter, approximation=Phi_h)
m = HT(D(j))
# Uncertainty
D = ift.SandwichOperator.make(aHT, D) # real space propagator
Dhat = ift.probe_with_posterior_samples(D.inverse, None,
nprobes=nprobes)[1]
nonlin = Exp3()
lam = R0(nonlin(HT(phi_h)))
data = ift.Field.from_local_data(d_domain, np.random.poisson(lam.local_data).astype(np.float64))
# initial guess
psi0 = ift.full(h_domain, 1e-7)
energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h,
inverter)
IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
tol_abs_gradnorm=1e-4)
minimizer = ift.RelaxedNewton(IC1)
energy = minimizer(energy)[0]
non_lin = Exp3()
m = HT(energy.position)
energy = ift.library.PoissonEnergy(s, data, R, non_lin, HT, S, inverter)
# Using an HMC sampler
def sample_transform(z):
return nonlin(HT(z))
return non_lin(HT(z))
num_processes = 5
mass_reevaluations = 3
num_samples = 200
max_burn_in = 5000
convergence_tolerance = 1.
x_initial = [energy.position + 0.01*Phi_h.draw_sample() for _ in range(num_processes)]
x_initial = [energy.position + 0.01*S.draw_sample() for _ in range(num_processes)]
poisson_hmc = hmcf.HMCSampler(potential=energy, num_processes=num_processes, sample_transform=sample_transform)
poisson_hmc.display = hmcf.TableDisplay
poisson_hmc.epsilon = hmcf.EpsilonPowerLawDivergence
poisson_hmc.mass.reevaluations = mass_reevaluations
poisson_hmc.run(num_samples=num_samples, max_burn_in=max_burn_in, convergence_tolerance=convergence_tolerance,
x_initial=x_initial)
hmc_mean = poisson_hmc.mean
hmc_std = ift.sqrt(poisson_hmc.var)
# Plotting
sig = ift.sqrt(Dhat)
poisson_hmc.mass.reevaluations = 2
poisson_hmc.run(200, x_initial=x_initial)
x_mod = np.where(mask > 0, x, None)
plt.clf()
plt.rcParams["text.usetex"] = True
c1 = (hmc_mean-hmc_std).to_global_data()
c2 = (hmc_mean+hmc_std).to_global_data()
plt.scatter(x_mod, data.to_global_data(), label='data', s=1, color='blue', alpha=0.5)
plt.plot(x, nonlin(phi).to_global_data(), label='true sky', color='black')
plt.plot(x, nonlin(m).to_global_data(), label='reconstructed sky', color='red')
plt.fill_between(x, c1, c2, color='orange', alpha=0.4)
plt.plot(x, hmc_mean.to_global_data(), label='hmc mean', color='orange')
plt.xlim([0, L])
plt.ylim([-0.1, 7.5])
plt.title('Poisson log-normal HMC reconstruction')
plt.legend()
plt.savefig('HMC_Poisson.pdf')
plot_poisson_result(data, s, poisson_hmc)
This diff is collapsed.
......@@ -200,6 +200,20 @@ class MassMain(MassBase):
self._new_mass_flag = np.full(self._num_chains, False, dtype=bool)
self._current_positions = [self._potential.position.copy() for _ in range(self._num_chains)]
self._got_current_position = self._new_mass_flag.copy()
get_init_mass_val = self._get_initial_mass
self._get_initial_mass = np.full(self._num_chains, get_init_mass_val)
@MassBase.get_initial_mass.getter
def get_initial_mass(self):
return self._get_initial_mass.any()
@get_initial_mass.setter
def get_initial_mass(self, value):
if value:
self._reevaluations += 1
else:
self._dec_reevaluations()
self._get_initial_mass = np.full(self._num_chains, value)
def get_client(self):
"""
......@@ -212,7 +226,7 @@ class MassMain(MassBase):
mass_client = MassChain(self._potential, num_samples=self.num_samples,
mass_reevaluations=self._reevaluations[0])
mass_client._operator = self._operator
mass_client._get_initial_mass = self._get_initial_mass
mass_client._get_initial_mass = self._get_initial_mass.any()
mass_client.shared = self.shared
return mass_client
......@@ -237,39 +251,46 @@ class MassMain(MassBase):
step = incoming['step']
new_mass_flag = False
if self.shared:
engage = incoming['accepted'] and incoming['epsilon'].converged
engage = incoming['accepted']
nominal_clearance = outgoing['converged'] and self._reevaluations[ch_id] > 0
if self._new_mass_flag[ch_id]: # check if there is an new 'shared' mass already
self._new_mass_flag[ch_id] = False
new_mass_flag = True
elif engage and (self._get_initial_mass or nominal_clearance):
elif engage and (self._get_initial_mass[ch_id] or nominal_clearance):
# add a new sample
self._current_positions[ch_id] = incoming['sample']
self._got_current_position[ch_id] = True
# try curvature based approach (every time the conditions above are met)
position = self._get_position_for_reeval()
curvature = self._potential.at(position).curvature
self._operator = self._get_mass_operator(curvature)
new_mass_flag = True
logger.info('Mass matrix reevaluated. (step {:d})'.format(step))
self.reset(reset_operator=False)
self._new_mass_flag = np.full_like(self._new_mass_flag, True)
self._new_mass_flag[ch_id] = False
try:
new_op = self._get_mass_operator(curvature)
except ValueError:
logger.warning('Curvature not positive definite, try again in next step. (step {:d})'.format(step))
new_mass_flag = False
else:
logger.info('Mass matrix reevaluated. (step {:d})'.format(step))
new_mass_flag = True
self._operator = new_op
self.reset(reset_operator=False)
self._new_mass_flag = np.full_like(self._new_mass_flag, True)
self._new_mass_flag[ch_id] = False
# if this was an initial mass thing, remember to start the real work.
if new_mass_flag:
outgoing.update(new_mass=self._operator)
if self._get_initial_mass:
self._get_initial_mass = False
if self._get_initial_mass[ch_id]:
self._get_initial_mass = np.full(self._num_chains, False)
else:
self._reevaluations[ch_id] -= 1
else:
# chains do all the work. Just do some stuff for information purposes
if incoming['mass_reeval']:
self._reevaluations[ch_id] -= 1
if self._get_initial_mass: # this is problematic since it should be an num_chains dim vector.
self._get_initial_mass = False
if self._get_initial_mass[ch_id]: # this is problematic since it should be an num_chains dim vector.
self._get_initial_mass[ch_id] = False
new_mass_flag = True
return new_mass_flag
......@@ -338,16 +359,22 @@ class MassChain(MassBase):
logger.info('Received new Mass matrix. (step {:d})'.format(step))
else:
sample = outgoing['sample'] # type: ift.Field
engage = outgoing['accepted'] and outgoing['epsilon'].converged
engage = outgoing['accepted']
nominal_clearance = incoming['converged'] and self._reevaluations[ch_id] > 0
if engage and (self._get_initial_mass or nominal_clearance):
# curvature based approach
curvature = self._potential.at(sample).curvature
self._operator = self._get_mass_operator(curvature)
logger.info('Mass matrix reevaluated. (step {:d})'.format(step))
new_mass_flag = True
self.reset(reset_operator=False)
try:
new_op = self._get_mass_operator(curvature)
except ValueError:
logger.warning('Curvature not positive definite, try again in next step. (step {:d})'.format(step))
new_mass_flag = False
else:
logger.info('Mass matrix reevaluated. (step {:d})'.format(step))
new_mass_flag = True
self._operator = new_op
self.reset(reset_operator=False)
if new_mass_flag:
outgoing.update(mass_reeval=True)
......
......@@ -25,7 +25,7 @@ import os
__all__ = ['get_autocorrelation']
def get_autocorrelation(path, shift=1):
def get_autocorrelation(path, shift=1, field_name=None):
"""
Calculates the autocorrelation for the samples of a h5 run-file.
It follows the convention of numpy correlate which is
......@@ -37,6 +37,8 @@ def get_autocorrelation(path, shift=1):
path to run file or sampler directory.
If the latter is used, the latest run file in this directory is evaluated.
shift : int
field_name : str, optional
If MultiFields were used, this specifies which of the sub-fields is supposed to be used.
Returns
-------
auto_corr : np.ndarray
......@@ -45,7 +47,9 @@ def get_autocorrelation(path, shift=1):
"""
samples = load(path=path)
if isinstance(samples, dict):
raise NotImplementedError('MultiField capability not implemented for tools yet')
if field_name is None:
raise ValueError("Samples are present as MultiFields. This requires the field_name keyword argument. ")
samples = samples[field_name]
samples[np.isnan(samples)] = 0
shift_samples = samples.copy()
reverse = shift < 0
......
......@@ -28,7 +28,7 @@ from PyQt5 import QtWidgets
__all__ = ['show_trajectories']
def show_trajectories(path, reference_field='mean', solution=None, start=0, stop=None, step=1, field_name=None):
def show_trajectories(path, reference_field=None, field_name=None, solution=None, start=0, stop=None, step=1):
"""
A GUI for examining the trajectories of the individual Markov chains.
......@@ -37,14 +37,14 @@ def show_trajectories(path, reference_field='mean', solution=None, start=0, stop
path : str
path to the h5 file containing the samples or the directory where the run files are located.
If a directory is passed, the latest run file is loaded.
reference_field : str, ift.Field, np.ndarray
reference_field : ift.Field, ift.MultiField, np.ndarray, dict(str -> np.ndarray)
the image of interest. As string the following statements are valid:
'mean' : use the final mean value of the sampling process. If there are only burn_in samples use the mean
of the last sample of each chain
'bad_pixel_distr' : the sum of all convergence measures during burn in (only possible if HMCSamplerDebug or
the like has been used with approbiate attributes
it is also possible to pass your own reference image as either ift-field or numpy array. Check shape and
dimensions though
field_name : str, optional
If MultiFields were used, this specifies which of the sub-fields is supposed to be displayed.
solution : ift.Field, ift.MultiField, np.ndarray, optional
If known the 'right' solution for your sampling problem
start : int
......@@ -55,8 +55,6 @@ def show_trajectories(path, reference_field='mean', solution=None, start=0, stop
step : int
Only use every 'step' sample.
Default: 1
field_name : str, optional
If MultiFields were used, this specifies which of the sub-fields is supposed to be displayed.
"""
# load required samples
burn_in_samples = load(path, 'burn_in', start=start, stop=stop, step=step)
......@@ -85,28 +83,15 @@ def show_trajectories(path, reference_field='mean', solution=None, start=0, stop
all_samples = burn_in_samples
# check reference_field input and set field_val accordingly
if isinstance(reference_field, ift.Field):
field_val = reference_field.val
elif isinstance(reference_field, np.ndarray):
field_val = reference_field
elif reference_field == 'mean':
if reference_field is None:
try:
field_val = load_mean(path)
reference_field = load_mean(path)
except (KeyError, ValueError):
field_val = np.nanmean(burn_in_samples, axis=(0, 1))
elif reference_field == 'bad_pixel_distr':
conv_meas = load(path, 'conv_meas', start=start, stop=stop, step=step)
conv_meas[np.isnan(conv_meas)] = 0.
inf_mask = np.isinf(conv_meas)
kinda_normalizer = np.nanmax(conv_meas[np.logical_not(inf_mask)])
conv_meas[inf_mask] = kinda_normalizer
conv_meas /= kinda_normalizer
num_samples = conv_meas.shape[0] * conv_meas.shape[1]
conv_meas /= num_samples
field_val = np.nansum(conv_meas, axis=(0, 1))
del conv_meas
else:
raise TypeError("Reference_field type not supported.")
reference_field = np.nanmean(burn_in_samples, axis=(0, 1))
if isinstance(reference_field, dict):
reference_field = reference_field[field_name]
if isinstance(reference_field, ift.Field):
reference_field = reference_field.val
if solution is not None:
if isinstance(solution, ift.MultiField) or isinstance(solution, dict):
......@@ -116,12 +101,10 @@ def show_trajectories(path, reference_field='mean', solution=None, start=0, stop
if burn_in_length <= 0:
burn_in_length = None
if isinstance(field_val, dict):
field_val = field_val[field_name]
# start QT App
q_app = QtWidgets.QApplication(sys.argv)
aw = EvaluationWindow(field_val, all_samples, solution=solution, max_burn_in_ix=burn_in_length)
aw = EvaluationWindow(reference_field, all_samples, solution=solution, max_burn_in_ix=burn_in_length)
aw.setWindowTitle('HMC Sampling Trajectories')
aw.show()
sys.exit(q_app.exec_())
......@@ -158,7 +158,7 @@ class EvaluationWindow(QtWidgets.QMainWindow):
if len(self.solution.shape) == 1:
self.trs.update_figure(self.samples[:, :, x], self.solution[x])
else:
self.trs.update_figure(self.samples[:, :, x, y], self.solution[x, y])
self.trs.update_figure(self.samples[:, :, y, x], self.solution[y, x])
def file_quit(self):
self.close()
......
This diff is collapsed.
......@@ -23,7 +23,7 @@ import curses
from ..convergence import Convergence
from .chain import Epsilon
from ..logging import logger, DisplayHandler
from logging import INFO, DEBUG, StreamHandler, Formatter, _checkLevel
from logging import INFO, DEBUG, StreamHandler, Formatter, NullHandler
__all__ = ['Display', 'LineByLineDisplay', 'TableDisplay']
......@@ -35,14 +35,18 @@ class Display(object):
This class can be used as an 'information display' but it shows nothing and leaves (other) terminal output
unchanged.
"""
_logging_hdlr = NullHandler()
def __init__(self, num_chains=1):
self._num_chains = num_chains
# logging related stuff
self._logging_hdlr.setLevel(logger.level)
def init_display(self):
pass
def update_chain(self, convergence, package, num_samples, samples_count, acceptance_rate):
def update_chain(self, convergence, package, num_samples, acceptance_rate):
"""
Parameters
----------
......@@ -50,8 +54,6 @@ class Display(object):
package : dict
num_samples : int
total number of steps after burn_in per chain
samples_count : int
number of steps since burn_in finished
acceptance_rate : float
acceptance rate of samples drawn after burn_in
"""
......@@ -63,6 +65,10 @@ class Display(object):
def close(self):
pass
def setLevel(self, level):
"""Sets the logging level during HMC sampling for the display."""
self._logging_hdlr.setLevel(level)
def _update_num_chains(self, new_val):
"""mostly for inheritance"""
self._num_chains = new_val
......@@ -74,10 +80,19 @@ class LineByLineDisplay(Display):
burn_in_debug_line = """dEnergy: {:9.2E}, ConvMeas: {:8.2E}, Epsilon: {:8.2E}"""
sampling_info_line = """Sampling progress: {:6.2f}%, Acceptance Rate: {:4^.2f}"""
_logging_hdlr = DisplayHandler(100)
_logging_hdlr_fmt = Formatter('{asctime} - {levelname}: {message}', style='{', datefmt='%H:%M:%S')
def __init__(self, num_chains=1):
super(LineByLineDisplay, self).__init__(num_chains=num_chains)
self._old_stream_hdlrs = []
for hdlr in logger.handlers:
if isinstance(hdlr, StreamHandler):
self._old_stream_hdlrs.append(hdlr)
logger.removeHandler(hdlr)
logger.addHandler(self._logging_hdlr)
def update_chain(self, convergence, package, num_samples, samples_count, acceptance_rate):
def update_chain(self, convergence, package, num_samples, acceptance_rate):
"""
Parameters
----------
......@@ -85,34 +100,40 @@ class LineByLineDisplay(Display):
package : dict
num_samples : int
total number of steps after burn_in per chain
samples_count : int
number of steps since burn_in finished
acceptance_rate : float
acceptance rate of samples drawn after burn_in
"""
ch_id = package['chain_identifier']
msg = ''
if package['burn_in']:
if logger.level <= INFO:
if self._logging_hdlr.level <= INFO:
msg += (self.burn_in_info_line.format(package['accepted'], convergence.converged[ch_id],
convergence.level[ch_id]) + '\n')
if logger.level <= DEBUG:
if self._logging_hdlr.level <= DEBUG:
msg += (self.burn_in_debug_line.format(package['delta_energy'], convergence.measure_max[ch_id],
package['epsilon'].val) + '\n')
else:
if logger.level <= INFO:
progress = float(samples_count) / num_samples
if self._logging_hdlr.level <= INFO:
progress = float(package['step']) / num_samples
msg += self.sampling_info_line.format(progress*100., acceptance_rate) + '\n'
for log in package['log']:
if log[0] >= self.level:
msg += log[1] + '\n'
import logging
logging.getLogger()
if msg:
print('chain', ch_id)
print(msg)
print('-'*80)
self._print_logs()
def _print_logs(self):
for msg in self._logging_hdlr.buffer:
print(msg)
print('-'*80)
self._logging_hdlr.flush()
def close(self):
logger.removeHandler(self._logging_hdlr)
for hdlr in self._old_stream_hdlrs:
logger.addHandler(hdlr)
self._old_stream_hdlrs = []
class TableDisplay(Display):
......@@ -126,8 +147,9 @@ class TableDisplay(Display):
row_burn_in = """ {:3d} | {:4^.2f} | {:9.2E} | {!r:5} {:3d} {:8.2E} {:6.2f}% | {:8.2E} {!r:5} {:8.2E} """
progress_bar_length = 33
row_sampling = """ {:3d} | {:4^.2f} | {:9.2E} | Progress: {:""" + str(progress_bar_length) + """} {:6.2f}% """
_table_hdlr = DisplayHandler(2000)
_table_hdlr_fmt = Formatter('{asctime} - {levelname}: {message}', style='{', datefmt='%H:%M:%S')
_logging_hdlr = DisplayHandler(2000)
_logging_hdlr_fmt = Formatter('{asctime} - {levelname}: {message}', style='{', datefmt='%H:%M:%S')
def __init__(self, num_chains=1, title=''):
"""
......@@ -153,14 +175,9 @@ class TableDisplay(Display):
self._title = title
self._reset_table()
# logging related stuff
self._level = logger.level
# for logging
self._old_stream_hdlrs = []
self._table_hdlr.setFormatter(self._table_hdlr_fmt)
def setLevel(self, level):
"""Sets the logging level during HMC sampling for the display."""
self._level = _checkLevel(level)
self._logging_hdlr.setFormatter(self._logging_hdlr_fmt)
def _reset_table(self):
if self._title == '':
......@@ -183,8 +200,7 @@ class TableDisplay(Display):
if isinstance(hdlr, StreamHandler):
self._old_stream_hdlrs.append(hdlr)
logger.removeHandler(hdlr)
self._table_hdlr.setLevel(self._level)
logger.addHandler(self._table_hdlr)
logger.addHandler(self._logging_hdlr)
self._window = curses.initscr()
self._window.clear()
......@@ -201,10 +217,10 @@ class TableDisplay(Display):
def _draw_log_msgs(self):
lines_to_end = self._window_shape[0] - self._eot - 1
for i, msg in enumerate(self._table_hdlr.buffer[-lines_to_end:]):
for i, msg in enumerate(self._logging_hdlr.buffer[-lines_to_end:]):
self._draw(self._eot+1+i, msg[:80])
def update_chain(self, convergence, package, num_samples, samples_count, acceptance_rate):
def update_chain(self, convergence, package, num_samples, acceptance_rate):
"""
Parameters
----------
......@@ -212,8 +228,6 @@ class TableDisplay(Display):
package : dict
num_samples : int
total number of steps after burn_in per chain
samples_count : int
number of steps since burn_in finished
acceptance_rate : float
acceptance rate of samples drawn after burn_in
"""
......@@ -231,7 +245,7 @@ class TableDisplay(Display):
epsilon.val, epsilon.converged, np.max(epsilon.convergence_measure),
num_mass_samples)
else:
progress = float(samples_count) / num_samples
progress = float(package['step']) / num_samples
# fill_string = '█'*int(progress*45) + '-'*(45-int(progress*45))
fill_string = ('|' * int(progress * self.progress_bar_length) + '-'
* (self.progress_bar_length - int(progress * self.progress_bar_length)))
......@@ -281,7 +295,7 @@ class TableDisplay(Display):
curses.echo()
curses.nocbreak()
curses.endwin()
logger.removeHandler(self._table_hdlr)
logger.removeHandler(self._logging_hdlr)
for hdlr in self._old_stream_hdlrs:
logger.addHandler(hdlr)
self._old_stream_hdlrs = []
......@@ -289,6 +303,6 @@ class TableDisplay(Display):
for line in self._table:
print(line)
print('')
for msg in self._table_hdlr.buffer:
for msg in self._logging_hdlr.buffer:
print(msg)
self._table_hdlr.flush()
self._logging_hdlr.flush()
......@@ -29,19 +29,21 @@ __all__ = ['Epsilon', 'EpsilonConst', 'EpsilonSimple', 'EpsilonPowerLaw', 'Epsil
class Epsilon(object):
"""
class to handle adjusting the epsilon parameter in the integration.
Class to handle adjusting the epsilon parameter in the integration.
Attributes
----------
val : float
value for epsilon.
Value for epsilon.
limits : list float
lower and upper bound for epsilon to avoid unrealistic values, especially in the beginning of optimization
(default: [1.E-20, 10.])
Lower and upper bound for epsilon to avoid unrealistic values, especially in the beginning of optimization
Default: [1.E-20, 10.]
convergence_tolerance : float
defines the range in which epsilon optimization is said to have converged (default: 0.5)
Defines the range in which epsilon optimization is said to have converged.
Default: 0.5
divergence_threshold : float
a value for which abs(denergy) is assumed to be divergent (default: 1.E20)
A value for which abs(denergy) is assumed to be divergent.
Default: 1.E20
"""
def __init__(self, target_acceptance_rate=0.8, limits=None):
"""
......@@ -52,7 +54,7 @@ class Epsilon(object):
limits : list of float
list with two elements containing minimum and maximum value of epsilon (default: [1.E-10, 10]
"""
self._locality = 10
self._locality = 50
self._tar = target_acceptance_rate
if limits is None:
self.limits = [1.E-40, 10.]
......@@ -67,7 +69,7 @@ class Epsilon(object):
self._converged = False
self._convergence_measure = [np.inf, np.inf]
self._min_samples = 10
self._min_samples = 30
self.convergence_tolerance = 0.5
self.divergence_threshold = 1E50
......@@ -76,7 +78,7 @@ class Epsilon(object):
@property
def locality(self):
"""int: how many recent epsilon/dEnergy pairs play a role in everything. Setting this will reset the instance!
default: 10"""
default: 50"""
return self._locality
@locality.setter
......@@ -86,10 +88,10 @@ class Epsilon(object):
else:
self._locality = 1
print('locality has to be bigger than zero. Set to 1.')
if self._locality < 10:
if self._locality < 30:
self._min_samples = self._locality
else:
self._min_samples = 10
self._min_samples = 30
self.reset()
@property
......@@ -296,7 +298,8 @@ class EpsilonPowerLawDivergence(Epsilon):
Attributes
----------
power : int
the power in power law
The power in power law.
Default: 2
"""
def __init__(self, target_acceptance_rate=0.8, limits=None):
"""
......@@ -313,7 +316,7 @@ class EpsilonPowerLawDivergence(Epsilon):
self._log_eps_hist = self._eps_hist.copy()
self._divergence_counter = 0
self._diverged_last_time = False
self.power = 5
self.power = 2
self._a = 1.
self._b = 2.
......
......@@ -24,6 +24,7 @@ import numpy as np
import h5py
from datetime import datetime
import nifty4 as ift
from ..logging import logger
__all__ = ['load', 'load_mean', 'load_var']
......@@ -93,7 +94,7 @@ def _load_single_field(field_name, h5_file, attr_type, field_meta_data, start=0,
stop = max_num_samples
num_samples = int(np.ceil(float(stop - start) / step))
if num_samples <= 0:
print('start/stop/step such that no samples will be returned')
logger.warning('start/stop/step such that no samples will be returned')
return np.full((field_meta_data[0][0],)+field_meta_data[0][1:], np.nan, dtype=field_meta_data[1])
# everything alright, start loading the data from the h5df file
......@@ -103,7 +104,7 @@ def _load_single_field(field_name, h5_file, attr_type, field_meta_data, start=0,
index = int(search(r'\d+', chain_name).group())
chain_data_chunk = chain_data[attr_type][field_name][start:stop:step]
field_val[index][:chain_data_chunk.shape[0]] = chain_data_chunk
print('loaded chain', index, 'for field', field_name)
logger.debug('loaded chain ' + str(index) + ' for field ' + field_name)
return field_val
......@@ -134,7 +135,7 @@ def load(path, attr_type='samples', start=0, stop=None, step=1):
"""
if os.path.isdir(path):
path = _get_recent_file_path(path)
print('path is not a file. Load most recent file instead:', path)
logger.debug('path is not a file. Load most recent file instead: ' + path)
result_field = dict()
f = h5py.File(path, 'r', libver='latest')
......@@ -168,7 +169,7 @@ def _load_single_field_mean(h5_file, field_name, field_meta_data, domain=None):
chunk_mean = np.mean(chunk, axis=0)
mean_val = np.average([mean_val, chunk_mean], axis=0, weights=weights)
weights[0] += weights[1]
print('processed chain', index, 'for mean value calculation.')
logger.debug('processed chain ' + str(index) + 'for mean value calculation.')
if domain is None:
return mean_val
......@@ -194,7 +195,7 @@ def load_mean(path, domain=None):
"""
if os.path.isdir(path):
path = _get_recent_file_path(path)
print('path is not a file. Load most recent file instead:', path)
logger.debug('path is not a file. Load most recent file instead: ' + path)
f = h5py.File(path, 'r', libver='latest')
fields_meta_data = _get_field_meta(f, attr_type='samples')
......@@ -228,7 +229,7 @@ def _load_single_field_var(h5_file, field_name, field_meta_data, domain=None):
for i in range(0, chain_samples.shape[0], chunk_size):
squared_chunk = (chain_samples[i:i + chunk_size])**2 / (num_samples - 1) # type: np.ndarray
sum_samples_squared += np.sum(squared_chunk, axis=0)
print('processed chain', index, 'for variance calculation.')
logger.debug('processed chain ' + str(index) + 'for variance value calculation.')
if domain is not None:
if isinstance(domain, ift.MultiDomain):
......@@ -260,7 +261,7 @@ def load_var(path, domain=None):
"""
if os.path.isdir(path):
path = _get_recent_file_path(path)
print('path is not a file. Load most recent file instead:', path)
logger.debug('path is not a file. Load most recent file instead: ' + path)
f = h5py.File(path, 'r', libver='latest')
fields_meta_data = _get_field_meta(h5_file=f, attr_type='samples')
......