Skip to content
Snippets Groups Projects
Select Git revision
  • 7b6dd3b40a6ff5a35902312d4fb83f9ac8e3f598
  • master default protected
  • BioEM-1.0 protected
  • BioEM-2.1.0
  • BioEM-2.0.3
  • BioEM-2.0.2
  • BioEM-2.0.1
  • BioEM-2.0.0
  • BioEM-1.0.2
  • BioEM-1.0.1
  • BioEM-1.0.0
11 results

Model_Text

Blame
  • demo.py 4.97 KiB
    # This program is free software: you can redistribute it and/or modify
    # it under the terms of the GNU General Public License as published by
    # the Free Software Foundation, either version 3 of the License, or
    # (at your option) any later version.
    #
    # This program is distributed in the hope that it will be useful,
    # but WITHOUT ANY WARRANTY; without even the implied warranty of
    # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    # GNU General Public License for more details.
    #
    # You should have received a copy of the GNU General Public License
    # along with this program.  If not, see <http://www.gnu.org/licenses/>.
    #
    # Copyright(C) 2020 Max-Planck-Society
    # Author: Jakob Knollmueller
    
    import nifty6 as ift
    import nifty_hmc as hmc
    import numpy as np
    import tensorflow.compat.v1 as tf
    import sys
    from operators.tensorflow_operator import TensorFlowOperator
    from operators.multinomial_energy import CategoricalEnergy
    from matplotlib import pyplot as plt
    from sugar import digit_sample_pic
    
    tf.disable_v2_behavior()
    
    if __name__ == '__main__':
        comm, _, _, master = ift.utilities.get_MPI_params()
        ift.random.push_sseq_from_seed(1)
    
    
        # Encode class expectations in categorical data via outcomes
        condition = sys.argv[1]
        if condition not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:
            raise RuntimeError
        condition = int(condition)
    
        # The number of chains for HMC
        N_chains = int(sys.argv[2])
    
        scale = 100
    
        data = np.zeros([10], dtype=np.int)
        data[condition] = 1
    
        # Read in trained networks
        sess = tf.InteractiveSession()
        graph = tf.get_default_graph()
    
        new_saver = tf.train.import_meta_graph('trained_models/generator_mnist-model.meta',
                                               import_scope='generator')
        new_saver.restore(sess, 'trained_models/generator_mnist-model')
        generator = graph.get_tensor_by_name('generator/generator/strided_slice:0')
        latent_input = graph.get_tensor_by_name('generator/placeholders/Placeholder_1:0')
    
        new_saver = tf.train.import_meta_graph('trained_models/convnet_mnist-model.meta',
                                               import_scope='classifier')
        new_saver.restore(sess, 'trained_models/convnet_mnist-model')
        digit_input = graph.get_tensor_by_name('classifier/inp:0')
        class_output = graph.get_tensor_by_name('classifier/out:0')
    
        # Set up the spaces
        position_space = ift.RGSpace([28, 28])
        latent_space = ift.UnstructuredDomain([128])
        data_space = ift.UnstructuredDomain([10])
    
        # Build model
        Generator = TensorFlowOperator(generator, latent_input, latent_space,
                                       position_space, add_target_axis=True)
        Classifier = TensorFlowOperator(class_output, digit_input, position_space,
                                        data_space, add_domain_axis=True)
    
        Digit = Generator.ducktape('xi')
        Classification = Classifier(Digit).clip(1e-15, 1 - 1e-15)
        data = ift.makeField(data_space, data)
        Likelihood = CategoricalEnergy(data, scale=scale)(Classification)
        H = ift.StandardHamiltonian(Likelihood)
    
        # Set up initial positions
        initial_positions = []
        with ift.random.Context(43):
            for i in range(N_chains):
                initial_positions.append(ift.from_random(H.domain))
    
    
        HMC = hmc.HMC_Sampler(H,initial_positions,steplength=0.01,chains=N_chains, comm=comm)
    
        for i in range(2):
            HMC.warmup(100)
    
        ESS_mean = []
        ESS_min = []
        R_hat_mean =[]
        R_hat_max = []
        N_sampling = 25
        fontsize = 15
        for i in range(100):
            
            HMC.sample(N_sampling)
    
            ESS = HMC.ESS
    
            ESS_mean.append(ESS.val['xi'].mean())
            ESS_min.append(ESS.val['xi'].min())
    
            R_hat = HMC.R_hat
    
            R_hat_mean.append(R_hat.val['xi'].mean()-1)
            R_hat_max.append(R_hat.val['xi'].max()-1)
    
            if not master:
                continue
    
            plt.clf()
            fig, axes = plt.subplots(1,3, num='results',figsize = (18,5))
            samps = np.random.choice(HMC._local_chains[0].samples, 16)
            plt.gray()
            iterations = np.arange(0,(i+1)*N_sampling,N_sampling)
            axes[0].imshow(digit_sample_pic(samps, Digit))
            axes[0].axis('off')
            axes[0].set_title('samples', fontsize=fontsize)
            axes[1].plot(iterations, ESS_mean,'b-',label='mean ESS')
            axes[1].plot(iterations, ESS_min,'k-',label='min ESS')
            axes[1].legend(fontsize=fontsize)
            axes[1].set_ylabel('effective sample size', fontsize=fontsize)
            axes[1].set_xlabel('iteration', fontsize=fontsize)
            axes[1].set_title('effective sample size', fontsize=fontsize)
    
    
            axes[2].plot(iterations, R_hat_mean,'b-',label='mean R_hat-1')
            axes[2].plot(iterations, R_hat_max,'k-',label='max R_hat-1')
            axes[2].set_yscale('log')
            axes[2].legend(fontsize=fontsize)
            axes[2].set_ylabel('R_hat - 1', fontsize=fontsize)
            axes[2].set_xlabel('iteration', fontsize=fontsize)
            axes[2].set_title('Gelman Rubin R_hat', fontsize=fontsize)
            plt.pause(0.001)