Select Git revision
demo.py 4.97 KiB
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2020 Max-Planck-Society
# Author: Jakob Knollmueller
import nifty6 as ift
import nifty_hmc as hmc
import numpy as np
import tensorflow.compat.v1 as tf
import sys
from operators.tensorflow_operator import TensorFlowOperator
from operators.multinomial_energy import CategoricalEnergy
from matplotlib import pyplot as plt
from sugar import digit_sample_pic
tf.disable_v2_behavior()
if __name__ == '__main__':
comm, _, _, master = ift.utilities.get_MPI_params()
ift.random.push_sseq_from_seed(1)
# Encode class expectations in categorical data via outcomes
condition = sys.argv[1]
if condition not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
raise RuntimeError
condition = int(condition)
# The number of chains for HMC
N_chains = int(sys.argv[2])
scale = 100
data = np.zeros([10], dtype=np.int)
data[condition] = 1
# Read in trained networks
sess = tf.InteractiveSession()
graph = tf.get_default_graph()
new_saver = tf.train.import_meta_graph('trained_models/generator_mnist-model.meta',
import_scope='generator')
new_saver.restore(sess, 'trained_models/generator_mnist-model')
generator = graph.get_tensor_by_name('generator/generator/strided_slice:0')
latent_input = graph.get_tensor_by_name('generator/placeholders/Placeholder_1:0')
new_saver = tf.train.import_meta_graph('trained_models/convnet_mnist-model.meta',
import_scope='classifier')
new_saver.restore(sess, 'trained_models/convnet_mnist-model')
digit_input = graph.get_tensor_by_name('classifier/inp:0')
class_output = graph.get_tensor_by_name('classifier/out:0')
# Set up the spaces
position_space = ift.RGSpace([28, 28])
latent_space = ift.UnstructuredDomain([128])
data_space = ift.UnstructuredDomain([10])
# Build model
Generator = TensorFlowOperator(generator, latent_input, latent_space,
position_space, add_target_axis=True)
Classifier = TensorFlowOperator(class_output, digit_input, position_space,
data_space, add_domain_axis=True)
Digit = Generator.ducktape('xi')
Classification = Classifier(Digit).clip(1e-15, 1 - 1e-15)
data = ift.makeField(data_space, data)
Likelihood = CategoricalEnergy(data, scale=scale)(Classification)
H = ift.StandardHamiltonian(Likelihood)
# Set up initial positions
initial_positions = []
with ift.random.Context(43):
for i in range(N_chains):
initial_positions.append(ift.from_random(H.domain))
HMC = hmc.HMC_Sampler(H,initial_positions,steplength=0.01,chains=N_chains, comm=comm)
for i in range(2):
HMC.warmup(100)
ESS_mean = []
ESS_min = []
R_hat_mean =[]
R_hat_max = []
N_sampling = 25
fontsize = 15
for i in range(100):
HMC.sample(N_sampling)
ESS = HMC.ESS
ESS_mean.append(ESS.val['xi'].mean())
ESS_min.append(ESS.val['xi'].min())
R_hat = HMC.R_hat
R_hat_mean.append(R_hat.val['xi'].mean()-1)
R_hat_max.append(R_hat.val['xi'].max()-1)
if not master:
continue
plt.clf()
fig, axes = plt.subplots(1,3, num='results',figsize = (18,5))
samps = np.random.choice(HMC._local_chains[0].samples, 16)
plt.gray()
iterations = np.arange(0,(i+1)*N_sampling,N_sampling)
axes[0].imshow(digit_sample_pic(samps, Digit))
axes[0].axis('off')
axes[0].set_title('samples', fontsize=fontsize)
axes[1].plot(iterations, ESS_mean,'b-',label='mean ESS')
axes[1].plot(iterations, ESS_min,'k-',label='min ESS')
axes[1].legend(fontsize=fontsize)
axes[1].set_ylabel('effective sample size', fontsize=fontsize)
axes[1].set_xlabel('iteration', fontsize=fontsize)
axes[1].set_title('effective sample size', fontsize=fontsize)
axes[2].plot(iterations, R_hat_mean,'b-',label='mean R_hat-1')
axes[2].plot(iterations, R_hat_max,'k-',label='max R_hat-1')
axes[2].set_yscale('log')
axes[2].legend(fontsize=fontsize)
axes[2].set_ylabel('R_hat - 1', fontsize=fontsize)
axes[2].set_xlabel('iteration', fontsize=fontsize)
axes[2].set_title('Gelman Rubin R_hat', fontsize=fontsize)
plt.pause(0.001)