...
 
Commits (2)
......@@ -83,7 +83,7 @@ def get_poisson_data(n_pixels=1024, dimension=1, length=2., excitation=1., diffu
return data, s, R0, S
def get_hierarchical_data(n_pixels=100, dimension=2, length=2., correlation_length=4., signal_to_noise=2.):
def get_hierarchical_data(n_pixels=100, dimension=2, length=2., correlation_length=.6, signal_to_noise=1.):
"""
Generates a mock data set for a Wiener filter with log-normal prior.
Implemented as d = R(exp(s)) + n where d is the mock data field, s is the 'true' signal field with a gaussian
......
......@@ -11,15 +11,15 @@ def plot_simple_hierarchical_source(sl, d, R, sample_transformation):
true_flux = sample_transformation(sl)['signal']
fig, (x_ax, rx_ax, d_ax) = plt.subplots(nrows=1, ncols=3)
v_min, v_max = d.min(), d.max()
x_ax.imshow(true_flux.to_global_data(), vmin=v_min, vmax=v_max)
x_ax.imshow(true_flux.to_global_data(), interpolation=None, vmin=v_min, vmax=v_max)
x_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
x_ax.set_xlabel('x')
rx_ax.imshow(R(true_flux).to_global_data(), cmap='viridis', vmin=v_min, vmax=v_max)
rx_ax.imshow(R(true_flux).to_global_data(), interpolation=None, vmin=v_min, vmax=v_max)
rx_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
rx_ax.set_xlabel('R(x)')
im = d_ax.imshow(d.to_global_data(), vmin=v_min, vmax=v_max)
im = d_ax.imshow(d.to_global_data(), interpolation=None, vmin=v_min, vmax=v_max)
d_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
d_ax.set_xlabel('d')
......@@ -29,49 +29,33 @@ def plot_simple_hierarchical_source(sl, d, R, sample_transformation):
fig.colorbar(im, cax=cbar_ax)
def plot_simple_hierarchical_result(sl, sampler):
def plot_simple_hierarchical_result(sl, d, sampler):
"""
Parameters
----------
sl : ift.MultiField
d : ift.Field
sampler : hmcf.HMCSampler
"""
true_flux = sampler.sample_transform(sl)['signal']
fig, (signal_ax, mean_ax) = plt.subplots(nrows=1, ncols=2)
fig, axes = plt.subplots(nrows=2, ncols=2)
mean_val = sampler.mean['signal'].to_global_data()
v_min, v_max = mean_val.min(), mean_val.max()
signal_ax.imshow(true_flux.to_global_data(), vmin=v_min, vmax=v_max)
signal_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
signal_ax.set_xlabel('x')
im = mean_ax.imshow(mean_val, vmin=v_min, vmax=v_max)
mean_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
mean_ax.set_xlabel('HMC mean')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.1, 0.02, 0.8])
fig.colorbar(im, cax=cbar_ax)
fig, (diff_ax, std_ax) = plt.subplots(nrows=1, ncols=2)
std_val = ift.sqrt(sampler.var['signal']).to_global_data()
diff = abs(true_flux.to_global_data() - mean_val)
std_val = ift.sqrt(sampler.var['signal']).to_global_data()
v_min, v_max = min(std_val.min(), diff.min()), max(std_val.max(), diff.max())
images = [true_flux.to_global_data(), mean_val, diff, std_val]
labels = ['x', 'hmc mean', 'difference', 'hmc std']
v_min, v_max = d.min(), d.max()
diff_ax.imshow(diff, vmin=v_min, vmax=v_max)
diff_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
diff_ax.set_xlabel('difference')
im = std_ax.imshow(std_val, vmin=v_min, vmax=v_max)
std_ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
for i, ax in enumerate(axes.flatten()):
if i in (0, 1):
im = ax.imshow(images[i], interpolation=None, vmin=v_min, vmax=v_max)
else:
im = ax.imshow(images[i])
ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
labelbottom=False, labelleft=False)
std_ax.set_xlabel('HMC std')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.1, 0.02, 0.8])
fig.colorbar(im, cax=cbar_ax)
ax.set_xlabel(labels[i])
fig.colorbar(im, ax=ax)
def plot_poisson_result(data, s, sampler):
......
......@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
from logging import INFO, DEBUG
np.random.seed(41)
np.random.seed(40)
def get_ln_params(mean, mode):
......@@ -47,8 +47,9 @@ if __name__ == '__main__':
sampler.display = hmcf.TableDisplay
x_initial = [sl*c for c in [.5, .7, 1.2]]
sampler.run(100, x_initial=x_initial)
sampler.run(500, x_initial=x_initial)
print(sampler.mean['l_c'].to_global_data()[0])
plot_simple_hierarchical_source(sl, d, R, sample_trafo)
plot_simple_hierarchical_result(sl, sampler)
plot_simple_hierarchical_result(sl, d, sampler)
plt.show()
......@@ -253,14 +253,16 @@ class MassMain(MassBase):
if self.shared:
engage = incoming['accepted']
nominal_clearance = outgoing['converged'] and self._reevaluations[ch_id] > 0
if engage:
self._current_positions[ch_id] = incoming['sample']
self._got_current_position[ch_id] = True
if self._new_mass_flag[ch_id]: # check if there is an new 'shared' mass already
self._new_mass_flag[ch_id] = False
new_mass_flag = True
elif engage and (self._get_initial_mass[ch_id] or nominal_clearance):
elif engage and (self._get_initial_mass[ch_id] or nominal_clearance) and self._got_current_position.all():
# add a new sample
self._current_positions[ch_id] = incoming['sample']
self._got_current_position[ch_id] = True
# try curvature based approach (every time the conditions above are met)
position = self._get_position_for_reeval()
curvature = self._potential.at(position).curvature
......